wiki_tools / tools.py
nurasaki's picture
Changed num context retrieved
3f5217a
from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import os
import logging
from src.vectorstore import VectorStore
# from langchain.tools import tool
class ToolBase(BaseModel, ABC):
@abstractmethod
def invoke(cls, input: Dict):
pass
@classmethod
def to_openai_tool(cls):
"""
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
Formats it into a structure similar to OpenAI's function metadata.
"""
function_metadata = {
"type": "function",
"function": {
"name": cls.__name__, # Function name is same as the class name, in lowercase
"description": cls.__doc__.strip(),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
# Iterate over the fields to add them to the parameters
for field_name, field_info in cls.model_fields.items():
# Field properties
field_type = "string" # Default to string, will adjust if it's a different type
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
has_none = False
if get_origin(annotation) is UnionType: # Check if it's a Union type
args = get_args(annotation)
if type(None) in args:
has_none = True
args = [arg for arg in args if type(None) != arg]
if len(args) > 1:
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
elif len(args) == 0:
raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
else:
annotation = args[0]
if annotation == int:
field_type = "integer"
elif annotation == bool:
field_type = "boolean"
# Add the field's description and type to the properties
function_metadata["function"]["parameters"]["properties"][field_name] = {
"type": field_type,
"description": field_info.description,
}
# Determine if the field is required (not Optional or None)
if field_info.is_required():
function_metadata["function"]["parameters"]["required"].append(field_name)
has_none = True
# If there's an enum (like for `unit`), add it to the properties
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
if not has_none:
function_metadata["function"]["parameters"]["required"].append(field_name)
return function_metadata
tools: Dict[str, ToolBase] = {}
oitools = []
vector_store = VectorStore(
# embeddings_model="BAAI/bge-m3",
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"),
vs_local_path=os.environ.get("VS_LOCAL_PATH"),
vs_hf_path=os.environ.get("VS_HF_PATH"),
number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3))
)
def tool_register(cls: BaseModel):
oaitool = cls.to_openai_tool()
oitools.append(oaitool)
tools[oaitool["function"]["name"]] = cls
@tool_register
class retrieve_wiki_data(ToolBase):
"""Retrieves relevant information from wikipedia, based on the user's query."""
logging.info("@tool_register: retrieve_wiki_data()")
query: str = Field(description="The user's input or question, used to search Wikipedia.")
logging.info(f"query: {query}")
@classmethod
def invoke(cls, input: Dict) -> str:
logging.info(f"retrieve_wiki_data.invoke() input: {input}")
# Check if the input is a dictionary
query = input.get("query", None)
if not query:
return "Missing required argument: query."
# return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!"
return vector_store.get_context(query)