Spaces:
Sleeping
Sleeping
from transformers.agents.tools import Tool | |
from data import get_data | |
class SquadRetrieverTool(Tool): | |
name = "squad_retriever" | |
description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD). | |
Because this tool does not remember context from previous queries, be sure to include | |
as many details as possible in your query. | |
""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query. Be sure to pass this as a keyword argument and not a dictionary.", | |
}, | |
} | |
output_type = "string" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.data = get_data(download=True) | |
self.retriever = self.data.index.as_retriever() | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
responses = self.retriever.retrieve(query) | |
if len(responses) == 0: | |
return "No documents found for this query." | |
return "===Document===\n" + "\n===Document===\n".join( | |
[ | |
f"{response.text}\nScore: {response.score}" | |
for response in responses | |
] | |
) | |
class SquadQueryTool(Tool): | |
name = "squad_query" | |
description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD). | |
Because this tool does not remember context from previous queries, be sure to include | |
as many details as possible in your query.""" | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The question. Be sure to pass this as a keyword argument and not a dictionary.", | |
}, | |
} | |
output_type = "string" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.data = get_data(download=True) | |
self.query_engine = self.data.index.as_query_engine() | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
response = self.query_engine.query(query) | |
if len(response.response) == 0: | |
return "No answer found for this query." | |
return "Query Response:\n\n" + "\n===Response===\n".join([response.response]) | |