|
import os |
|
from dotenv import load_dotenv |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_ollama import ChatOllama |
|
from langchain.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.runnables import Runnable |
|
from langchain_core.documents import Document |
|
from langchain_core.embeddings import Embeddings |
|
import chromadb |
|
import numpy as np |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import pandas as pd |
|
from typing import Optional, List |
|
import re |
|
import torch |
|
import subprocess |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
from transformers import pipeline |
|
from langchain.llms import HuggingFacePipeline |
|
|
|
|
|
os.environ["ANONYMIZED_TELEMETRY"] = "False" |
|
os.environ["CHROMA_SERVER_HOST"] = "localhost" |
|
os.environ["CHROMA_SERVER_HTTP_PORT"] = "8000" |
|
|
|
|
|
class ImprovedTFIDFEmbeddings(Embeddings): |
|
"""Improved TF-IDF based embedding function with better preprocessing.""" |
|
|
|
def __init__(self): |
|
self.vectorizer = TfidfVectorizer( |
|
max_features=5000, |
|
stop_words='english', |
|
ngram_range=(1, 3), |
|
min_df=1, |
|
max_df=0.85, |
|
lowercase=True, |
|
strip_accents='unicode', |
|
analyzer='word' |
|
) |
|
self.fitted = False |
|
self.documents = [] |
|
|
|
def embed_documents(self, texts): |
|
"""Create embeddings for a list of texts.""" |
|
if not self.fitted: |
|
self.documents = texts |
|
self.vectorizer.fit(texts) |
|
self.fitted = True |
|
|
|
|
|
tfidf_matrix = self.vectorizer.transform(texts) |
|
|
|
|
|
embeddings = [] |
|
for i in range(tfidf_matrix.shape[0]): |
|
embedding = tfidf_matrix[i].toarray().flatten() |
|
|
|
norm = np.linalg.norm(embedding) |
|
if norm > 0: |
|
embedding = embedding / norm |
|
|
|
if len(embedding) < 512: |
|
embedding = np.pad(embedding, (0, 512 - len(embedding))) |
|
else: |
|
embedding = embedding[:512] |
|
embeddings.append(embedding.tolist()) |
|
|
|
return embeddings |
|
|
|
def embed_query(self, text): |
|
"""Create embedding for a single query text.""" |
|
if not self.fitted: |
|
|
|
self.vectorizer.fit([text]) |
|
self.fitted = True |
|
|
|
|
|
tfidf_matrix = self.vectorizer.transform([text]) |
|
embedding = tfidf_matrix[0].toarray().flatten() |
|
|
|
|
|
norm = np.linalg.norm(embedding) |
|
if norm > 0: |
|
embedding = embedding / norm |
|
|
|
|
|
if len(embedding) < 512: |
|
embedding = np.pad(embedding, (0, 512 - len(embedding))) |
|
else: |
|
embedding = embedding[:512] |
|
|
|
return embedding.tolist() |
|
|
|
|
|
class SmartFAQRetriever(BaseRetriever): |
|
"""Smart retriever optimized for FAQ datasets with semantic similarity.""" |
|
|
|
def __init__(self, documents: List[Document], k: int = 4): |
|
super().__init__() |
|
self._documents = documents |
|
self._k = k |
|
self._vectorizer = None |
|
|
|
@property |
|
def documents(self): |
|
return self._documents |
|
|
|
@property |
|
def k(self): |
|
return self._k |
|
|
|
def get_documents_with_confidence(self, query: str) -> List[dict]: |
|
"""Return top documents and their confidence (similarity) scores.""" |
|
results = self._get_relevant_documents_with_scores(query) |
|
return [{"document": doc.page_content, "confidence": round(score, 3)} for doc, score in results] |
|
|
|
|
|
|
|
def _get_relevant_documents_with_scores(self, query: str) -> List[tuple[Document, float]]: |
|
"""Retrieve documents along with similarity scores.""" |
|
if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_: |
|
self._vectorizer = TfidfVectorizer( |
|
max_features=3000, |
|
stop_words='english', |
|
ngram_range=(1, 2), |
|
min_df=1, |
|
max_df=0.9 |
|
) |
|
questions = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip() |
|
if "QUESTION:" in doc.page_content else doc.page_content |
|
for doc in self._documents] |
|
self._vectorizer.fit(questions) |
|
|
|
query_vector = self._vectorizer.transform([query.lower().strip()]) |
|
question_texts = [doc.page_content.split("ANSWER:")[0].replace("QUESTION:", "").strip() |
|
if "QUESTION:" in doc.page_content else doc.page_content |
|
for doc in self._documents] |
|
question_vectors = self._vectorizer.transform(question_texts) |
|
similarities = cosine_similarity(query_vector, question_vectors).flatten() |
|
|
|
top_indices = similarities.argsort()[-self._k:][::-1] |
|
return [(self._documents[i], float(similarities[i])) for i in top_indices if similarities[i] > 0.1] |
|
|
|
|
|
def _get_relevant_documents(self, query: str) -> List[Document]: |
|
"""Retrieve documents based on semantic similarity.""" |
|
|
|
if not hasattr(self, '_vectorizer') or self._vectorizer is None or not hasattr(self._vectorizer, 'vocabulary_') or not self._vectorizer.vocabulary_: |
|
print("[SmartFAQRetriever] Fitting vectorizer...") |
|
self._vectorizer = TfidfVectorizer( |
|
max_features=3000, |
|
stop_words='english', |
|
ngram_range=(1, 2), |
|
min_df=1, |
|
max_df=0.9 |
|
) |
|
questions = [] |
|
for doc in self._documents: |
|
if "QUESTION:" in doc.page_content: |
|
question_part = doc.page_content.split("ANSWER:")[0] |
|
question = question_part.replace("QUESTION:", "").strip() |
|
questions.append(question) |
|
else: |
|
questions.append(doc.page_content) |
|
self._vectorizer.fit(questions) |
|
query_lower = query.lower().strip() |
|
|
|
|
|
questions = [] |
|
for doc in self._documents: |
|
if "QUESTION:" in doc.page_content: |
|
question_part = doc.page_content.split("ANSWER:")[0] |
|
question = question_part.replace("QUESTION:", "").strip() |
|
questions.append(question) |
|
else: |
|
questions.append(doc.page_content) |
|
|
|
|
|
query_vector = self._vectorizer.transform([query_lower]) |
|
question_vectors = self._vectorizer.transform(questions) |
|
|
|
|
|
similarities = cosine_similarity(query_vector, question_vectors).flatten() |
|
|
|
|
|
top_indices = similarities.argsort()[-self._k:][::-1] |
|
|
|
|
|
relevant_docs = [self._documents[i] for i in top_indices if similarities[i] > 0.1] |
|
|
|
if not relevant_docs: |
|
|
|
relevant_docs = self._documents[:self._k] |
|
|
|
return relevant_docs |
|
|
|
async def _aget_relevant_documents(self, query: str) -> List[Document]: |
|
"""Async version of get_relevant_documents.""" |
|
return self._get_relevant_documents(query) |
|
|
|
def setup_retriever(use_kaggle_data: bool = False, kaggle_dataset: Optional[str] = None, |
|
kaggle_username: Optional[str] = None, kaggle_key: Optional[str] = None, |
|
use_local_mental_health_data: bool = False) -> BaseRetriever: |
|
""" |
|
Creates a vector store with documents from test data, Kaggle datasets, or local mental health data. |
|
|
|
Args: |
|
use_kaggle_data: Whether to load Kaggle data instead of test documents |
|
kaggle_dataset: Kaggle dataset name (e.g., 'username/dataset-name') |
|
kaggle_username: Your Kaggle username (optional if using kaggle.json) |
|
kaggle_key: Your Kaggle API key (optional if using kaggle.json) |
|
use_local_mental_health_data: Whether to load local mental health FAQ data |
|
""" |
|
print("Setting up the retriever...") |
|
|
|
if use_local_mental_health_data: |
|
try: |
|
print("Loading mental health FAQ data from local file...") |
|
mental_health_file = "data/Mental_Health_FAQ.csv" |
|
|
|
if not os.path.exists(mental_health_file): |
|
print(f"Mental health FAQ file not found: {mental_health_file}") |
|
use_local_mental_health_data = False |
|
else: |
|
|
|
df = pd.read_csv(mental_health_file) |
|
documents = [] |
|
|
|
for _, row in df.iterrows(): |
|
question = row['Questions'] |
|
answer = row['Answers'] |
|
|
|
content = f"QUESTION: {question}\nANSWER: {answer}" |
|
documents.append(Document(page_content=content)) |
|
|
|
print(f"Loaded {len(documents)} mental health FAQ documents") |
|
for i, doc in enumerate(documents[:3]): |
|
print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...") |
|
|
|
except Exception as e: |
|
print(f"Error loading mental health data: {e}") |
|
use_local_mental_health_data = False |
|
|
|
if use_kaggle_data and kaggle_dataset: |
|
try: |
|
from src.kaggle_loader import KaggleDataLoader |
|
|
|
print(f"Loading Kaggle dataset: {kaggle_dataset}") |
|
|
|
loader = KaggleDataLoader() |
|
|
|
|
|
dataset_path = loader.download_dataset(kaggle_dataset) |
|
|
|
|
|
documents = [] |
|
|
|
|
|
dataset_name = kaggle_dataset.split('/')[-1] |
|
print(f"Processing files in dataset directory: {dataset_path}") |
|
|
|
for file in os.listdir(dataset_path): |
|
file_path = os.path.join(dataset_path, file) |
|
|
|
if file.endswith('.csv'): |
|
print(f"Loading CSV file: {file}") |
|
|
|
if 'faq' in file.lower() or 'mental' in file.lower(): |
|
documents.extend(loader.load_csv_dataset(file_path, [], chunk_size=50)) |
|
else: |
|
|
|
df = pd.read_csv(file_path) |
|
text_columns = df.columns[:3].tolist() |
|
documents.extend(loader.load_csv_dataset(file_path, text_columns, chunk_size=50)) |
|
|
|
elif file.endswith('.json'): |
|
print(f"Loading JSON file: {file}") |
|
documents.extend(loader.load_json_dataset(file_path)) |
|
|
|
elif file.endswith('.txt'): |
|
print(f"Loading text file: {file}") |
|
documents.extend(loader.load_text_dataset(file_path)) |
|
|
|
print(f"Loaded {len(documents)} documents from Kaggle dataset") |
|
for i, doc in enumerate(documents[:3]): |
|
print(f"Sample doc {i+1}: {doc.page_content[:200]}") |
|
|
|
except Exception as e: |
|
print(f"Error loading Kaggle data: {e}") |
|
print("Falling back to test documents...") |
|
use_kaggle_data = False |
|
|
|
if not use_kaggle_data and not use_local_mental_health_data: |
|
|
|
print("No specific data source specified, loading mental health FAQ data as default...") |
|
try: |
|
mental_health_file = "data/Mental_Health_FAQ.csv" |
|
|
|
if not os.path.exists(mental_health_file): |
|
raise FileNotFoundError(f"Mental health FAQ file not found: {mental_health_file}") |
|
|
|
|
|
df = pd.read_csv(mental_health_file) |
|
documents = [] |
|
|
|
for _, row in df.iterrows(): |
|
question = row['Questions'] |
|
answer = row['Answers'] |
|
|
|
content = f"QUESTION: {question}\nANSWER: {answer}" |
|
documents.append(Document(page_content=content)) |
|
|
|
print(f"Loaded {len(documents)} mental health FAQ documents") |
|
for i, doc in enumerate(documents[:3]): |
|
print(f"Sample FAQ {i+1}: {doc.page_content[:200]}...") |
|
|
|
except Exception as e: |
|
print(f"Error loading mental health data: {e}") |
|
raise Exception("No valid data source available. Please ensure mental health FAQ data is present or provide Kaggle credentials.") |
|
|
|
print("Creating TF-IDF embeddings...") |
|
embeddings = ImprovedTFIDFEmbeddings() |
|
|
|
print("Creating ChromaDB vector store...") |
|
client = chromadb.PersistentClient(path="./tmp/chroma_db") |
|
|
|
|
|
try: |
|
collections = client.list_collections() |
|
for collection in collections: |
|
print(f"Deleting existing collection: {collection.name}") |
|
client.delete_collection(collection.name) |
|
except Exception as e: |
|
print(f"Warning: Could not clear existing collections: {e}") |
|
|
|
print(f"Processing {len(documents)} documents...") |
|
|
|
|
|
if any("QUESTION:" in doc.page_content for doc in documents): |
|
print("Using SmartFAQRetriever for better semantic matching...") |
|
return SmartFAQRetriever(documents, k=4) |
|
else: |
|
|
|
vectorstore = Chroma.from_documents( |
|
documents=documents, |
|
embedding=embeddings, |
|
client=client |
|
) |
|
print("Retriever setup complete.") |
|
return vectorstore.as_retriever(k=4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_rag_chain() -> Runnable: |
|
"""Sets up the RAG chain with a prompt template and an LLM.""" |
|
|
|
prompt = PromptTemplate( |
|
template="""Context: You are a medical information assistant that answers health questions using verified medical documents. |
|
|
|
Primary Task: Answer the medical question using ONLY the provided documents. |
|
|
|
Instructions: |
|
1. For medical questions: Provide a clear, accurate answer based solely on the document content |
|
2. If documents lack sufficient information: "I don't have enough information in the provided documents to answer this question" |
|
3. For non-medical questions: "I specialize in medical information. Please ask a health-related question." |
|
4. For identity questions: "I am a medical information assistant designed to help answer health-related questions based on verified medical documents." |
|
5. Always use patient-friendly language |
|
6. Keep responses 2-4 sentences maximum |
|
7. For serious symptoms, recommend consulting healthcare professionals |
|
|
|
Documents: {documents} |
|
|
|
Question: {question} |
|
|
|
Medical Answer:""", |
|
input_variables=["question", "documents"], |
|
) |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceTB/SmolLM3-3B", |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print(f"Tokenizer pad_token_id: {tokenizer.pad_token_id}") |
|
print(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}") |
|
|
|
|
|
hf_pipeline = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=50, |
|
temperature=0.2, |
|
return_full_text=False, |
|
do_sample=True, |
|
|
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
clean_up_tokenization_spaces=True |
|
) |
|
|
|
|
|
test_input = "What is diabetes?" |
|
print(f"Testing pipeline with: {test_input}") |
|
test_result = hf_pipeline(test_input) |
|
print(f"Pipeline test successful: {test_result}") |
|
|
|
except Exception as e: |
|
print(f"Error setting up BioGPT: {e}") |
|
print("Falling back to DistilGPT-2...") |
|
|
|
|
|
hf_pipeline = pipeline( |
|
"text-generation", |
|
model="distilgpt2", |
|
max_new_tokens=50, |
|
temperature=0.2, |
|
return_full_text=False, |
|
do_sample=True, |
|
clean_up_tokenization_spaces=True |
|
) |
|
|
|
|
|
test_input = "What is diabetes?" |
|
print(f"Testing fallback pipeline with: {test_input}") |
|
test_result = hf_pipeline(test_input) |
|
print(f"Fallback pipeline test successful: {test_result}") |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=hf_pipeline) |
|
|
|
|
|
return prompt | llm | StrOutputParser() |
|
|
|
|
|
|
|
class RAGApplication: |
|
def __init__(self, retriever: BaseRetriever, rag_chain: Runnable): |
|
self.retriever = retriever |
|
self.rag_chain = rag_chain |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run(self, question: str) -> str: |
|
try: |
|
if not question.strip(): |
|
return "Please provide a valid question." |
|
|
|
print(f"\nProcessing question: '{question}'") |
|
|
|
if hasattr(self.retriever, "get_documents_with_confidence"): |
|
docs_with_scores = self.retriever.get_documents_with_confidence(question) |
|
documents = [Document(page_content=d["document"]) for d in docs_with_scores] |
|
confidence_info = "\n".join([f"- Score: {d['confidence']}, Snippet: {d['document'][:100]}..." for d in docs_with_scores]) |
|
else: |
|
documents = self.retriever.invoke(question) |
|
confidence_info = "Confidence scoring not available." |
|
|
|
print(f"Retrieved {len(documents)} documents") |
|
print(confidence_info) |
|
|
|
doc_texts = "\n\n".join([doc.page_content for doc in documents]) |
|
if len(doc_texts) > 500: |
|
doc_texts = doc_texts[:500] + "..." |
|
|
|
answer = self.rag_chain.invoke({"question": question, "documents": doc_texts}) |
|
|
|
|
|
footer = "\n\n(Note: This answer is based on documents with confidence scores. Review full context if critical.)" |
|
return answer.strip() + footer |
|
|
|
except Exception as e: |
|
print(f"Error in RAG application: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question." |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
load_dotenv() |
|
|
|
|
|
retriever = setup_retriever() |
|
rag_chain = setup_rag_chain() |
|
|
|
|
|
rag_application = RAGApplication(retriever, rag_chain) |
|
|
|
|
|
question = "What is terminal illness?" |
|
print("\n--- Running RAG Application ---") |
|
print(f"Question: {question}") |
|
answer = rag_application.run(question) |
|
print(f"Answer: {answer}") |
|
|