File size: 3,046 Bytes
4a3a2c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import textwrap
from pathlib import Path
import os
from dotenv import load_dotenv
from qa_prompts import PROMPT_TMPL

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_google_genai import ChatGoogleGenerativeAI

load_dotenv()

HF_API_TOKEN = os.getenv("HUGGING_FACE_API_TOKEN")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

EMBED_MODEL_NAME = os.getenv("HUGGING_FACE_EMBEDDING_MODEL")
LLM_MODEL_NAME = os.getenv("LLM_MODEL")

ROOT_DIR = Path(__file__).parent
INDEX_DIR = Path(f"{ROOT_DIR}/data_index")  


def load_retriever(index_dir: Path, k: int = 4):
    # Ensure we use the same embedding model that was used during ingest
    embed_model_name_path = index_dir / "embeddings_model.txt"
    if not embed_model_name_path.exists():
        raise RuntimeError(f"Missing {embed_model_name_path}. Re-run ingest.py.")
    embed_model_name = embed_model_name_path.read_text(encoding="utf-8").strip()

    embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
    vs = FAISS.load_local(str(index_dir), embeddings, allow_dangerous_deserialization=True)
    return vs.as_retriever(search_kwargs={"k": k})



def build_chain_gemini(retriever):
    if not GOOGLE_API_KEY:
        raise RuntimeError("Set GOOGLE_API_KEY in your .env to use the Gemini inference endpoint.")

    # Uses Google Generative AI (Gemini) hosted inference endpoint
    llm = ChatGoogleGenerativeAI(
        model=LLM_MODEL_NAME,
        api_key=GOOGLE_API_KEY,
        temperature=0.1,
        max_output_tokens=512,
        convert_system_message_to_human=True,
    )

    prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=PROMPT_TMPL,
    )

    # map_reduce keeps per-call size manageable and robust
    qa = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        chain_type_kwargs={"prompt": prompt},
        return_source_documents=True,
    )
    return qa


def main():
    parser = argparse.ArgumentParser(description="Run recruiter Q/A over a saved FAISS index.")
    args = parser.parse_args()

    retriever = load_retriever(INDEX_DIR)

    chain = build_chain_gemini(retriever)

    print("\My Profile Chatbot ready. Ask about me.")
    print("Type 'exit' to quit.\n")

    while True:
        try:
            q = input("You: ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\nBye!")
            break
        if not q:
            continue
        if q.lower() in {"exit", "quit", "q"}:
            print("Bye!")
            break

        try:
            res = chain.invoke({"query": q})
            answer = res["result"] if isinstance(res, dict) else str(res)
        except Exception as e:
            answer = f"[error] {e}"

        print("\nMaheen:", textwrap.fill(answer, width=100))
        print()

if __name__ == "__main__":
    main()