medical-qa-assistant / rag_pipeline.py
rahideer's picture
Upload 4 files
381b00d verified
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# Dummy knowledge base - replace with real embeddings in production
knowledge_base = [
{"text": "Aspirin is used to reduce fever and relieve mild to moderate pain.", "embedding": None},
{"text": "Hypertension is a condition in which the blood pressure in the arteries is elevated.", "embedding": None},
{"text": "Diabetes is a chronic condition that affects how the body processes blood sugar.", "embedding": None},
]
def load_rag_pipeline():
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
for entry in knowledge_base:
entry["embedding"] = embedder.encode(entry["text"])
index = faiss.IndexFlatL2(len(knowledge_base[0]["embedding"]))
index.add(np.array([entry["embedding"] for entry in knowledge_base]))
return {"embedder": embedder, "index": index, "texts": [entry["text"] for entry in knowledge_base]}
def ask_question(pipe, query):
query_vec = pipe["embedder"].encode(query)
D, I = pipe["index"].search(np.array([query_vec]), k=1)
context = pipe["texts"][I[0][0]]
generator = pipeline("text2text-generation", model="facebook/bart-large")
answer = generator(f"question: {query} context: {context}", max_length=100, do_sample=False)
return answer[0]['generated_text']