Spaces:
Running
Running
import os | |
import streamlit as st | |
from dotenv import load_dotenv | |
import httpx | |
from huggingface_hub import InferenceClient | |
import json | |
# --- LangChain Imports --- | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_community.retrievers import TavilySearchAPIRetriever | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
from langchain_mistralai import ChatMistralAI | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from typing import List | |
# --- 1. Load API Keys --- | |
load_dotenv() | |
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") | |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
# --- App Configuration --- | |
st.set_page_config(page_title="Synapse AI", page_icon="🧠", layout="wide") | |
# --- Custom CSS --- | |
st.markdown(""" | |
<style> | |
.stApp { background-color: #1E1E1E; color: #E0E0E0; } | |
[data-testid="stChatMessage"] { background-color: #2B2B2B; border-radius: 10px; padding: 1rem; border: 1px solid #333; } | |
[data-testid="stChatInput"] { background-color: #2B2B2B; border-top: 1px solid #333; } | |
[data-testid="stSidebar"] { background-color: #1A1A1A; border-right: 1px solid #333; } | |
.st-expander, .st-expander header { background-color: #2B2B2B !important; color: #E0E0E0 !important; border-radius: 10px; border: 1px solid #333; } | |
.st-expander header:hover { background-color: #333 !important; } | |
.stButton>button { background-color: #4CAF50; color: white; border-radius: 8px; border: none; } | |
.stAlert { border-radius: 8px; } | |
.search-query-display { | |
background-color: #2B2B2B; | |
border: 1px solid #444; | |
padding: 0.5rem 1rem; | |
border-radius: 8px; | |
margin-bottom: 1rem; | |
font-family: monospace; | |
color: #A0A0A0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Title & Header --- | |
st.title("🧠 Synapse AI") | |
# --- Session State Initialization --- | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "doc_retriever" not in st.session_state: | |
st.session_state.doc_retriever = None | |
if "web_retriever" not in st.session_state: | |
st.session_state.web_retriever = None | |
if "qa_chain" not in st.session_state: | |
st.session_state.qa_chain = None | |
if "sub_query_chain" not in st.session_state: | |
st.session_state.sub_query_chain = None | |
# --- API Key Validation --- | |
if not HUGGING_FACE_HUB_TOKEN: | |
st.error("HUGGING_FACE_HUB_TOKEN not found! Please add it to your environment secrets.") | |
st.stop() | |
if not TAVILY_API_KEY: | |
st.sidebar.warning("TAVILY_API_KEY not found. Web search will be disabled.") | |
if not MISTRAL_API_KEY: | |
st.sidebar.warning("MISTRAL_API_KEY not found. Query generation will be less effective.") | |
# --- Core Logic --- | |
def invoke_llm(messages_for_api): | |
"""Manually invokes the HF Inference Client with a simple message list.""" | |
client = InferenceClient(token=HUGGING_FACE_HUB_TOKEN) | |
response = client.chat_completion( | |
messages=messages_for_api, | |
model="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
max_tokens=2048, | |
temperature=0.1 | |
) | |
return response.choices[0].message.content | |
def llm_wrapper(prompt_value): | |
"""Converts LangChain message objects to the dictionary format required by the HF client.""" | |
messages_for_api = [] | |
for msg in prompt_value.to_messages(): | |
role = "user" if msg.type == 'human' else "assistant" if msg.type == 'ai' else "system" | |
messages_for_api.append({"role": role, "content": msg.content}) | |
return invoke_llm(messages_for_api) | |
# Pydantic models for structured output (Tool Calling) | |
class SubQuery(BaseModel): | |
"""A single, targeted search query with its designated datasource.""" | |
query: str = Field(description="The specific, self-contained search query string.") | |
datasource: str = Field(description="The source to search, either 'web' or 'doc'.") | |
class ResearchPlan(BaseModel): | |
"""A list of sub-queries to execute for answering a user's question.""" | |
queries: List[SubQuery] = Field(description="A list of 2-4 targeted sub-queries.") | |
def create_sub_query_chain(): | |
"""Creates a chain to generate targeted sub-queries using a commercial LLM with tool calling.""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are an expert at breaking down complex user questions into a series of smaller, targeted search queries. | |
Based on the user's question and the conversation history, generate a research plan by calling the ResearchPlan tool. | |
{doc_instruction}"""), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}") | |
]) | |
llm = ChatMistralAI(model="mistral-large-latest", temperature=0, api_key=MISTRAL_API_KEY) | |
return prompt | llm.with_structured_output(ResearchPlan) | |
def create_qa_chain(): | |
"""Creates the final question-answering chain with annotation and formatting instructions.""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are an AI research assistant. Your task is to answer the user's question based on the chat history and the provided context. | |
Synthesize the information from all sources into a single, cohesive, well-formatted answer. | |
**Formatting Instructions:** | |
- Use Markdown for clear formatting (headings, bold text, lists). | |
- Use LaTeX for all mathematical notation, formulas, and technical symbols by enclosing them in '$' or '$$'. For example, write '$L=12$' for variables. | |
- Structure your response logically with clear sections where appropriate. | |
IMPORTANT: You MUST cite the sources you use. The context is provided as a numbered list. At the end of each sentence or claim you make, add the corresponding source number(s) in brackets, like [1] or [2, 3]. | |
CONTEXT: | |
{context}"""), | |
MessagesPlaceholder("chat_history"), | |
("human", "{input}"), | |
]) | |
return ( | |
RunnablePassthrough.assign(context=lambda inputs: inputs["context"]) # Pass context directly | |
| prompt | |
| RunnableLambda(llm_wrapper) | |
| StrOutputParser() | |
) | |
def build_doc_retriever(_uploaded_files): | |
"""Builds and returns a document retriever from uploaded files.""" | |
if not _uploaded_files: | |
return None | |
all_splits = [] | |
for uploaded_file in _uploaded_files: | |
temp_file_path = f"/tmp/{uploaded_file.name}" | |
with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue()) | |
loader = PyPDFLoader(temp_file_path) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200) | |
splits = text_splitter.split_documents(docs) | |
for split in splits: | |
split.metadata["filename"] = uploaded_file.name | |
all_splits.extend(splits) | |
os.remove(temp_file_path) | |
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5") | |
vectorstore = FAISS.from_documents(documents=all_splits, embedding=embedding_model) | |
return vectorstore.as_retriever(search_kwargs={"k": 5}) | |
# --- UI & State Management --- | |
with st.sidebar: | |
st.title("Controls") | |
st.write("Upload or manage documents.") | |
uploaded_files = st.file_uploader("Upload PDFs", type="pdf", accept_multiple_files=True, key="pdf_uploader_main") | |
if st.button("Start New Chat"): | |
st.session_state.clear() | |
st.rerun() | |
if "file_names" not in st.session_state: st.session_state.file_names = [] | |
current_file_names = [f.name for f in uploaded_files] | |
if set(current_file_names) != set(st.session_state.file_names): | |
st.session_state.file_names = current_file_names | |
with st.spinner(f"Processing {len(st.session_state.file_names)} document(s)..."): | |
st.session_state.doc_retriever = build_doc_retriever(uploaded_files) | |
st.success("Documents processed!") | |
if "web_retriever" not in st.session_state or st.session_state.web_retriever is None: | |
st.session_state.web_retriever = TavilySearchAPIRetriever(k=5, tavily_api_key=TAVILY_API_KEY) | |
if "qa_chain" not in st.session_state or st.session_state.qa_chain is None: | |
st.session_state.qa_chain = create_qa_chain() | |
if "sub_query_chain" not in st.session_state or st.session_state.sub_query_chain is None: | |
st.session_state.sub_query_chain = create_sub_query_chain() | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
if "sub_queries_html" in message: | |
st.markdown(message["sub_queries_html"], unsafe_allow_html=True) | |
st.markdown(message["content"]) | |
if "sources" in message: | |
with st.expander("Sources", expanded=False): | |
st.markdown(message["sources"], unsafe_allow_html=True) | |
# --- Main Conversational Logic --- | |
if prompt := st.chat_input("Ask a question..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): st.markdown(prompt) | |
with st.chat_message("assistant"): | |
with st.spinner("Synapse is thinking..."): | |
try: | |
chat_history = [HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) for m in st.session_state.messages[:-1]] | |
# Step 1: Generate a structured research plan with forceful doc instruction | |
doc_instruction = "" | |
if st.session_state.doc_retriever: | |
doc_instruction = "IMPORTANT: The user has uploaded documents. For any part of the user's question that explicitly refers to 'the paper' or 'the document', you MUST set the 'datasource' for that query to 'doc'." | |
research_plan = st.session_state.sub_query_chain.invoke({ | |
"chat_history": chat_history, | |
"input": prompt, | |
"doc_instruction": doc_instruction | |
}) | |
sub_queries = research_plan.queries | |
sub_queries_html = "<div class='search-query-display'><b>Research Plan:</b><ul>" + "".join([f"<li><b>Search {q.datasource}:</b> {q.query}</li>" for q in sub_queries]) + "</ul></div>" | |
st.markdown(sub_queries_html, unsafe_allow_html=True) | |
# Step 2: Execute retrievals based on the reliable plan | |
retrieved_docs = [] | |
for query_info in sub_queries: | |
if query_info.datasource == "doc" and st.session_state.doc_retriever: | |
results = st.session_state.doc_retriever.invoke(query_info.query) | |
retrieved_docs.extend(results) | |
else: | |
results = st.session_state.web_retriever.invoke(query_info.query) | |
retrieved_docs.extend(results) | |
# Step 3: Format sources and context for annotation | |
numbered_context_list = [] | |
source_markdown_list = [] | |
for i, doc in enumerate(retrieved_docs): | |
source_id = i + 1 | |
numbered_context_list.append(f"[{source_id}] Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content}") | |
if "filename" in doc.metadata: | |
filename = doc.metadata.get('filename', 'Unknown Document') | |
page_meta = doc.metadata.get('page', 'N/A') | |
display_page = page_meta + 1 if isinstance(page_meta, int) else 'N/A' | |
source_markdown_list.append(f"**[{source_id}]** Document: {filename} (Page {display_page})") | |
elif "title" in doc.metadata and "source" in doc.metadata: | |
source_markdown_list.append(f"**[{source_id}]** Web: [{doc.metadata['title']}]({doc.metadata['source']})") | |
numbered_context_str = "\n\n".join(numbered_context_list) | |
source_markdown = "\n\n".join(source_markdown_list) | |
with st.expander("Sources", expanded=False): | |
st.markdown(source_markdown, unsafe_allow_html=True) | |
# Step 4: Generate final, annotated answer | |
answer = st.session_state.qa_chain.invoke({ | |
"chat_history": chat_history, | |
"input": prompt, | |
"context": numbered_context_str | |
}) | |
st.markdown(answer) | |
st.session_state.messages.append({ | |
"role": "assistant", | |
"content": answer, | |
"sub_queries_html": sub_queries_html, | |
"sources": source_markdown | |
}) | |
except httpx.HTTPStatusError as e: | |
st.error(f"An API error occurred: {e}. The service may be busy. Please try again shortly.") | |
except Exception as e: | |
st.error(f"An unexpected error occurred: {e}") | |
if not st.session_state.messages: | |
st.info("Upload your documents and ask a question to get started.") | |