Spaces:
Running
Running
import argparse | |
import textwrap | |
from pathlib import Path | |
import os | |
from dotenv import load_dotenv | |
from qa_prompts import PROMPT_TMPL | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
load_dotenv() | |
HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL") | |
LLM_MODEL_NAME = os.getenv("LLM_MODEL") | |
ROOT_DIR = Path(__file__).parent | |
INDEX_DIR = Path(f"{ROOT_DIR}/data_index") | |
def load_retriever(index_dir: Path, k: int = 4): | |
# Ensure we use the same embedding model that was used during ingest | |
embed_model_name_path = index_dir / "embeddings_model.txt" | |
if not embed_model_name_path.exists(): | |
raise RuntimeError(f"Missing {embed_model_name_path}. Re-run ingest.py.") | |
embed_model_name = embed_model_name_path.read_text(encoding="utf-8").strip() | |
embeddings = HuggingFaceEmbeddings(model_name=embed_model_name) | |
vs = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True) | |
return vs.as_retriever(search_kwargs={"k": k}) | |
def build_chain_gemini(retriever): | |
if not GOOGLE_API_KEY: | |
raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.") | |
# Uses Google Generative AI (Gemini) hosted inference endpoint | |
llm = ChatGoogleGenerativeAI( | |
model=LLM_MODEL_NAME, | |
api_key=GOOGLE_API_KEY, | |
temperature=0.1, | |
max_output_tokens=512, | |
convert_system_message_to_human=True, | |
) | |
prompt = PromptTemplate( | |
input_variables=["context", "question"], | |
template=PROMPT_TMPL, | |
) | |
# map_reduce keeps per-call size manageable and robust | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs={"prompt": prompt}, | |
return_source_documents=True, | |
) | |
return qa | |
def main(): | |
parser = argparse.ArgumentParser(description="Run recruiter Q/A over a saved FAISS index.") | |
args = parser.parse_args() | |
retriever = load_retriever(INDEX_DIR) | |
chain = build_chain_gemini(retriever) | |
print("\My Profile Chatbot ready. Ask about me.") | |
print("Type 'exit' to quit.\n") | |
while True: | |
try: | |
q = input("You: ").strip() | |
except (EOFError, KeyboardInterrupt): | |
print("\nBye!") | |
break | |
if not q: | |
continue | |
if q.lower() in {"exit", "quit", "q"}: | |
print("Bye!") | |
break | |
try: | |
res = chain.invoke({"query": q}) | |
answer = res["result"] if isinstance(res, dict) else str(res) | |
except Exception as e: | |
answer = f"[error] {e}" | |
print("\nMaheen:", textwrap.fill(answer, width=100)) | |
print() | |
if __name__ == "__main__": | |
main() | |