import os import streamlit as st from dotenv import load_dotenv import httpx from huggingface_hub import InferenceClient import json # --- LangChain Imports --- from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import AIMessage, HumanMessage from langchain_community.retrievers import TavilySearchAPIRetriever from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_mistralai import ChatMistralAI from langchain_core.pydantic_v1 import BaseModel, Field from typing import List # --- 1. Load API Keys --- load_dotenv() HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") # --- App Configuration --- st.set_page_config(page_title="Synapse AI", page_icon="🧠", layout="wide") # --- Custom CSS --- st.markdown(""" """, unsafe_allow_html=True) # --- Title & Header --- st.title("🧠Synapse AI") # --- Session State Initialization --- if "messages" not in st.session_state: st.session_state.messages = [] if "doc_retriever" not in st.session_state: st.session_state.doc_retriever = None if "web_retriever" not in st.session_state: st.session_state.web_retriever = None if "qa_chain" not in st.session_state: st.session_state.qa_chain = None if "sub_query_chain" not in st.session_state: st.session_state.sub_query_chain = None # --- API Key Validation --- if not HUGGING_FACE_HUB_TOKEN: st.error("HUGGING_FACE_HUB_TOKEN not found! Please add it to your environment secrets.") st.stop() if not TAVILY_API_KEY: st.sidebar.warning("TAVILY_API_KEY not found. Web search will be disabled.") if not MISTRAL_API_KEY: st.sidebar.warning("MISTRAL_API_KEY not found. Query generation will be less effective.") # --- Core Logic --- def invoke_llm(messages_for_api): """Manually invokes the HF Inference Client with a simple message list.""" client = InferenceClient(token=HUGGING_FACE_HUB_TOKEN) response = client.chat_completion( messages=messages_for_api, model="mistralai/Mixtral-8x7B-Instruct-v0.1", max_tokens=2048, temperature=0.1 ) return response.choices[0].message.content def llm_wrapper(prompt_value): """Converts LangChain message objects to the dictionary format required by the HF client.""" messages_for_api = [] for msg in prompt_value.to_messages(): role = "user" if msg.type == 'human' else "assistant" if msg.type == 'ai' else "system" messages_for_api.append({"role": role, "content": msg.content}) return invoke_llm(messages_for_api) # Pydantic models for structured output (Tool Calling) class SubQuery(BaseModel): """A single, targeted search query with its designated datasource.""" query: str = Field(description="The specific, self-contained search query string.") datasource: str = Field(description="The source to search, either 'web' or 'doc'.") class ResearchPlan(BaseModel): """A list of sub-queries to execute for answering a user's question.""" queries: List[SubQuery] = Field(description="A list of 2-4 targeted sub-queries.") @st.cache_resource def create_sub_query_chain(): """Creates a chain to generate targeted sub-queries using a commercial LLM with tool calling.""" prompt = ChatPromptTemplate.from_messages([ ("system", """You are an expert at breaking down complex user questions into a series of smaller, targeted search queries. Based on the user's question and the conversation history, generate a research plan by calling the ResearchPlan tool. {doc_instruction}"""), MessagesPlaceholder("chat_history"), ("human", "{input}") ]) llm = ChatMistralAI(model="mistral-large-latest", temperature=0, api_key=MISTRAL_API_KEY) return prompt | llm.with_structured_output(ResearchPlan) @st.cache_resource def create_qa_chain(): """Creates the final question-answering chain with annotation and formatting instructions.""" prompt = ChatPromptTemplate.from_messages([ ("system", """You are an AI research assistant. Your task is to answer the user's question based on the chat history and the provided context. Synthesize the information from all sources into a single, cohesive, well-formatted answer. **Formatting Instructions:** - Use Markdown for clear formatting (headings, bold text, lists). - Use LaTeX for all mathematical notation, formulas, and technical symbols by enclosing them in '$' or '$$'. For example, write '$L=12$' for variables. - Structure your response logically with clear sections where appropriate. IMPORTANT: You MUST cite the sources you use. The context is provided as a numbered list. At the end of each sentence or claim you make, add the corresponding source number(s) in brackets, like [1] or [2, 3]. CONTEXT: {context}"""), MessagesPlaceholder("chat_history"), ("human", "{input}"), ]) return ( RunnablePassthrough.assign(context=lambda inputs: inputs["context"]) # Pass context directly | prompt | RunnableLambda(llm_wrapper) | StrOutputParser() ) @st.cache_resource def build_doc_retriever(_uploaded_files): """Builds and returns a document retriever from uploaded files.""" if not _uploaded_files: return None all_splits = [] for uploaded_file in _uploaded_files: temp_file_path = f"/tmp/{uploaded_file.name}" with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue()) loader = PyPDFLoader(temp_file_path) docs = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) splits = text_splitter.split_documents(docs) for split in splits: split.metadata["filename"] = uploaded_file.name all_splits.extend(splits) os.remove(temp_file_path) embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5") vectorstore = FAISS.from_documents(documents=all_splits, embedding=embedding_model) return vectorstore.as_retriever(search_kwargs={"k": 5}) # --- UI & State Management --- with st.sidebar: st.title("Controls") st.write("Upload or manage documents.") uploaded_files = st.file_uploader("Upload PDFs", type="pdf", accept_multiple_files=True, key="pdf_uploader_main") if st.button("Start New Chat"): st.session_state.clear() st.rerun() if "file_names" not in st.session_state: st.session_state.file_names = [] current_file_names = [f.name for f in uploaded_files] if set(current_file_names) != set(st.session_state.file_names): st.session_state.file_names = current_file_names with st.spinner(f"Processing {len(st.session_state.file_names)} document(s)..."): st.session_state.doc_retriever = build_doc_retriever(uploaded_files) st.success("Documents processed!") if "web_retriever" not in st.session_state or st.session_state.web_retriever is None: st.session_state.web_retriever = TavilySearchAPIRetriever(k=5, tavily_api_key=TAVILY_API_KEY) if "qa_chain" not in st.session_state or st.session_state.qa_chain is None: st.session_state.qa_chain = create_qa_chain() if "sub_query_chain" not in st.session_state or st.session_state.sub_query_chain is None: st.session_state.sub_query_chain = create_sub_query_chain() for message in st.session_state.messages: with st.chat_message(message["role"]): if "sub_queries_html" in message: st.markdown(message["sub_queries_html"], unsafe_allow_html=True) st.markdown(message["content"]) if "sources" in message: with st.expander("Sources", expanded=False): st.markdown(message["sources"], unsafe_allow_html=True) # --- Main Conversational Logic --- if prompt := st.chat_input("Ask a question..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Synapse is thinking..."): try: chat_history = [HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) for m in st.session_state.messages[:-1]] # Step 1: Generate a structured research plan with forceful doc instruction doc_instruction = "" if st.session_state.doc_retriever: doc_instruction = "IMPORTANT: The user has uploaded documents. For any part of the user's question that explicitly refers to 'the paper' or 'the document', you MUST set the 'datasource' for that query to 'doc'." research_plan = st.session_state.sub_query_chain.invoke({ "chat_history": chat_history, "input": prompt, "doc_instruction": doc_instruction }) sub_queries = research_plan.queries sub_queries_html = "