assignment5 / app.py
rahideer's picture
Update app.py
c54ca35 verified
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
# Load dataset (pubmed_qa) and tokenizer
dataset = load_dataset("pubmed_qa", split="test")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset")
# Initialize the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
# Define Streamlit app
st.title('Medical QA Assistant')
st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.")
# User input for query
user_query = st.text_input("Ask a medical question:")
if user_query:
# Tokenize input question and retrieve related documents
inputs = tokenizer(user_query, return_tensors="pt")
input_ids = inputs['input_ids']
question_encoder_outputs = model.question_encoder(input_ids)
# Use the retriever to get context
retrieved_docs = retriever.retrieve(input_ids)
# Generate an answer based on the context
generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Show the answer
st.write(f"Answer: {answer}")
# Display the most relevant documents
st.subheader("Relevant Documents:")
for doc in retrieved_docs:
st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc