from transformers import RagRetriever, RagTokenizer | |
# 自动从Hub加载索引 | |
retriever = RagRetriever.from_pretrained( | |
"facebook/rag-token-base", | |
index_name="custom", | |
index_path="GOGO198/GOGO_dataset" | |
) | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") | |
def answer_question(question): | |
inputs = tokenizer(question, return_tensors="pt") | |
outputs = retriever(**inputs) | |
return outputs['answer'] |