PersonalChatbot / src /streamlit_app.py
maheensaleh40's picture
updated UI for question input bar
8c1e6b5
import subprocess
from pathlib import Path
from typing import List
import streamlit as st
from qa_prompts import PROMPT_TMPL
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.embeddings.base import Embeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from huggingface_hub import InferenceClient
import os, streamlit as st
from dotenv import load_dotenv
load_dotenv() # still works locally
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
LLM_MODEL_NAME = os.getenv("LLM_MODEL", "gemini-1.5-flash")
ROOT_DIR = Path(__file__).parent
INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
###### run ingest.py (to be run locally) ######
if not INDEX_DIR.exists():
with st.spinner("Index not found. Building FAISS index (first run)…"):
# Ensure ingest.py reads the same env/secrets model and paths
proc = subprocess.run(["python", "src/ingest.py"], capture_output=True, text=True)
if proc.returncode != 0:
st.error(f"ingest.py failed:\n{proc.stderr}")
st.stop()
class HFAPIEmbeddings(Embeddings):
def __init__(self, repo_id: str, token: str | None = None, timeout: float = 120.0):
self.client = InferenceClient(model=repo_id, token=token, timeout=timeout)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.client.feature_extraction(texts)
def embed_query(self, text: str) -> List[float]:
vec = self.client.feature_extraction(text)
return vec[0] if (isinstance(vec, list) and vec and isinstance(vec[0], list)) else vec
def build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources):
if not GOOGLE_API_KEY:
raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
# Uses Google Generative AI (Gemini) hosted inference endpoint
llm = ChatGoogleGenerativeAI(
model=_llm_repo,
api_key=GOOGLE_API_KEY,
temperature=_temp,
max_output_tokens=_max_new,
convert_system_message_to_human=True,
)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=PROMPT_TMPL,
)
#map reduce or stuff
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=_show_sources,
)
return qa
# ========================= Streamlit UI =========================
st.set_page_config(page_title="Maheen's Profile Chatbot", page_icon="💬", layout="centered")
st.title("Maheen's Profile Chatbot")
st.caption("Want to know about my skills and experience? Enter your question below 👇")
# Sidebar settings
st.sidebar.header("Settings")
hf_token = HF_API_TOKEN
if not hf_token:
st.sidebar.warning("HUGGINGFACEHUB_API_TOKEN is not set. Set it in your shell before running the app.")
# Display model names as text (read-only)
st.sidebar.markdown(f"**Embedding Model:** `{EMBED_MODEL_NAME}`")
st.sidebar.markdown(f"**Chat Model:** `{LLM_MODEL_NAME}`")
k = 4
max_new_tokens = 512
temperature = 0.1
show_sources = False
###################
# Session state for chat history
if "history" not in st.session_state:
st.session_state.history = [] # list of (user, assistant, sources)
# Load vector store & chain lazily, cache across reruns
@st.cache_resource(show_spinner=True)
def _load_chain(_store_dir: str, _embed_repo: str, _llm_repo: str, _k: int, _max_new: int, _temp: float, _show_sources: bool):
if not Path(_store_dir).exists():
raise FileNotFoundError(f"FAISS store not found at '{_store_dir}'. Run ingest.py first.")
embeddings = HFAPIEmbeddings(repo_id=_embed_repo, token=hf_token)
vs = FAISS.load_local(
_store_dir,
embeddings,
allow_dangerous_deserialization=True, # required by newer LC versions
)
retriever = vs.as_retriever(search_kwargs={"k": 4}) # hardcoded, change later
chain = build_chain_gemini(retriever, _llm_repo, _max_new, _temp, _show_sources)
return chain
# Prepare chain
with st.spinner("Preparing retriever & LLM…"):
chain = _load_chain(INDEX_DIR, EMBED_MODEL_NAME, LLM_MODEL_NAME, k, max_new_tokens, temperature, show_sources)
def render_sources(docs):
if not docs:
return
st.markdown("**Sources**")
for i, d in enumerate(docs, start=1):
src = d.metadata.get("source", "unknown")
page = d.metadata.get("page", None)
label = f"{Path(src).name}" + (f" (page {page+1})" if isinstance(page, int) else "")
with st.expander(f"{i}. {label}"):
st.write(d.page_content[:1500] + ("…" if len(d.page_content) > 1500 else ""))
# --- Chat input with Enter submit ---
user_input = st.chat_input("e.g. Tell me about your experience as AI Engineer")
if user_input:
with st.spinner("Thinking…"):
try:
res = chain.invoke({"query": user_input.strip()})
if isinstance(res, dict):
answer = res.get("result", "")
sources = res.get("source_documents", []) if show_sources else []
else:
answer, sources = str(res), []
except Exception as e:
answer, sources = f"[error] {e}", []
st.session_state.history.append((user_input.strip(), answer, sources))
# Display history in logs
for q, a, srcs in st.session_state.history:
st.markdown(f"**You:** {q}")
st.markdown(f"**Assistant:** {a}")
if show_sources:
render_sources(srcs)
st.markdown("---")