from typing import Dict, Any from langchain_openai import ChatOpenAI from langchain.prompts import ChatPromptTemplate from langchain.schema import StrOutputParser from scripts.rag_chat import build_general_qa_chain def build_router_chain(model_name=None): general_qa = build_general_qa_chain(model_name=model_name) llm = ChatOpenAI(model_name=model_name or "gpt-4o-mini", temperature=0.0) # This prompt asks the LLM to choose which "mode" to use router_prompt = ChatPromptTemplate.from_template(""" You are a routing assistant for a chatbot. Classify the following user request into one of these categories: - "code" for programming or debugging - "summarize" for summary requests - "calculate" for math or numeric calculations - "general" for general Q&A using course files Return ONLY the category word. User request: {input} """) router_chain = router_prompt | llm | StrOutputParser() class Router: def invoke(self, input_dict: Dict[str, Any]): category = router_chain.invoke({"input": input_dict["input"]}).strip().lower() print(f"[ROUTER] User query routed to category: {category}") if category == "code": prompt = ChatPromptTemplate.from_template( "As a coding assistant, help with this Python question.\nQuestion: {input}\nAnswer:" ) chain = prompt | llm | StrOutputParser() return {"result": chain.invoke({"input": input_dict["input"]})} # elif category == "summarize": # prompt = ChatPromptTemplate.from_template( # "Provide a concise summary about: {input}\nSummary:" # ) # chain = prompt | llm | StrOutputParser() # return {"result": chain.invoke({"input": input_dict["input"]})} #elif category == "summarize": # # 1. Use RAG to retrieve relevant docs # rag_result = general_qa({"query": input_dict["input"]}) # # 2. Extract docs and prepare text # source_docs = rag_result.get("source_documents", []) # combined_text = "\n\n".join([doc.page_content for doc in source_docs]) # # 3. Run the summarizer chain on the retrieved text # from scripts.summarizer import get_summarizer # summarizer_chain = get_summarizer() # summary = summarizer_chain.run(combined_text) # # 4. Add sources if any # sources = list({str(doc.metadata.get("source", "unknown")) for doc in source_docs}) # if sources: # summary += f"\n\n๐Ÿ“š Sources: {', '.join(sources)}" # return {"result": summary} elif category == "summarize": # 1) Retrieve relevant documents via your existing RAG chain rag_result = general_qa({"query": input_dict["input"]}) # 2) Get the retrieved docs (already LangChain Document objects) source_docs = rag_result.get("source_documents", []) or [] # 3) Build the summarizer and prepare the docs list from langchain.docstore.document import Document from scripts.summarizer import get_summarizer summarizer_chain = get_summarizer() # If retrieval returned nothing, fall back to summarizing the userโ€™s text docs = source_docs if source_docs else [Document(page_content=input_dict["input"])] # 4) Summarize โ€” load_summarize_chain returns {"output_text": "..."} out = summarizer_chain.invoke(docs) summary = out["output_text"] if isinstance(out, dict) and "output_text" in out else str(out) # 5) Append sources (only if we actually had retrieved docs) if source_docs: sources = sorted({str(d.metadata.get("source", "unknown")) for d in source_docs}) if sources: summary += f"\n\n๐Ÿ“š Sources: {', '.join(sources)}" return {"result": summary} elif category == "calculate": prompt = ChatPromptTemplate.from_template( "Solve the following calculation step-by-step:\n{input}" ) chain = prompt | llm | StrOutputParser() return {"result": chain.invoke({"input": input_dict["input"]})} else: # "general" return general_qa({"query": input_dict["input"]}) return Router()