LawChatbot / lawchatbot /rag_chain.py
NLPGenius's picture
Initial commit for LawChatbot Panel
f114412
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.prompts import PromptTemplate
def initialize_llm() -> ChatOpenAI:
"""
Initialize the LLM (ChatOpenAI) for RAG.
"""
return ChatOpenAI(
model="deepseek/deepseek-r1-0528",
openai_api_key="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",
openai_api_base="https://openrouter.ai/api/v1",
temperature=0,
max_tokens=8192 # Increase this value if needed
)
RAG_PROMPT = PromptTemplate.from_template(
"""You are an expert legal assistant. Use the provided context to answer the user question at the end.
If you use any document, cite it in the format: [source_name] with metadata (e.g., URL or case_id).
Be accurate, concise, and include citations for facts.
Context:
{context}
Question:
{question}
Answer (with citations):"""
)
def build_rag_chain(llm: ChatOpenAI, retriever) -> RetrievalQAWithSourcesChain:
"""
Build a RetrievalQAWithSourcesChain using the provided LLM and retriever.
"""
return RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={
"prompt": RAG_PROMPT,
"document_variable_name": "context"
},
return_source_documents=True # <-- Set this to True
)
def run_rag_query(
rag_chain: RetrievalQAWithSourcesChain,
query: str,
show_sources: bool = True # Default to True
) -> str:
"""
Run a RAG query and print the answer and relevant context.
"""
response = rag_chain.invoke({"question": query})
answer = response["answer"]
documents = response.get("source_documents", [])
# Optionally, you can return the answer and context as a dict if needed
return answer