import streamlit as st from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration from datasets import load_dataset # Load dataset (pubmed_qa) and tokenizer dataset = load_dataset("pubmed_qa", split="test") tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset") # Initialize the RAG model model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") # Define Streamlit app st.title('Medical QA Assistant') st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.") # User input for query user_query = st.text_input("Ask a medical question:") if user_query: # Tokenize input question and retrieve related documents inputs = tokenizer(user_query, return_tensors="pt") input_ids = inputs['input_ids'] question_encoder_outputs = model.question_encoder(input_ids) # Use the retriever to get context retrieved_docs = retriever.retrieve(input_ids) # Generate an answer based on the context generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs) answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Show the answer st.write(f"Answer: {answer}") # Display the most relevant documents st.subheader("Relevant Documents:") for doc in retrieved_docs: st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc