Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from datasets import load_dataset | |
# Load dataset (pubmed_qa) and tokenizer | |
dataset = load_dataset("pubmed_qa", split="test") | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset") | |
# Initialize the RAG model | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") | |
# Define Streamlit app | |
st.title('Medical QA Assistant') | |
st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.") | |
# User input for query | |
user_query = st.text_input("Ask a medical question:") | |
if user_query: | |
# Tokenize input question and retrieve related documents | |
inputs = tokenizer(user_query, return_tensors="pt") | |
input_ids = inputs['input_ids'] | |
question_encoder_outputs = model.question_encoder(input_ids) | |
# Use the retriever to get context | |
retrieved_docs = retriever.retrieve(input_ids) | |
# Generate an answer based on the context | |
generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs) | |
answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
# Show the answer | |
st.write(f"Answer: {answer}") | |
# Display the most relevant documents | |
st.subheader("Relevant Documents:") | |
for doc in retrieved_docs: | |
st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc | |