Spaces:
Running
Running
import os | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain_groq import ChatGroq | |
from PyPDF2 import PdfReader | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from langgraph.prebuilt import create_react_agent | |
from retrieval import create_retriever | |
from templates import advisor_template, predictor_template, generator_template | |
from langchain.tools.retriever import create_retriever_tool | |
from tools import tavily_tool | |
from dotenv import load_dotenv | |
from typing import List | |
load_dotenv() | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
chat = ChatGroq(model = "llama-3.3-70b-versatile", api_key=groq_api_key) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def read_root(): | |
return {"message": "Welcome to the Legal Research API! Please use one of the endpoints for requests."} | |
def process_files(files: List[UploadFile]): | |
if not files: | |
raise HTTPException(status_code=400, detail="Please upload at least one PDF file.") | |
docs = [] | |
for uploaded_file in files: | |
reader = PdfReader(uploaded_file.file) | |
text = "" | |
for page in reader.pages: | |
text += page.extract_text() | |
docs.append(Document(page_content=text, metadata={"source": uploaded_file.filename})) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
pdf_content = text_splitter.split_documents(docs) | |
return pdf_content | |
def setup_retriever(pdf_content): | |
retriever = create_retriever(pdf_content) | |
retrieval_tool = create_retriever_tool( | |
retriever, | |
"Pdf_content_retriever", | |
"Searches and returns excerpts from the set of PDF docs.", | |
) | |
return retriever, retrieval_tool | |
def setup_agents(tools): | |
advisor_graph = create_react_agent(chat, tools=tools, state_modifier=advisor_template) | |
predictor_graph = create_react_agent(chat, tools=tools, state_modifier=predictor_template) | |
return advisor_graph, predictor_graph | |
async def legal_assistance( | |
query: str = Form(...), | |
option: str = Form(...), | |
files: List[UploadFile] = File(...) | |
): | |
if not query: | |
raise HTTPException(status_code=400, detail="Please enter a query.") | |
pdf_content = process_files(files) | |
retriever, retrieval_tool = setup_retriever(pdf_content) | |
tools = [tavily_tool, retrieval_tool] | |
advisor_graph, predictor_graph = setup_agents(tools) | |
inputs = {"messages": [("human", query)]} | |
if option == "Legal Advisory": | |
async for chunk in advisor_graph.astream(inputs, stream_mode="values"): | |
final_result = chunk | |
result = final_result["messages"][-1].content | |
return {"result": result} | |
elif option == "Legal Report Generation": | |
set_ret = RunnableParallel({"context": retriever, "query": RunnablePassthrough()}) | |
rag_chain = set_ret | generator_template | chat | StrOutputParser() | |
report = rag_chain.invoke(query) | |
return {"report": report} | |
elif option == "Case Outcome Prediction": | |
async for chunk in predictor_graph.astream(inputs, stream_mode="values"): | |
final_prediction = chunk | |
prediction = final_prediction["messages"][-1] | |
return {"prediction": prediction} | |
else: | |
raise HTTPException(status_code=400, detail="Invalid option selected.") | |
async def legal_advisory_endpoint( | |
query: str = Form(...), | |
files: List[UploadFile] = File(...) | |
): | |
if not query: | |
raise HTTPException(status_code=400, detail="Please enter a query.") | |
pdf_content = process_files(files) | |
retriever, retrieval_tool = setup_retriever(pdf_content) | |
tools = [tavily_tool, retrieval_tool] | |
advisor_graph, _ = setup_agents(tools) | |
inputs = {"messages": [("human", query)]} | |
async for chunk in advisor_graph.astream(inputs, stream_mode="values"): | |
final_result = chunk | |
result = final_result["messages"][-1].content | |
return {"result": result} | |
async def case_outcome_prediction_endpoint( | |
query: str = Form(...), | |
files: List[UploadFile] = File(...) | |
): | |
if not query: | |
raise HTTPException(status_code=400, detail="Please enter a query.") | |
pdf_content = process_files(files) | |
retriever, retrieval_tool = setup_retriever(pdf_content) | |
tools = [tavily_tool, retrieval_tool] | |
_, predictor_graph = setup_agents(tools) | |
inputs = {"messages": [("human", query)]} | |
async for chunk in predictor_graph.astream(inputs, stream_mode="values"): | |
final_prediction = chunk | |
prediction = final_prediction["messages"][-1].content | |
return {"prediction": prediction} | |
async def report_generator_endpoint( | |
query: str = Form(...), | |
files: List[UploadFile] = File(...) | |
): | |
if not query: | |
raise HTTPException(status_code=400, detail="Please enter a query.") | |
pdf_content = process_files(files) | |
retriever, _ = setup_retriever(pdf_content) | |
set_ret = RunnableParallel({"context": retriever, "query": RunnablePassthrough()}) | |
rag_chain = set_ret | generator_template | chat | StrOutputParser() | |
report = rag_chain.invoke(query) | |
return {"report": report} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=10000) | |