Spaces:
Running
Running
File size: 13,519 Bytes
46b1e53 d9a47f1 b21c9ef 3b39ab4 46b1e53 b21c9ef 46b1e53 0df374b ed8f0b3 9597185 0df374b 3b39ab4 d674d8c 46b1e53 029243a 46b1e53 aef9f65 029243a 3b39ab4 46b1e53 0e891bd 029243a 04f7a8e 029243a d6fe6fd d998a88 029243a 46b1e53 029243a abe218d 029243a e902497 0e891bd 0df374b 3b39ab4 029243a aef9f65 46b1e53 029243a 3b39ab4 029243a d998a88 4dfaa48 0e891bd b21c9ef 4dfaa48 0e891bd c9d1786 aef9f65 b21c9ef 97d882d 4dfaa48 3b39ab4 0df374b 4dfaa48 d674d8c 3b39ab4 d674d8c 3b39ab4 634bd9d 0df374b d674d8c 0df374b 634bd9d 0df374b d674d8c 634bd9d d674d8c 0df374b d674d8c 0df374b 3b39ab4 e902497 0df374b 3b39ab4 51db066 0df374b 51db066 0df374b 51db066 d674d8c 0df374b d674d8c 0df374b d674d8c 3b39ab4 c74639f 0df374b 3b39ab4 0df374b 772eb61 3b39ab4 51db066 3b39ab4 51db066 3b39ab4 0df374b 634bd9d d674d8c 9597185 d674d8c 3b39ab4 0df374b d674d8c 3b39ab4 d674d8c 3b39ab4 d674d8c 3b39ab4 d674d8c 634bd9d d674d8c 0d2befc 0df374b d674d8c 0e891bd d674d8c 4dfaa48 d674d8c 97d882d d674d8c 0e891bd 0df374b 4dfaa48 0df374b 4dfaa48 3b39ab4 0df374b 4dfaa48 0e891bd 0df374b 97d882d bd837cf 0e891bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
import os
import streamlit as st
from dotenv import load_dotenv
import httpx
from huggingface_hub import InferenceClient
import json
# --- LangChain Imports ---
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.retrievers import TavilySearchAPIRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
# --- 1. Load API Keys ---
load_dotenv()
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
# --- App Configuration ---
st.set_page_config(page_title="Synapse AI", page_icon="🧠", layout="wide")
# --- Custom CSS ---
st.markdown("""
<style>
.stApp { background-color: #1E1E1E; color: #E0E0E0; }
[data-testid="stChatMessage"] { background-color: #2B2B2B; border-radius: 10px; padding: 1rem; border: 1px solid #333; }
[data-testid="stChatInput"] { background-color: #2B2B2B; border-top: 1px solid #333; }
[data-testid="stSidebar"] { background-color: #1A1A1A; border-right: 1px solid #333; }
.st-expander, .st-expander header { background-color: #2B2B2B !important; color: #E0E0E0 !important; border-radius: 10px; border: 1px solid #333; }
.st-expander header:hover { background-color: #333 !important; }
.stButton>button { background-color: #4CAF50; color: white; border-radius: 8px; border: none; }
.stAlert { border-radius: 8px; }
.search-query-display {
background-color: #2B2B2B;
border: 1px solid #444;
padding: 0.5rem 1rem;
border-radius: 8px;
margin-bottom: 1rem;
font-family: monospace;
color: #A0A0A0;
}
</style>
""", unsafe_allow_html=True)
# --- Title & Header ---
st.title("🧠 Synapse AI")
# --- Session State Initialization ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "doc_retriever" not in st.session_state:
st.session_state.doc_retriever = None
if "web_retriever" not in st.session_state:
st.session_state.web_retriever = None
if "qa_chain" not in st.session_state:
st.session_state.qa_chain = None
if "sub_query_chain" not in st.session_state:
st.session_state.sub_query_chain = None
# --- API Key Validation ---
if not HUGGING_FACE_HUB_TOKEN:
st.error("HUGGING_FACE_HUB_TOKEN not found! Please add it to your environment secrets.")
st.stop()
if not TAVILY_API_KEY:
st.sidebar.warning("TAVILY_API_KEY not found. Web search will be disabled.")
if not MISTRAL_API_KEY:
st.sidebar.warning("MISTRAL_API_KEY not found. Query generation will be less effective.")
# --- Core Logic ---
def invoke_llm(messages_for_api):
"""Manually invokes the HF Inference Client with a simple message list."""
client = InferenceClient(token=HUGGING_FACE_HUB_TOKEN)
response = client.chat_completion(
messages=messages_for_api,
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=2048,
temperature=0.1
)
return response.choices[0].message.content
def llm_wrapper(prompt_value):
"""Converts LangChain message objects to the dictionary format required by the HF client."""
messages_for_api = []
for msg in prompt_value.to_messages():
role = "user" if msg.type == 'human' else "assistant" if msg.type == 'ai' else "system"
messages_for_api.append({"role": role, "content": msg.content})
return invoke_llm(messages_for_api)
# Pydantic models for structured output (Tool Calling)
class SubQuery(BaseModel):
"""A single, targeted search query with its designated datasource."""
query: str = Field(description="The specific, self-contained search query string.")
datasource: str = Field(description="The source to search, either 'web' or 'doc'.")
class ResearchPlan(BaseModel):
"""A list of sub-queries to execute for answering a user's question."""
queries: List[SubQuery] = Field(description="A list of 2-4 targeted sub-queries.")
@st.cache_resource
def create_sub_query_chain():
"""Creates a chain to generate targeted sub-queries using a commercial LLM with tool calling."""
prompt = ChatPromptTemplate.from_messages([
("system", """You are an expert at breaking down complex user questions into a series of smaller, targeted search queries.
Based on the user's question and the conversation history, generate a research plan by calling the ResearchPlan tool.
{doc_instruction}"""),
MessagesPlaceholder("chat_history"),
("human", "{input}")
])
llm = ChatMistralAI(model="mistral-large-latest", temperature=0, api_key=MISTRAL_API_KEY)
return prompt | llm.with_structured_output(ResearchPlan)
@st.cache_resource
def create_qa_chain():
"""Creates the final question-answering chain with annotation and formatting instructions."""
prompt = ChatPromptTemplate.from_messages([
("system", """You are an AI research assistant. Your task is to answer the user's question based on the chat history and the provided context.
Synthesize the information from all sources into a single, cohesive, well-formatted answer.
**Formatting Instructions:**
- Use Markdown for clear formatting (headings, bold text, lists).
- Use LaTeX for all mathematical notation, formulas, and technical symbols by enclosing them in '$' or '$$'. For example, write '$L=12$' for variables.
- Structure your response logically with clear sections where appropriate.
IMPORTANT: You MUST cite the sources you use. The context is provided as a numbered list. At the end of each sentence or claim you make, add the corresponding source number(s) in brackets, like [1] or [2, 3].
CONTEXT:
{context}"""),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
return (
RunnablePassthrough.assign(context=lambda inputs: inputs["context"]) # Pass context directly
| prompt
| RunnableLambda(llm_wrapper)
| StrOutputParser()
)
@st.cache_resource
def build_doc_retriever(_uploaded_files):
"""Builds and returns a document retriever from uploaded files."""
if not _uploaded_files:
return None
all_splits = []
for uploaded_file in _uploaded_files:
temp_file_path = f"/tmp/{uploaded_file.name}"
with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue())
loader = PyPDFLoader(temp_file_path)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
for split in splits:
split.metadata["filename"] = uploaded_file.name
all_splits.extend(splits)
os.remove(temp_file_path)
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
vectorstore = FAISS.from_documents(documents=all_splits, embedding=embedding_model)
return vectorstore.as_retriever(search_kwargs={"k": 5})
# --- UI & State Management ---
with st.sidebar:
st.title("Controls")
st.write("Upload or manage documents.")
uploaded_files = st.file_uploader("Upload PDFs", type="pdf", accept_multiple_files=True, key="pdf_uploader_main")
if st.button("Start New Chat"):
st.session_state.clear()
st.rerun()
if "file_names" not in st.session_state: st.session_state.file_names = []
current_file_names = [f.name for f in uploaded_files]
if set(current_file_names) != set(st.session_state.file_names):
st.session_state.file_names = current_file_names
with st.spinner(f"Processing {len(st.session_state.file_names)} document(s)..."):
st.session_state.doc_retriever = build_doc_retriever(uploaded_files)
st.success("Documents processed!")
if "web_retriever" not in st.session_state or st.session_state.web_retriever is None:
st.session_state.web_retriever = TavilySearchAPIRetriever(k=5, tavily_api_key=TAVILY_API_KEY)
if "qa_chain" not in st.session_state or st.session_state.qa_chain is None:
st.session_state.qa_chain = create_qa_chain()
if "sub_query_chain" not in st.session_state or st.session_state.sub_query_chain is None:
st.session_state.sub_query_chain = create_sub_query_chain()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if "sub_queries_html" in message:
st.markdown(message["sub_queries_html"], unsafe_allow_html=True)
st.markdown(message["content"])
if "sources" in message:
with st.expander("Sources", expanded=False):
st.markdown(message["sources"], unsafe_allow_html=True)
# --- Main Conversational Logic ---
if prompt := st.chat_input("Ask a question..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"): st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Synapse is thinking..."):
try:
chat_history = [HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) for m in st.session_state.messages[:-1]]
# Step 1: Generate a structured research plan with forceful doc instruction
doc_instruction = ""
if st.session_state.doc_retriever:
doc_instruction = "IMPORTANT: The user has uploaded documents. For any part of the user's question that explicitly refers to 'the paper' or 'the document', you MUST set the 'datasource' for that query to 'doc'."
research_plan = st.session_state.sub_query_chain.invoke({
"chat_history": chat_history,
"input": prompt,
"doc_instruction": doc_instruction
})
sub_queries = research_plan.queries
sub_queries_html = "<div class='search-query-display'><b>Research Plan:</b><ul>" + "".join([f"<li><b>Search {q.datasource}:</b> {q.query}</li>" for q in sub_queries]) + "</ul></div>"
st.markdown(sub_queries_html, unsafe_allow_html=True)
# Step 2: Execute retrievals based on the reliable plan
retrieved_docs = []
for query_info in sub_queries:
if query_info.datasource == "doc" and st.session_state.doc_retriever:
results = st.session_state.doc_retriever.invoke(query_info.query)
retrieved_docs.extend(results)
else:
results = st.session_state.web_retriever.invoke(query_info.query)
retrieved_docs.extend(results)
# Step 3: Format sources and context for annotation
numbered_context_list = []
source_markdown_list = []
for i, doc in enumerate(retrieved_docs):
source_id = i + 1
numbered_context_list.append(f"[{source_id}] Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content}")
if "filename" in doc.metadata:
filename = doc.metadata.get('filename', 'Unknown Document')
page_meta = doc.metadata.get('page', 'N/A')
display_page = page_meta + 1 if isinstance(page_meta, int) else 'N/A'
source_markdown_list.append(f"**[{source_id}]** Document: {filename} (Page {display_page})")
elif "title" in doc.metadata and "source" in doc.metadata:
source_markdown_list.append(f"**[{source_id}]** Web: [{doc.metadata['title']}]({doc.metadata['source']})")
numbered_context_str = "\n\n".join(numbered_context_list)
source_markdown = "\n\n".join(source_markdown_list)
with st.expander("Sources", expanded=False):
st.markdown(source_markdown, unsafe_allow_html=True)
# Step 4: Generate final, annotated answer
answer = st.session_state.qa_chain.invoke({
"chat_history": chat_history,
"input": prompt,
"context": numbered_context_str
})
st.markdown(answer)
st.session_state.messages.append({
"role": "assistant",
"content": answer,
"sub_queries_html": sub_queries_html,
"sources": source_markdown
})
except httpx.HTTPStatusError as e:
st.error(f"An API error occurred: {e}. The service may be busy. Please try again shortly.")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
if not st.session_state.messages:
st.info("Upload your documents and ask a question to get started.")
|