navilaw-ai / main.py
datafreak's picture
Dockerfile and other imp files
eee9fe9 verified
import os
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from langchain_groq import ChatGroq
from PyPDF2 import PdfReader
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langgraph.prebuilt import create_react_agent
from retrieval import create_retriever
from templates import advisor_template, predictor_template, generator_template
from langchain.tools.retriever import create_retriever_tool
from tools import tavily_tool
from dotenv import load_dotenv
from typing import List
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
chat = ChatGroq(model = "llama-3.3-70b-versatile", api_key=groq_api_key)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def read_root():
return {"message": "Welcome to the Legal Research API! Please use one of the endpoints for requests."}
def process_files(files: List[UploadFile]):
if not files:
raise HTTPException(status_code=400, detail="Please upload at least one PDF file.")
docs = []
for uploaded_file in files:
reader = PdfReader(uploaded_file.file)
text = ""
for page in reader.pages:
text += page.extract_text()
docs.append(Document(page_content=text, metadata={"source": uploaded_file.filename}))
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
pdf_content = text_splitter.split_documents(docs)
return pdf_content
def setup_retriever(pdf_content):
retriever = create_retriever(pdf_content)
retrieval_tool = create_retriever_tool(
retriever,
"Pdf_content_retriever",
"Searches and returns excerpts from the set of PDF docs.",
)
return retriever, retrieval_tool
def setup_agents(tools):
advisor_graph = create_react_agent(chat, tools=tools, state_modifier=advisor_template)
predictor_graph = create_react_agent(chat, tools=tools, state_modifier=predictor_template)
return advisor_graph, predictor_graph
@app.post("/legal-assistance/")
async def legal_assistance(
query: str = Form(...),
option: str = Form(...),
files: List[UploadFile] = File(...)
):
if not query:
raise HTTPException(status_code=400, detail="Please enter a query.")
pdf_content = process_files(files)
retriever, retrieval_tool = setup_retriever(pdf_content)
tools = [tavily_tool, retrieval_tool]
advisor_graph, predictor_graph = setup_agents(tools)
inputs = {"messages": [("human", query)]}
if option == "Legal Advisory":
async for chunk in advisor_graph.astream(inputs, stream_mode="values"):
final_result = chunk
result = final_result["messages"][-1].content
return {"result": result}
elif option == "Legal Report Generation":
set_ret = RunnableParallel({"context": retriever, "query": RunnablePassthrough()})
rag_chain = set_ret | generator_template | chat | StrOutputParser()
report = rag_chain.invoke(query)
return {"report": report}
elif option == "Case Outcome Prediction":
async for chunk in predictor_graph.astream(inputs, stream_mode="values"):
final_prediction = chunk
prediction = final_prediction["messages"][-1]
return {"prediction": prediction}
else:
raise HTTPException(status_code=400, detail="Invalid option selected.")
@app.post("/legal-advisory/")
async def legal_advisory_endpoint(
query: str = Form(...),
files: List[UploadFile] = File(...)
):
if not query:
raise HTTPException(status_code=400, detail="Please enter a query.")
pdf_content = process_files(files)
retriever, retrieval_tool = setup_retriever(pdf_content)
tools = [tavily_tool, retrieval_tool]
advisor_graph, _ = setup_agents(tools)
inputs = {"messages": [("human", query)]}
async for chunk in advisor_graph.astream(inputs, stream_mode="values"):
final_result = chunk
result = final_result["messages"][-1].content
return {"result": result}
@app.post("/case-outcome-prediction/")
async def case_outcome_prediction_endpoint(
query: str = Form(...),
files: List[UploadFile] = File(...)
):
if not query:
raise HTTPException(status_code=400, detail="Please enter a query.")
pdf_content = process_files(files)
retriever, retrieval_tool = setup_retriever(pdf_content)
tools = [tavily_tool, retrieval_tool]
_, predictor_graph = setup_agents(tools)
inputs = {"messages": [("human", query)]}
async for chunk in predictor_graph.astream(inputs, stream_mode="values"):
final_prediction = chunk
prediction = final_prediction["messages"][-1].content
return {"prediction": prediction}
@app.post("/report-generator/")
async def report_generator_endpoint(
query: str = Form(...),
files: List[UploadFile] = File(...)
):
if not query:
raise HTTPException(status_code=400, detail="Please enter a query.")
pdf_content = process_files(files)
retriever, _ = setup_retriever(pdf_content)
set_ret = RunnableParallel({"context": retriever, "query": RunnablePassthrough()})
rag_chain = set_ret | generator_template | chat | StrOutputParser()
report = rag_chain.invoke(query)
return {"report": report}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=10000)