"""Guest Information Retrieval Toll""" from smolagents import Tool from langchain_community.retrievers import BM25Retriever from langchain.docstore.document import Document import datasets class GuestInfoRetrieverTool(Tool): """Derived Class for Guest Information Retrieval Tool""" name = "guest_info_retriever" description = "Retrieves detailed information about gala guests based on their name or relation." # pylint: disable=line-too-long inputs = { "query": { "type": "string", "description": "The name or relation of the guest you want information about.", } } output_type = "string" def __init__(self, docs): # pylint: disable=super-init-not-called self.is_initialized = False self.retriever = BM25Retriever.from_documents(docs) def forward(self, query: str): # pylint: disable=arguments-differ results = self.retriever.get_relevant_documents(query) if results: return "\n\n".join([doc.page_content for doc in results[:3]]) return "No matching guest information found." def guest_info_retriever_factory(): """Get Guest Information Retrieval Tool""" hf_datasets = [ "agents-course/unit3-invitees", "Data-Gem/agents-course-unit3-invitees-expanded", ] guest_dataset = [] for dataset_name in hf_datasets: new_dataset = datasets.load_dataset(dataset_name, split="train") guest_dataset.append(new_dataset) # Convert dataset entries into Document objects docs = [ Document( page_content="\n".join( [ f"Name: {guest['name']}", f"Relation: {guest['relation']}", f"Description: {guest['description']}", f"Email: {guest['email']}", ] ), metadata={"name": guest["name"]}, ) for guest in guest_dataset ] return GuestInfoRetrieverTool(docs)