wiki_tools / tools.py
ankush13r's picture
Update tools.py
13b03f0 verified
raw
history blame
4.46 kB
from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import os
import logging
from src.vectorstore import VectorStore
# from langchain.tools import tool
class ToolBase(BaseModel, ABC):
@abstractmethod
def invoke(cls, input: Dict):
pass
@classmethod
def to_openai_tool(cls):
"""
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
Formats it into a structure similar to OpenAI's function metadata.
"""
function_metadata = {
"type": "function",
"function": {
"name": cls.__name__, # Function name is same as the class name, in lowercase
"description": cls.__doc__.strip(),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
# Iterate over the fields to add them to the parameters
for field_name, field_info in cls.model_fields.items():
# Field properties
field_type = "string" # Default to string, will adjust if it's a different type
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
has_none = False
if get_origin(annotation) is UnionType: # Check if it's a Union type
args = get_args(annotation)
if type(None) in args:
has_none = True
args = [arg for arg in args if type(None) != arg]
if len(args) > 1:
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
elif len(args) == 0:
raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
else:
annotation = args[0]
if annotation == int:
field_type = "integer"
elif annotation == bool:
field_type = "boolean"
# Add the field's description and type to the properties
function_metadata["function"]["parameters"]["properties"][field_name] = {
"type": field_type,
"description": field_info.description,
}
# Determine if the field is required (not Optional or None)
if field_info.is_required():
function_metadata["function"]["parameters"]["required"].append(field_name)
has_none = True
# If there's an enum (like for `unit`), add it to the properties
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
if not has_none:
function_metadata["function"]["parameters"]["required"].append(field_name)
return function_metadata
tools: Dict[str, ToolBase] = {}
oitools = []
vector_store = VectorStore(
# embeddings_model="BAAI/bge-m3",
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"),
vs_local_path=os.environ.get("VS_LOCAL_PATH"),
vs_hf_path=os.environ.get("VS_HF_PATH"))
def tool_register(cls: BaseModel):
oaitool = cls.to_openai_tool()
oitools.append(oaitool)
tools[oaitool["function"]["name"]] = cls
@tool_register
class get_documents(ToolBase):
"""Retrieves general information from Wikipedia based on the user's query. """
logging.info("@tool_register: get_documents()")
query: str = Field(description="Search query to retrieve relevant documents.")
logging.info(f"query: {query}")
@classmethod
def invoke(cls, input: Dict) -> str:
logging.info(f"get_documents.invoke() input: {input}")
# Check if the input is a dictionary
query = input.get("query", None)
if not query:
return "Missing required argument: query."
# return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!"
return vector_store.get_context(query)