PersonalChatbot / src /qa_chain_cli.py
Maheen Saleh
updated proj structure
4a3a2c0
import argparse
import textwrap
from pathlib import Path
import os
from dotenv import load_dotenv
from qa_prompts import PROMPT_TMPL
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_google_genai import ChatGoogleGenerativeAI
load_dotenv()
HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL")
LLM_MODEL_NAME = os.getenv("LLM_MODEL")
ROOT_DIR = Path(__file__).parent
INDEX_DIR = Path(f"{ROOT_DIR}/data_index")
def load_retriever(index_dir: Path, k: int = 4):
# Ensure we use the same embedding model that was used during ingest
embed_model_name_path = index_dir / "embeddings_model.txt"
if not embed_model_name_path.exists():
raise RuntimeError(f"Missing {embed_model_name_path}. Re-run ingest.py.")
embed_model_name = embed_model_name_path.read_text(encoding="utf-8").strip()
embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
vs = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True)
return vs.as_retriever(search_kwargs={"k": k})
def build_chain_gemini(retriever):
if not GOOGLE_API_KEY:
raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")
# Uses Google Generative AI (Gemini) hosted inference endpoint
llm = ChatGoogleGenerativeAI(
model=LLM_MODEL_NAME,
api_key=GOOGLE_API_KEY,
temperature=0.1,
max_output_tokens=512,
convert_system_message_to_human=True,
)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=PROMPT_TMPL,
)
# map_reduce keeps per-call size manageable and robust
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
return qa
def main():
parser = argparse.ArgumentParser(description="Run recruiter Q/A over a saved FAISS index.")
args = parser.parse_args()
retriever = load_retriever(INDEX_DIR)
chain = build_chain_gemini(retriever)
print("\My Profile Chatbot ready. Ask about me.")
print("Type 'exit' to quit.\n")
while True:
try:
q = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not q:
continue
if q.lower() in {"exit", "quit", "q"}:
print("Bye!")
break
try:
res = chain.invoke({"query": q})
answer = res["result"] if isinstance(res, dict) else str(res)
except Exception as e:
answer = f"[error] {e}"
print("\nMaheen:", textwrap.fill(answer, width=100))
print()
if __name__ == "__main__":
main()