Unit_3_Agentic_RAG / guest_info_retriever.py
mikejay14's picture
change inner quotes to single quotes
62d1a65
"""Guest Information Retrieval Toll"""
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from langchain.docstore.document import Document
import datasets
class GuestInfoRetrieverTool(Tool):
"""Derived Class for Guest Information Retrieval Tool"""
name = "guest_info_retriever"
description = "Retrieves detailed information about gala guests based on their name or relation." # pylint: disable=line-too-long
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest you want information about.",
}
}
output_type = "string"
def __init__(self, docs): # pylint: disable=super-init-not-called
self.is_initialized = False
self.retriever = BM25Retriever.from_documents(docs)
def forward(self, query: str): # pylint: disable=arguments-differ
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results[:3]])
return "No matching guest information found."
def guest_info_retriever_factory():
"""Get Guest Information Retrieval Tool"""
hf_datasets = [
"agents-course/unit3-invitees",
"Data-Gem/agents-course-unit3-invitees-expanded",
]
guest_dataset = []
for dataset_name in hf_datasets:
new_dataset = datasets.load_dataset(dataset_name, split="train")
guest_dataset.append(new_dataset)
# Convert dataset entries into Document objects
docs = [
Document(
page_content="\n".join(
[
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}",
]
),
metadata={"name": guest["name"]},
)
for guest in guest_dataset
]
return GuestInfoRetrieverTool(docs)