gabrielaltay's picture
simplify
eeef8f5
"""RAG (Retrieval-Augmented Generation) chain implementation"""
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from legisqa_local.core.llm import get_llm
from legisqa_local.core.vectorstore import get_vectorstore, get_vectorstore_filter
from legisqa_local.utils.formatting import format_docs
def create_rag_chain(llm, retriever):
"""Create a RAG chain with the given LLM and retriever"""
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.
---
Congressional Legislation Excerpts:
{context}
---
Query: {query}"""
prompt = ChatPromptTemplate.from_messages([
("human", QUERY_RAG_TEMPLATE),
])
rag_chain = (
RunnableParallel({
"docs": retriever,
"query": RunnablePassthrough(),
})
.assign(context=lambda x: format_docs(x["docs"]))
.assign(aimessage=prompt | llm)
)
return rag_chain
def process_query(gen_config: dict, ret_config: dict, query: str):
"""Process a query using RAG"""
# Check if vectorstore is loaded
vectorstore = get_vectorstore()
if vectorstore is None:
return {
"aimessage": "⏳ Vectorstore is still loading. Please wait a moment and try again.",
"docs": [],
"query": query
}
llm = get_llm(gen_config)
vs_filter = get_vectorstore_filter(ret_config)
# ChromaDB uses 'filter' parameter in search_kwargs
search_kwargs = {"k": ret_config["n_ret_docs"]}
if vs_filter:
search_kwargs["filter"] = vs_filter
retriever = vectorstore.as_retriever(search_kwargs=search_kwargs)
rag_chain = create_rag_chain(llm, retriever)
response = rag_chain.invoke(query)
return response