File size: 435 Bytes
b8f273f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from transformers import RagRetriever, RagTokenizer
# 自动从Hub加载索引
retriever = RagRetriever.from_pretrained(
"facebook/rag-token-base",
index_name="custom",
index_path="GOGO198/GOGO_dataset"
)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
def answer_question(question):
inputs = tokenizer(question, return_tensors="pt")
outputs = retriever(**inputs)
return outputs['answer'] |