from abc import ABC, abstractmethod from typing import Dict, Union, get_origin, get_args from pydantic import BaseModel, Field from types import UnionType import os import logging from src.vectorstore import VectorStore # from langchain.tools import tool class ToolBase(BaseModel, ABC): @abstractmethod def invoke(cls, input: Dict): pass @classmethod def to_openai_tool(cls): """ Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions. Formats it into a structure similar to OpenAI's function metadata. """ function_metadata = { "type": "function", "function": { "name": cls.__name__, # Function name is same as the class name, in lowercase "description": cls.__doc__.strip(), "parameters": { "type": "object", "properties": {}, "required": [], }, }, } # Iterate over the fields to add them to the parameters for field_name, field_info in cls.model_fields.items(): # Field properties field_type = "string" # Default to string, will adjust if it's a different type annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation has_none = False if get_origin(annotation) is UnionType: # Check if it's a Union type args = get_args(annotation) if type(None) in args: has_none = True args = [arg for arg in args if type(None) != arg] if len(args) > 1: raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None") elif len(args) == 0: raise TypeError("There must be a valid type (str, int, bool, etc) not only None") else: annotation = args[0] if annotation == int: field_type = "integer" elif annotation == bool: field_type = "boolean" # Add the field's description and type to the properties function_metadata["function"]["parameters"]["properties"][field_name] = { "type": field_type, "description": field_info.description, } # Determine if the field is required (not Optional or None) if field_info.is_required(): function_metadata["function"]["parameters"]["required"].append(field_name) has_none = True # If there's an enum (like for `unit`), add it to the properties if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list): function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default if not has_none: function_metadata["function"]["parameters"]["required"].append(field_name) return function_metadata tools: Dict[str, ToolBase] = {} oitools = [] vector_store = VectorStore( # embeddings_model="BAAI/bge-m3", embeddings_model=os.environ.get("EMBEDDINGS_MODEL"), vs_local_path=os.environ.get("VS_LOCAL_PATH"), vs_hf_path=os.environ.get("VS_HF_PATH"), number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3)) ) def tool_register(cls: BaseModel): oaitool = cls.to_openai_tool() oitools.append(oaitool) tools[oaitool["function"]["name"]] = cls @tool_register class retrieve_wiki_data(ToolBase): """Retrieves relevant information from wikipedia, based on the user's query.""" logging.info("@tool_register: retrieve_wiki_data()") query: str = Field(description="The user's input or question, used to search Wikipedia.") logging.info(f"query: {query}") @classmethod def invoke(cls, input: Dict) -> str: logging.info(f"retrieve_wiki_data.invoke() input: {input}") # Check if the input is a dictionary query = input.get("query", None) if not query: return "Missing required argument: query." # return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!" return vector_store.get_context(query)