Spaces:
Sleeping
Sleeping
from typing import Dict, Any | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import StrOutputParser | |
from scripts.rag_chat import build_general_qa_chain | |
def build_router_chain(model_name=None): | |
general_qa = build_general_qa_chain(model_name=model_name) | |
llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) | |
# This prompt asks the LLM to choose which "mode" to use | |
router_prompt = ChatPromptTemplate.from_template(""" | |
You are a routing assistant for a chatbot. | |
Classify the following user request into one of these categories: | |
- "code" for programming or debugging | |
- "summarize" for summary requests | |
- "calculate" for math or numeric calculations | |
- "general" for general Q&A using course files | |
Return ONLY the category word. | |
User request: {input} | |
""") | |
router_chain = router_prompt | llm | StrOutputParser() | |
class Router: | |
def invoke(self, input_dict: Dict[str, Any]): | |
category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() | |
print(f"[ROUTER] User query routed to category: {category}") | |
if category == "code": | |
prompt = ChatPromptTemplate.from_template( | |
"As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" | |
) | |
chain = prompt | llm | StrOutputParser() | |
return {"result": chain.invoke({"input": input_dict["input"]})} | |
# elif category == "summarize": | |
# prompt = ChatPromptTemplate.from_template( | |
# "Provide a concise summary about: {input}\nSummary:" | |
# ) | |
# chain = prompt | llm | StrOutputParser() | |
# return {"result": chain.invoke({"input": input_dict["input"]})} | |
#elif category == "summarize": | |
# # 1. Use RAG to retrieve relevant docs | |
# rag_result = general_qa({"query": input_dict["input"]}) | |
# # 2. Extract docs and prepare text | |
# source_docs = rag_result.get("source_documents", []) | |
# combined_text = "\n\n".join([doc.page_content for doc in source_docs]) | |
# # 3. Run the summarizer chain on the retrieved text | |
# from scripts.summarizer import get_summarizer | |
# summarizer_chain = get_summarizer() | |
# summary = summarizer_chain.run(combined_text) | |
# # 4. Add sources if any | |
# sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs}) | |
# if sources: | |
# summary += f"\n\n📚 Sources: {', '.join(sources)}" | |
# return {"result": summary} | |
elif category == "summarize": | |
# 1) Retrieve relevant documents via your existing RAG chain | |
rag_result = general_qa({"query": input_dict["input"]}) | |
# 2) Get the retrieved docs (already LangChain Document objects) | |
source_docs = rag_result.get("source_documents", []) or [] | |
# 3) Build the summarizer and prepare the docs list | |
from langchain.docstore.document import Document | |
from scripts.summarizer import get_summarizer | |
summarizer_chain = get_summarizer() | |
# If retrieval returned nothing, fall back to summarizing the user’s text | |
docs = source_docs if source_docs else [Document(page_content=input_dict["input"])] | |
# 4) Summarize — load_summarize_chain returns {"output_text": "..."} | |
out = summarizer_chain.invoke(docs) | |
summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out) | |
# 5) Append sources (only if we actually had retrieved docs) | |
if source_docs: | |
sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs}) | |
if sources: | |
summary += f"\n\n📚 Sources: {', '.join(sources)}" | |
return {"result": summary} | |
elif category == "calculate": | |
prompt = ChatPromptTemplate.from_template( | |
"Solve the following calculation step-by-step:\n{input}" | |
) | |
chain = prompt | llm | StrOutputParser() | |
return {"result": chain.invoke({"input": input_dict["input"]})} | |
else: # "general" | |
return general_qa({"query": input_dict["input"]}) | |
return Router() | |