File size: 1,541 Bytes
e0edce4
f0cab08
 
e0edce4
c54ca35
f0cab08
 
c54ca35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0cab08
c54ca35
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import streamlit as st
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset

# Load dataset (pubmed_qa) and tokenizer
dataset = load_dataset("pubmed_qa", split="test")
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset")

# Initialize the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")

# Define Streamlit app
st.title('Medical QA Assistant')

st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.")

# User input for query
user_query = st.text_input("Ask a medical question:")

if user_query:
    # Tokenize input question and retrieve related documents
    inputs = tokenizer(user_query, return_tensors="pt")
    input_ids = inputs['input_ids']
    question_encoder_outputs = model.question_encoder(input_ids)

    # Use the retriever to get context
    retrieved_docs = retriever.retrieve(input_ids)

    # Generate an answer based on the context
    generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs)
    answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Show the answer
    st.write(f"Answer: {answer}")

    # Display the most relevant documents
    st.subheader("Relevant Documents:")
    for doc in retrieved_docs:
        st.write(doc['text'][:300] + '...')  # Display first 300 characters of each doc