Spaces:
Running
Running
import streamlit as st | |
from langchain.memory import ConversationBufferMemory | |
from llama_index.core.indices.query.schema import QueryBundle | |
from llama_index.core import Document, VectorStoreIndex | |
from llama_index.core.text_splitter import SentenceSplitter | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.core.postprocessor import SentenceTransformerRerank | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.embeddings.gemini import GeminiEmbedding | |
from llama_index.llms.gemini import Gemini | |
from llama_index.core import Settings | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from llama_index.core import ( | |
SimpleDirectoryReader, | |
load_index_from_storage, | |
VectorStoreIndex, | |
StorageContext, | |
) | |
from llama_index.core.node_parser import SemanticSplitterNodeParser | |
import os | |
import faiss | |
import pickle | |
import spacy | |
# Load NLP model | |
# nlp = spacy.load("en_core_web_sm") | |
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
# Function to load documents | |
def load_documents(filename="documents.pkl"): | |
with open(filename, "rb") as file: | |
return pickle.load(file) | |
# Load stored documents | |
loaded_docs = load_documents() | |
# Function to split text into sentences | |
# def spacy_sentence_splitter(text): | |
# doc = nlp(text) | |
# return [sent.text for sent in doc.sents] | |
embed_model = GeminiEmbedding(model_name="models/embedding-001", use_async=False) | |
splitter = SemanticSplitterNodeParser( | |
buffer_size=5, breakpoint_percentile_threshold=95, embed_model=embed_model | |
) | |
# splitter = SentenceSplitter(chunk_size=512, chunk_overlap=50, separator="\n") | |
nodes = splitter.get_nodes_from_documents([doc for doc in loaded_docs]) | |
chunked_documents = [Document(text=node.text, metadata=node.metadata) for node in nodes] | |
# Process documents | |
# chunked_documents = [ | |
# Document(text=chunk_text, metadata=doc.metadata) | |
# for doc in loaded_docs for chunk_text in spacy_sentence_splitter(doc.text) | |
# ] | |
# Configure LLM and embeddings | |
Settings.llm = Gemini(model="models/gemini-2.0-flash", api_key=GOOGLE_API_KEY, temperature=0.5) | |
dimension = 768 | |
faiss_index = faiss.IndexFlatL2(dimension) | |
vector_store = FaissVectorStore(faiss_index=faiss_index) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
# Build index | |
index = VectorStoreIndex.from_documents( | |
documents=chunked_documents, | |
storage_context=storage_context, | |
embed_model=embed_model, | |
show_progress=True | |
) | |
index.storage_context.persist() | |
# Initialize memory | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
def get_chat_history(): | |
return memory.load_memory_variables({})["chat_history"] | |
# Define chatbot prompt template | |
prompt_template = PromptTemplate( | |
"""You are a friendly college counselor with expertise in Indian technical institutes. | |
Previous conversation context (if any):\n{chat_history}\n\n | |
Available college information:\n{context_str}\n\n" | |
User query: {query_str}\n\n | |
Instructions:\n | |
1. Provide a brief, direct answer using only the information available above\n | |
2. If specific data is not available, clearly state that\n | |
3. Keep responses under 3 sentences when possible\n | |
4. If comparing colleges, use bullet points for clarity\n | |
5. Use a friendly, conversational tone\n | |
6. Always be interactive and ask follow-up questions\n | |
7. Always try to give answers in points each point should focus on single aspect of the response.\n | |
8. Always try to give conclusion of your answer in the end for the user to take a decision.\n | |
Response:""" | |
) | |
# Configure retrieval and query engine | |
vector_retriever = index.as_retriever(similarity_top_k=10) | |
bm25_retriever = BM25Retriever.from_defaults(index=index, similarity_top_k=10) | |
hybrid_retriever = QueryFusionRetriever( | |
[vector_retriever, bm25_retriever], | |
similarity_top_k=10, | |
num_queries=10, | |
mode="reciprocal_rerank", | |
use_async=False | |
) | |
reranker = SentenceTransformerRerank( | |
model="cross-encoder/ms-marco-MiniLM-L-2-v2", | |
top_n=10, | |
) | |
query_engine = RetrieverQueryEngine.from_args( | |
retriever=hybrid_retriever, | |
node_postprocessors=[reranker], | |
llm=Settings.llm, | |
verbose=True, | |
prompt_template=prompt_template, | |
use_async=False, | |
) | |
# Streamlit UI | |
st.title("📚 Precollege Chatbot") | |
st.write("Ask me anything about different colleges and their courses!") | |
# Custom CSS for WhatsApp-like interface | |
st.markdown(""" | |
<style> | |
body { | |
background-color: #111b21; | |
color: #e9edef; | |
} | |
.stApp { | |
background-color: #111b21; | |
} | |
.chat-container { | |
padding: 10px; | |
color: #111b21; | |
} | |
.user-message { | |
background-color: #005c4b; | |
color: #e9edef; | |
padding: 10px 15px; | |
border-radius: 15px; | |
margin: 5px 0; | |
max-width: 70%; | |
margin-left: auto; | |
margin-right: 10px; | |
} | |
.ai-message { | |
background-color: #1f2c33; | |
color: #e9edef; | |
padding: 10px 15px; | |
border-radius: 15px; | |
margin: 5px 0; | |
max-width: 70%; | |
margin-right: auto; | |
margin-left: 10px; | |
box-shadow: 0 1px 2px rgba(255,255,255,0.1); | |
} | |
.ai-message table { | |
border-collapse: collapse; | |
width: 100%; | |
margin: 10px 0; | |
} | |
.ai-message th, .ai-message td { | |
border: 1px solid #e9edef; | |
padding: 8px; | |
text-align: left; | |
} | |
.ai-message th { | |
background-color: #2a3942; | |
} | |
.message-container { | |
display: flex; | |
margin-bottom: 10px; | |
} | |
.stTextInput input { | |
border-radius: 20px; | |
padding: 10px 20px; | |
border: 1px solid #ccc; | |
background-color: #2a3942; | |
color: #e9edef; | |
} | |
.stButton button { | |
border-radius: 50%; /* Make it circular */ | |
width: 40px; | |
height: 40px; | |
padding: 0px; | |
background-color: #005c4b; | |
color: #e9edef; | |
font-size: 20px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
border: none; | |
cursor: pointer; | |
} | |
.stButton button:hover { | |
background-color: #00735e; | |
} | |
div[data-testid="stToolbar"] { | |
display: none; | |
} | |
.stMarkdown { | |
color: #e9edef; | |
} | |
header { | |
background-color: #202c33 !important; | |
} | |
.ai-message table.ai-table { | |
border-collapse: collapse; | |
width: 100%; | |
margin: 10px 0; | |
background-color: #2a3942; | |
} | |
.ai-message table.ai-table th, | |
.ai-message table.ai-table td { | |
border: 1px solid #e9edef; | |
padding: 8px; | |
text-align: left; | |
color: #e9edef; | |
} | |
.ai-message table.ai-table th { | |
background-color: #005c4b; | |
font-weight: bold; | |
} | |
.ai-message table.ai-table tr:nth-child(even) { | |
background-color: #1f2c33; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
# Create a container for chat messages | |
chat_container = st.container() | |
# Create a form for input | |
with st.form(key="message_form", clear_on_submit=True): | |
col1, col2 = st.columns([5,1]) | |
with col1: | |
user_input = st.text_input("", placeholder="Type a message...", label_visibility="collapsed") | |
with col2: | |
submit_button = st.form_submit_button("➤") | |
if submit_button and user_input.strip(): | |
chat_history = get_chat_history() | |
query_bundle = QueryBundle(query_str=f"{chat_history}\n\nUser: {user_input}") | |
response_obj = query_engine.query(query_bundle) | |
response_text = str(response_obj.response) if hasattr(response_obj, "response") else str(response_obj) | |
memory.save_context({"query_str": user_input}, {"response": response_text}) | |
st.session_state.chat_history.append(("You", user_input)) | |
st.session_state.chat_history.append(("AI", response_text)) | |
# Display chat history with custom styling | |
with chat_container: | |
for role, message in st.session_state.chat_history: | |
message = message.replace("</div>", "").replace("<div>", "") # Sanitize the message | |
if role == "You": | |
st.markdown( | |
f'<div class="message-container"><div class="user-message">{message}</div></div>', | |
unsafe_allow_html=True | |
) | |
else: | |
# Convert markdown tables to HTML tables with proper styling | |
if "|" in message and "-|-" in message: # Detect markdown tables | |
# Split the message into lines | |
lines = message.split("\n") | |
table_html = [] | |
in_table = False | |
formatted_lines = [] | |
for line in lines: | |
if "|" in line: | |
if not in_table: | |
in_table = True | |
table_html.append('<table class="ai-table">') | |
# Add header | |
header = line.strip().strip("|").split("|") | |
table_html.append("<tr>") | |
for h in header: | |
table_html.append(f"<th>{h.strip()}</th>") | |
table_html.append("</tr>") | |
elif "-|-" not in line: # Skip separator line | |
# Add row | |
row = line.strip().strip("|").split("|") | |
table_html.append("<tr>") | |
for cell in row: | |
table_html.append(f"<td>{cell.strip()}</td>") | |
table_html.append("</tr>") | |
else: | |
if in_table: | |
in_table = False | |
table_html.append("</table>") | |
formatted_lines.append("".join(table_html)) | |
table_html = [] | |
formatted_lines.append(line) | |
if in_table: | |
table_html.append("</table>") | |
formatted_lines.append("".join(table_html)) | |
message = "\n".join(formatted_lines) | |
st.markdown( | |
f'<div class="message-container"><div class="ai-message">{message}</div></div>', | |
unsafe_allow_html=True | |
) | |