File size: 435 Bytes
b8f273f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import RagRetriever, RagTokenizer

# 自动从Hub加载索引
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base",
    index_name="custom",
    index_path="GOGO198/GOGO_dataset"
)

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")

def answer_question(question):
    inputs = tokenizer(question, return_tensors="pt")
    outputs = retriever(**inputs)
    return outputs['answer']