File size: 13,519 Bytes
46b1e53
 
 
d9a47f1
b21c9ef
3b39ab4
46b1e53
 
 
 
b21c9ef
46b1e53
0df374b
 
ed8f0b3
9597185
0df374b
3b39ab4
d674d8c
 
46b1e53
029243a
46b1e53
aef9f65
029243a
3b39ab4
46b1e53
 
0e891bd
029243a
04f7a8e
029243a
 
d6fe6fd
 
 
 
 
 
 
 
d998a88
 
 
 
 
 
 
 
 
029243a
 
46b1e53
029243a
abe218d
029243a
e902497
0e891bd
 
0df374b
 
 
 
 
 
3b39ab4
 
 
029243a
aef9f65
 
46b1e53
029243a
 
3b39ab4
 
 
029243a
d998a88
 
4dfaa48
0e891bd
b21c9ef
 
4dfaa48
0e891bd
 
c9d1786
aef9f65
b21c9ef
97d882d
4dfaa48
3b39ab4
0df374b
 
 
 
4dfaa48
 
d674d8c
 
 
 
 
 
 
 
 
 
3b39ab4
 
d674d8c
3b39ab4
 
634bd9d
 
0df374b
 
 
d674d8c
 
0df374b
 
 
634bd9d
0df374b
d674d8c
 
634bd9d
 
 
 
 
 
d674d8c
 
 
 
0df374b
 
 
 
d674d8c
0df374b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b39ab4
e902497
 
0df374b
3b39ab4
51db066
 
 
 
 
 
0df374b
51db066
 
0df374b
 
51db066
 
d674d8c
0df374b
d674d8c
0df374b
d674d8c
3b39ab4
c74639f
0df374b
 
3b39ab4
 
0df374b
 
 
 
772eb61
3b39ab4
51db066
 
3b39ab4
51db066
 
 
 
3b39ab4
0df374b
634bd9d
 
 
 
 
 
 
 
 
 
d674d8c
9597185
d674d8c
3b39ab4
0df374b
d674d8c
 
3b39ab4
d674d8c
 
3b39ab4
d674d8c
 
3b39ab4
 
d674d8c
 
 
 
 
634bd9d
d674d8c
0d2befc
0df374b
 
 
d674d8c
 
 
0e891bd
d674d8c
 
4dfaa48
d674d8c
 
97d882d
d674d8c
 
 
 
 
 
0e891bd
0df374b
4dfaa48
0df374b
4dfaa48
3b39ab4
0df374b
4dfaa48
0e891bd
0df374b
 
97d882d
 
bd837cf
0e891bd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import os
import streamlit as st
from dotenv import load_dotenv
import httpx
from huggingface_hub import InferenceClient
import json

# --- LangChain Imports ---
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.retrievers import TavilySearchAPIRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_mistralai import ChatMistralAI
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List

# --- 1. Load API Keys ---
load_dotenv()
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")

# --- App Configuration ---
st.set_page_config(page_title="Synapse AI", page_icon="🧠", layout="wide")

# --- Custom CSS ---
st.markdown("""
<style>
    .stApp { background-color: #1E1E1E; color: #E0E0E0; }
    [data-testid="stChatMessage"] { background-color: #2B2B2B; border-radius: 10px; padding: 1rem; border: 1px solid #333; }
    [data-testid="stChatInput"] { background-color: #2B2B2B; border-top: 1px solid #333; }
    [data-testid="stSidebar"] { background-color: #1A1A1A; border-right: 1px solid #333; }
    .st-expander, .st-expander header { background-color: #2B2B2B !important; color: #E0E0E0 !important; border-radius: 10px; border: 1px solid #333; }
    .st-expander header:hover { background-color: #333 !important; }
    .stButton>button { background-color: #4CAF50; color: white; border-radius: 8px; border: none; }
    .stAlert { border-radius: 8px; }
    .search-query-display { 
        background-color: #2B2B2B; 
        border: 1px solid #444; 
        padding: 0.5rem 1rem; 
        border-radius: 8px; 
        margin-bottom: 1rem;
        font-family: monospace;
        color: #A0A0A0;
    }
</style>
""", unsafe_allow_html=True)

# --- Title & Header ---
st.title("🧠 Synapse AI")

# --- Session State Initialization ---
if "messages" not in st.session_state:
    st.session_state.messages = []
if "doc_retriever" not in st.session_state:
    st.session_state.doc_retriever = None
if "web_retriever" not in st.session_state:
    st.session_state.web_retriever = None
if "qa_chain" not in st.session_state:
    st.session_state.qa_chain = None
