File size: 2,314 Bytes
ac2020e
 
 
 
 
 
eeef8f5
ac2020e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeef8f5
 
 
 
 
 
 
 
 
ac2020e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""RAG (Retrieval-Augmented Generation) chain implementation"""

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough

from legisqa_local.core.llm import get_llm
from legisqa_local.core.vectorstore import get_vectorstore, get_vectorstore_filter
from legisqa_local.utils.formatting import format_docs


def create_rag_chain(llm, retriever):
    """Create a RAG chain with the given LLM and retriever"""
    
    QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.

---

Congressional Legislation Excerpts:

{context}

---

Query: {query}"""

    prompt = ChatPromptTemplate.from_messages([
        ("human", QUERY_RAG_TEMPLATE),
    ])

    rag_chain = (
        RunnableParallel({
            "docs": retriever,
            "query": RunnablePassthrough(),
        })
        .assign(context=lambda x: format_docs(x["docs"]))
        .assign(aimessage=prompt | llm)
    )

    return rag_chain


def process_query(gen_config: dict, ret_config: dict, query: str):
    """Process a query using RAG"""
    # Check if vectorstore is loaded
    vectorstore = get_vectorstore()
    if vectorstore is None:
        return {
            "aimessage": "⏳ Vectorstore is still loading. Please wait a moment and try again.",
            "docs": [],
            "query": query
        }
    
    llm = get_llm(gen_config)
    vs_filter = get_vectorstore_filter(ret_config)
    
    # ChromaDB uses 'filter' parameter in search_kwargs
    search_kwargs = {"k": ret_config["n_ret_docs"]}
    if vs_filter:
        search_kwargs["filter"] = vs_filter
        
    retriever = vectorstore.as_retriever(search_kwargs=search_kwargs)
    rag_chain = create_rag_chain(llm, retriever)
    response = rag_chain.invoke(query)
    return response