Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from abc import ABC, abstractmethod | |
from typing import Dict, Union, get_origin, get_args | |
from pydantic import BaseModel, Field | |
from types import UnionType | |
import os | |
import logging | |
from src.vectorstore import VectorStore | |
# from langchain.tools import tool | |
class ToolBase(BaseModel, ABC): | |
def invoke(cls, input: Dict): | |
pass | |
def to_openai_tool(cls): | |
""" | |
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions. | |
Formats it into a structure similar to OpenAI's function metadata. | |
""" | |
function_metadata = { | |
"type": "function", | |
"function": { | |
"name": cls.__name__, # Function name is same as the class name, in lowercase | |
"description": cls.__doc__.strip(), | |
"parameters": { | |
"type": "object", | |
"properties": {}, | |
"required": [], | |
}, | |
}, | |
} | |
# Iterate over the fields to add them to the parameters | |
for field_name, field_info in cls.model_fields.items(): | |
# Field properties | |
field_type = "string" # Default to string, will adjust if it's a different type | |
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation | |
has_none = False | |
if get_origin(annotation) is UnionType: # Check if it's a Union type | |
args = get_args(annotation) | |
if type(None) in args: | |
has_none = True | |
args = [arg for arg in args if type(None) != arg] | |
if len(args) > 1: | |
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None") | |
elif len(args) == 0: | |
raise TypeError("There must be a valid type (str, int, bool, etc) not only None") | |
else: | |
annotation = args[0] | |
if annotation == int: | |
field_type = "integer" | |
elif annotation == bool: | |
field_type = "boolean" | |
# Add the field's description and type to the properties | |
function_metadata["function"]["parameters"]["properties"][field_name] = { | |
"type": field_type, | |
"description": field_info.description, | |
} | |
# Determine if the field is required (not Optional or None) | |
if field_info.is_required(): | |
function_metadata["function"]["parameters"]["required"].append(field_name) | |
has_none = True | |
# If there's an enum (like for `unit`), add it to the properties | |
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list): | |
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default | |
if not has_none: | |
function_metadata["function"]["parameters"]["required"].append(field_name) | |
return function_metadata | |
tools: Dict[str, ToolBase] = {} | |
oitools = [] | |
vector_store = VectorStore( | |
# embeddings_model="BAAI/bge-m3", | |
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"), | |
vs_local_path=os.environ.get("VS_LOCAL_PATH"), | |
vs_hf_path=os.environ.get("VS_HF_PATH"), | |
number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3)) | |
) | |
def tool_register(cls: BaseModel): | |
oaitool = cls.to_openai_tool() | |
oitools.append(oaitool) | |
tools[oaitool["function"]["name"]] = cls | |
class retrieve_wiki_data(ToolBase): | |
"""Retrieves relevant information from wikipedia, based on the user's query.""" | |
logging.info("@tool_register: retrieve_wiki_data()") | |
query: str = Field(description="The user's input or question, used to search Wikipedia.") | |
logging.info(f"query: {query}") | |
def invoke(cls, input: Dict) -> str: | |
logging.info(f"retrieve_wiki_data.invoke() input: {input}") | |
# Check if the input is a dictionary | |
query = input.get("query", None) | |
if not query: | |
return "Missing required argument: query." | |
# return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!" | |
return vector_store.get_context(query) | |