Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,561 Bytes
d519be4 cb3dcae d519be4 cb3dcae d519be4 cb3dcae d519be4 cb3dcae d519be4 3f5217a d519be4 cb3dcae aa6ef3d 4701923 d519be4 aa6ef3d d519be4 cb34a9e d519be4 cb3dcae d519be4 aa6ef3d d519be4 cb3dcae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import os
import logging
from src.vectorstore import VectorStore
# from langchain.tools import tool
class ToolBase(BaseModel, ABC):
@abstractmethod
def invoke(cls, input: Dict):
pass
@classmethod
def to_openai_tool(cls):
"""
Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
Formats it into a structure similar to OpenAI's function metadata.
"""
function_metadata = {
"type": "function",
"function": {
"name": cls.__name__, # Function name is same as the class name, in lowercase
"description": cls.__doc__.strip(),
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
# Iterate over the fields to add them to the parameters
for field_name, field_info in cls.model_fields.items():
# Field properties
field_type = "string" # Default to string, will adjust if it's a different type
annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
has_none = False
if get_origin(annotation) is UnionType: # Check if it's a Union type
args = get_args(annotation)
if type(None) in args:
has_none = True
args = [arg for arg in args if type(None) != arg]
if len(args) > 1:
raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
elif len(args) == 0:
raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
else:
annotation = args[0]
if annotation == int:
field_type = "integer"
elif annotation == bool:
field_type = "boolean"
# Add the field's description and type to the properties
function_metadata["function"]["parameters"]["properties"][field_name] = {
"type": field_type,
"description": field_info.description,
}
# Determine if the field is required (not Optional or None)
if field_info.is_required():
function_metadata["function"]["parameters"]["required"].append(field_name)
has_none = True
# If there's an enum (like for `unit`), add it to the properties
if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
if not has_none:
function_metadata["function"]["parameters"]["required"].append(field_name)
return function_metadata
tools: Dict[str, ToolBase] = {}
oitools = []
vector_store = VectorStore(
# embeddings_model="BAAI/bge-m3",
embeddings_model=os.environ.get("EMBEDDINGS_MODEL"),
vs_local_path=os.environ.get("VS_LOCAL_PATH"),
vs_hf_path=os.environ.get("VS_HF_PATH"),
number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3))
)
def tool_register(cls: BaseModel):
oaitool = cls.to_openai_tool()
oitools.append(oaitool)
tools[oaitool["function"]["name"]] = cls
@tool_register
class retrieve_wiki_data(ToolBase):
"""Retrieves relevant information from wikipedia, based on the user's query."""
logging.info("@tool_register: retrieve_wiki_data()")
query: str = Field(description="The user's input or question, used to search Wikipedia.")
logging.info(f"query: {query}")
@classmethod
def invoke(cls, input: Dict) -> str:
logging.info(f"retrieve_wiki_data.invoke() input: {input}")
# Check if the input is a dictionary
query = input.get("query", None)
if not query:
return "Missing required argument: query."
# return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!"
return vector_store.get_context(query)
|