if "sub_query_chain" not in st.session_state:
    st.session_state.sub_query_chain = None
    
# --- API Key Validation ---
if not HUGGING_FACE_HUB_TOKEN:
    st.error("HUGGING_FACE_HUB_TOKEN not found! Please add it to your environment secrets.")
    st.stop()
if not TAVILY_API_KEY:
    st.sidebar.warning("TAVILY_API_KEY not found. Web search will be disabled.")
if not MISTRAL_API_KEY:
    st.sidebar.warning("MISTRAL_API_KEY not found. Query generation will be less effective.")


# --- Core Logic ---

def invoke_llm(messages_for_api):
    """Manually invokes the HF Inference Client with a simple message list."""
    client = InferenceClient(token=HUGGING_FACE_HUB_TOKEN)
    response = client.chat_completion(
        messages=messages_for_api,
        model="mistralai/Mixtral-8x7B-Instruct-v0.1",
        max_tokens=2048,
        temperature=0.1
    )
    return response.choices[0].message.content

def llm_wrapper(prompt_value):
    """Converts LangChain message objects to the dictionary format required by the HF client."""
    messages_for_api = []
    for msg in prompt_value.to_messages():
        role = "user" if msg.type == 'human' else "assistant" if msg.type == 'ai' else "system"
        messages_for_api.append({"role": role, "content": msg.content})
    return invoke_llm(messages_for_api)

# Pydantic models for structured output (Tool Calling)
class SubQuery(BaseModel):
    """A single, targeted search query with its designated datasource."""
    query: str = Field(description="The specific, self-contained search query string.")
    datasource: str = Field(description="The source to search, either 'web' or 'doc'.")

class ResearchPlan(BaseModel):
    """A list of sub-queries to execute for answering a user's question."""
    queries: List[SubQuery] = Field(description="A list of 2-4 targeted sub-queries.")

