DocuFlow_V3 / app.py
anaantraj's picture
Update app.py
634bd9d verified
import os
import streamlit as st
from dotenv import load_dotenv
import httpx
from huggingface_hub import InferenceClient
import json
# --- LangChain Imports ---
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.retrievers import TavilySearchAPIRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
# --- 1. Load API Keys ---
load_dotenv()
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
# --- App Configuration ---
st.set_page_config(page_title="Synapse AI", page_icon="🧠", layout="wide")
# --- Custom CSS ---
st.markdown("""
<style>
.stApp { background-color: #1E1E1E; color: #E0E0E0; }
[data-testid="stChatMessage"] { background-color: #2B2B2B; border-radius: 10px; padding: 1rem; border: 1px solid #333; }
[data-testid="stChatInput"] { background-color: #2B2B2B; border-top: 1px solid #333; }
[data-testid="stSidebar"] { background-color: #1A1A1A; border-right: 1px solid #333; }
.st-expander, .st-expander header { background-color: #2B2B2B !important; color: #E0E0E0 !important; border-radius: 10px; border: 1px solid #333; }
.st-expander header:hover { background-color: #333 !important; }
.stButton>button { background-color: #4CAF50; color: white; border-radius: 8px; border: none; }
.stAlert { border-radius: 8px; }
.search-query-display {
background-color: #2B2B2B;
border: 1px solid #444;
padding: 0.5rem 1rem;
border-radius: 8px;
margin-bottom: 1rem;
font-family: monospace;
color: #A0A0A0;
}
</style>
""", unsafe_allow_html=True)
# --- Title & Header ---
st.title("🧠 Synapse AI")
# --- Session State Initialization ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "doc_retriever" not in st.session_state:
st.session_state.doc_retriever = None
if "web_retriever" not in st.session_state:
st.session_state.web_retriever = None
if "qa_chain" not in st.session_state:
st.session_state.qa_chain = None
if "sub_query_chain" not in st.session_state:
st.session_state.sub_query_chain = None
# --- API Key Validation ---
if not HUGGING_FACE_HUB_TOKEN:
st.error("HUGGING_FACE_HUB_TOKEN not found! Please add it to your environment secrets.")
st.stop()
if not TAVILY_API_KEY:
st.sidebar.warning("TAVILY_API_KEY not found. Web search will be disabled.")
if not MISTRAL_API_KEY:
st.sidebar.warning("MISTRAL_API_KEY not found. Query generation will be less effective.")
# --- Core Logic ---
def invoke_llm(messages_for_api):
"""Manually invokes the HF Inference Client with a simple message list."""
client = InferenceClient(token=HUGGING_FACE_HUB_TOKEN)
response = client.chat_completion(
messages=messages_for_api,
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=2048,
temperature=0.1
)
return response.choices[0].message.content
def llm_wrapper(prompt_value):
"""Converts LangChain message objects to the dictionary format required by the HF client."""
messages_for_api = []
for msg in prompt_value.to_messages():
role = "user" if msg.type == 'human' else "assistant" if msg.type == 'ai' else "system"
messages_for_api.append({"role": role, "content": msg.content})
return invoke_llm(messages_for_api)
# Pydantic models for structured output (Tool Calling)
class SubQuery(BaseModel):
"""A single, targeted search query with its designated datasource."""
query: str = Field(description="The specific, self-contained search query string.")
datasource: str = Field(description="The source to search, either 'web' or 'doc'.")
class ResearchPlan(BaseModel):
"""A list of sub-queries to execute for answering a user's question."""
queries: List[SubQuery] = Field(description="A list of 2-4 targeted sub-queries.")
@st.cache_resource
def create_sub_query_chain():
"""Creates a chain to generate targeted sub-queries using a commercial LLM with tool calling."""
prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert at breaking down complex user questions into a series of smaller, targeted search queries.
Based on the user's question and the conversation history, generate a research plan by calling the ResearchPlan tool.
{doc_instruction}"""),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
llm = ChatMistralAI(model="mistral-large-latest", temperature=0, api_key=MISTRAL_API_KEY)
return prompt | llm.with_structured_output(ResearchPlan)
@st.cache_resource
def create_qa_chain():
"""Creates the final question-answering chain with annotation and formatting instructions."""
prompt = ChatPromptTemplate.from_messages([
("system", """You are an AI research assistant. Your task is to answer the user's question based on the chat history and the provided context.
Synthesize the information from all sources into a single, cohesive, well-formatted answer.
**Formatting Instructions:**
- Use Markdown for clear formatting (headings, bold text, lists).
- Use LaTeX for all mathematical notation, formulas, and technical symbols by enclosing them in '$' or '$$'. For example, write '$L=12$' for variables.
- Structure your response logically with clear sections where appropriate.
IMPORTANT: You MUST cite the sources you use. The context is provided as a numbered list. At the end of each sentence or claim you make, add the corresponding source number(s) in brackets, like [1] or [2, 3].
CONTEXT:
{context}"""),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
return (
RunnablePassthrough.assign(context=lambda inputs: inputs["context"]) # Pass context directly
| prompt
| RunnableLambda(llm_wrapper)
| StrOutputParser()
)
@st.cache_resource
def build_doc_retriever(_uploaded_files):
"""Builds and returns a document retriever from uploaded files."""
if not _uploaded_files:
return None
all_splits = []
for uploaded_file in _uploaded_files:
temp_file_path = f"/tmp/{uploaded_file.name}"
with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue())
loader = PyPDFLoader(temp_file_path)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
for split in splits:
split.metadata["filename"] = uploaded_file.name
all_splits.extend(splits)
os.remove(temp_file_path)
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
vectorstore = FAISS.from_documents(documents=all_splits, embedding=embedding_model)
return vectorstore.as_retriever(search_kwargs={"k": 5})
# --- UI & State Management ---
with st.sidebar:
st.title("Controls")
st.write("Upload or manage documents.")
uploaded_files = st.file_uploader("Upload PDFs", type="pdf", accept_multiple_files=True, key="pdf_uploader_main")
if st.button("Start New Chat"):
st.session_state.clear()
st.rerun()
if "file_names" not in st.session_state: st.session_state.file_names = []
current_file_names = [f.name for f in uploaded_files]
if set(current_file_names) != set(st.session_state.file_names):
st.session_state.file_names = current_file_names
with st.spinner(f"Processing {len(st.session_state.file_names)} document(s)..."):
st.session_state.doc_retriever = build_doc_retriever(uploaded_files)
st.success("Documents processed!")
if "web_retriever" not in st.session_state or st.session_state.web_retriever is None:
st.session_state.web_retriever = TavilySearchAPIRetriever(k=5, tavily_api_key=TAVILY_API_KEY)
if "qa_chain" not in st.session_state or st.session_state.qa_chain is None:
st.session_state.qa_chain = create_qa_chain()
if "sub_query_chain" not in st.session_state or st.session_state.sub_query_chain is None:
st.session_state.sub_query_chain = create_sub_query_chain()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if "sub_queries_html" in message:
st.markdown(message["sub_queries_html"], unsafe_allow_html=True)
st.markdown(message["content"])
if "sources" in message:
with st.expander("Sources", expanded=False):
st.markdown(message["sources"], unsafe_allow_html=True)
# --- Main Conversational Logic ---
if prompt := st.chat_input("Ask a question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"): st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Synapse is thinking..."):
try:
chat_history = [HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) for m in st.session_state.messages[:-1]]
# Step 1: Generate a structured research plan with forceful doc instruction
doc_instruction = ""
if st.session_state.doc_retriever:
doc_instruction = "IMPORTANT: The user has uploaded documents. For any part of the user's question that explicitly refers to 'the paper' or 'the document', you MUST set the 'datasource' for that query to 'doc'."
research_plan = st.session_state.sub_query_chain.invoke({
"chat_history": chat_history,
"input": prompt,
"doc_instruction": doc_instruction
})
sub_queries = research_plan.queries
sub_queries_html = "<div class='search-query-display'><b>Research Plan:</b><ul>" + "".join([f"<li><b>Search {q.datasource}:</b> {q.query}</li>" for q in sub_queries]) + "</ul></div>"
st.markdown(sub_queries_html, unsafe_allow_html=True)
# Step 2: Execute retrievals based on the reliable plan
retrieved_docs = []
for query_info in sub_queries:
if query_info.datasource == "doc" and st.session_state.doc_retriever:
results = st.session_state.doc_retriever.invoke(query_info.query)
retrieved_docs.extend(results)
else:
results = st.session_state.web_retriever.invoke(query_info.query)
retrieved_docs.extend(results)
# Step 3: Format sources and context for annotation
numbered_context_list = []
source_markdown_list = []
for i, doc in enumerate(retrieved_docs):
source_id = i + 1
numbered_context_list.append(f"[{source_id}] Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content}")
if "filename" in doc.metadata:
filename = doc.metadata.get('filename', 'Unknown Document')
page_meta = doc.metadata.get('page', 'N/A')
display_page = page_meta + 1 if isinstance(page_meta, int) else 'N/A'
source_markdown_list.append(f"**[{source_id}]** Document: {filename} (Page {display_page})")
elif "title" in doc.metadata and "source" in doc.metadata:
source_markdown_list.append(f"**[{source_id}]** Web: [{doc.metadata['title']}]({doc.metadata['source']})")
numbered_context_str = "\n\n".join(numbered_context_list)
source_markdown = "\n\n".join(source_markdown_list)
with st.expander("Sources", expanded=False):
st.markdown(source_markdown, unsafe_allow_html=True)
# Step 4: Generate final, annotated answer
answer = st.session_state.qa_chain.invoke({
"chat_history": chat_history,
"input": prompt,
"context": numbered_context_str
})
st.markdown(answer)
st.session_state.messages.append({
"role": "assistant",
"content": answer,
"sub_queries_html": sub_queries_html,
"sources": source_markdown
})
except httpx.HTTPStatusError as e:
st.error(f"An API error occurred: {e}. The service may be busy. Please try again shortly.")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
if not st.session_state.messages:
st.info("Upload your documents and ask a question to get started.")