@st.cache_resource
def create_sub_query_chain():
    """Creates a chain to generate targeted sub-queries using a commercial LLM with tool calling."""
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are an expert at breaking down complex user questions into a series of smaller, targeted search queries. 
        Based on the user's question and the conversation history, generate a research plan by calling the ResearchPlan tool.
        {doc_instruction}"""),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}")
    ])
    llm = ChatMistralAI(model="mistral-large-latest", temperature=0, api_key=MISTRAL_API_KEY)
    return prompt | llm.with_structured_output(ResearchPlan)

@st.cache_resource
def create_qa_chain():
    """Creates the final question-answering chain with annotation and formatting instructions."""
    prompt = ChatPromptTemplate.from_messages([
        ("system", """You are an AI research assistant. Your task is to answer the user's question based on the chat history and the provided context.
        Synthesize the information from all sources into a single, cohesive, well-formatted answer.
        
        **Formatting Instructions:**
        - Use Markdown for clear formatting (headings, bold text, lists).
        - Use LaTeX for all mathematical notation, formulas, and technical symbols by enclosing them in '$' or '$$'. For example, write '$L=12$' for variables.
        - Structure your response logically with clear sections where appropriate.

        IMPORTANT: You MUST cite the sources you use. The context is provided as a numbered list. At the end of each sentence or claim you make, add the corresponding source number(s) in brackets, like [1] or [2, 3].

        CONTEXT:
        {context}"""),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ])
    return (
        RunnablePassthrough.assign(context=lambda inputs: inputs["context"]) # Pass context directly
        | prompt
        | RunnableLambda(llm_wrapper)
        | StrOutputParser()
    )

@st.cache_resource
def build_doc_retriever(_uploaded_files):
    """Builds and returns a document retriever from uploaded files."""
    if not _uploaded_files:
        return None
    all_splits = []
    for uploaded_file in _uploaded_files:
        temp_file_path = f"/tmp/{uploaded_file.name}"
        with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue())
        loader = PyPDFLoader(temp_file_path)
        docs = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
        splits = text_splitter.split_documents(docs)
        for split in splits:
            split.metadata["filename"] = uploaded_file.name
        all_splits.extend(splits)
        os.remove(temp_file_path)
    embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
    vectorstore = FAISS.from_documents(documents=all_splits, embedding=embedding_model)
    return vectorstore.as_retriever(search_kwargs={"k": 5})

# --- UI & State Management ---
with st.sidebar:
    st.title("Controls")
    st.write("Upload or manage documents.")
    uploaded_files = st.file_uploader("Upload PDFs", type="pdf", accept_multiple_files=True, key="pdf_uploader_main")
    if st.button("Start New Chat"):
        st.session_state.clear()
        st.rerun()

if "file_names" not in st.session_state: st.session_state.file_names = []
current_file_names = [f.name for f in uploaded_files]

if set(current_file_names) != set(st.session_state.file_names):
    st.session_state.file_names = current_file_names
    with st.spinner(f"Processing {len(st.session_state.file_names)} document(s)..."):
        st.session_state.doc_retriever = build_doc_retriever(uploaded_files)
    st.success("Documents processed!")

if "web_retriever" not in st.session_state or st.session_state.web_retriever is None:
    st.session_state.web_retriever = TavilySearchAPIRetriever(k=5, tavily_api_key=TAVILY_API_KEY)
if "qa_chain" not in st.session_state or st.session_state.qa_chain is None:
    st.session_state.qa_chain = create_qa_chain()
if "sub_query_chain" not in st.session_state or st.session_state.sub_query_chain is None:
    st.session_state.sub_query_chain = create_sub_query_chain()

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        if "sub_queries_html" in message:
            st.markdown(message["sub_queries_html"], unsafe_allow_html=True)
        st.markdown(message["content"])
        if "sources" in message:
            with st.expander("Sources", expanded=False):
                st.markdown(message["sources"], unsafe_allow_html=True)

# --- Main Conversational Logic ---
if prompt := st.chat_input("Ask a question..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"): st.markdown(prompt)

    with st.chat_message("assistant"):
        with st.spinner("Synapse is thinking..."):
            try:
                chat_history = [HumanMessage(content=m["content"]) if m["role"] == "user" else AIMessage(content=m["content"]) for m in st.session_state.messages[:-1]]
                
                # Step 1: Generate a structured research plan with forceful doc instruction
                doc_instruction = ""
                if st.session_state.doc_retriever:
                    doc_instruction = "IMPORTANT: The user has uploaded documents. For any part of the user's question that explicitly refers to 'the paper' or 'the document', you MUST set the 'datasource' for that query to 'doc'."

                research_plan = st.session_state.sub_query_chain.invoke({
                    "chat_history": chat_history, 
                    "input": prompt,
                    "doc_instruction": doc_instruction
                })
                sub_queries = research_plan.queries
                
                sub_queries_html = "<div class='search-query-display'><b>Research Plan:</b><ul>" + "".join([f"<li><b>Search {q.datasource}:</b> {q.query}</li>" for q in sub_queries]) + "</ul></div>"
                st.markdown(sub_queries_html, unsafe_allow_html=True)
                
                # Step 2: Execute retrievals based on the reliable plan
                retrieved_docs = []
                for query_info in sub_queries:
                    if query_info.datasource == "doc" and st.session_state.doc_retriever:
                        results = st.session_state.doc_retriever.invoke(query_info.query)
                        retrieved_docs.extend(results)
                    else:
                        results = st.session_state.web_retriever.invoke(query_info.query)
                        retrieved_docs.extend(results)

                # Step 3: Format sources and context for annotation
                numbered_context_list = []
                source_markdown_list = []
                for i, doc in enumerate(retrieved_docs):
                    source_id = i + 1
                    numbered_context_list.append(f"[{source_id}] Source: {doc.metadata.get('source', 'N/A')}\nContent: {doc.page_content}")

                    if "filename" in doc.metadata:
                        filename = doc.metadata.get('filename', 'Unknown Document')
                        page_meta = doc.metadata.get('page', 'N/A')
                        display_page = page_meta + 1 if isinstance(page_meta, int) else 'N/A'
                        source_markdown_list.append(f"**[{source_id}]** Document: {filename} (Page {display_page})")
                    elif "title" in doc.metadata and "source" in doc.metadata:
                        source_markdown_list.append(f"**[{source_id}]** Web: [{doc.metadata['title']}]({doc.metadata['source']})")
                
                numbered_context_str = "\n\n".join(numbered_context_list)
                source_markdown = "\n\n".join(source_markdown_list)

                with st.expander("Sources", expanded=False):
                    st.markdown(source_markdown, unsafe_allow_html=True)
                
                # Step 4: Generate final, annotated answer
                answer = st.session_state.qa_chain.invoke({
                    "chat_history": chat_history, 
                    "input": prompt, 
                    "context": numbered_context_str
                })
                st.markdown(answer)
                
                st.session_state.messages.append({
                    "role": "assistant", 
                    "content": answer,
                    "sub_queries_html": sub_queries_html,
                    "sources": source_markdown
                })

            except httpx.HTTPStatusError as e:
                st.error(f"An API error occurred: {e}. The service may be busy. Please try again shortly.")
            except Exception as e:
                st.error(f"An unexpected error occurred: {e}")

if not st.session_state.messages:
    st.info("Upload your documents and ask a question to get started.")