Spaces:
Sleeping
Sleeping
import json | |
import os | |
from typing import Any, Dict, List, Type, Union | |
import openai | |
import weave | |
from openai import AsyncOpenAI | |
from pydantic import BaseModel | |
from app.utils.converter import product_data_to_str | |
from app.utils.image_processing import ( | |
get_data_format, | |
get_image_base64_and_type, | |
get_image_data, | |
) | |
from app.utils.logger import exception_to_str, setup_logger | |
from ..config import get_settings | |
from ..core import errors | |
from ..core.errors import BadRequestError, VendorError | |
from ..core.prompts import get_prompts | |
from .base import BaseAttributionService | |
ENV = os.getenv("ENV", "LOCAL") | |
if ENV == "LOCAL": # local or demo | |
weave_project_name = "cfai/attribution-exp" | |
elif ENV == "DEV": | |
weave_project_name = "cfai/attribution-dev" | |
elif ENV == "UAT": | |
weave_project_name = "cfai/attribution-uat" | |
elif ENV == "PROD": | |
pass | |
if ENV != "PROD": | |
#Disabled for now | |
#weave.init(project_name=weave_project_name) | |
print("something") | |
settings = get_settings() | |
prompts = get_prompts() | |
logger = setup_logger(__name__) | |
def get_response_format(json_schema: dict[str, any]) -> dict[str, any]: | |
# OpenAI requires each $def have to have additionalProperties set to False | |
json_schema["additionalProperties"] = False | |
# check if the schema has a $defs key | |
if "$defs" in json_schema: | |
for keys in json_schema["$defs"].keys(): | |
json_schema["$defs"][keys]["additionalProperties"] = False | |
response_format = { | |
"type": "json_schema", | |
"json_schema": {"strict": True, "name": "GarmentSchema", "schema": json_schema}, | |
} | |
return response_format | |
class OpenAIService(BaseAttributionService): | |
def __init__(self): | |
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) | |
async def extract_attributes( | |
self, | |
attributes_model: Type[BaseModel], | |
ai_model: str, | |
img_urls: List[str], | |
product_taxonomy: str, | |
product_data: Dict[str, Union[str, List[str]]], | |
pil_images: List[Any] = None, # do not remove, this is for weave | |
img_paths: List[str] = None, | |
data: Dict[str, Any] = None, | |
) -> Dict[str, Any]: | |
print("Prompt: ") | |
print(prompts.EXTRACT_INFO_HUMAN_MESSAGE.replace("{product_data}", product_data_to_str(product_data))) | |
text_content = [ | |
{ | |
"type": "text", | |
"text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.replace("{product_data}", product_data_to_str(product_data)), | |
}, | |
] | |
if img_urls is not None: | |
base64_data_list = [] | |
data_format_list = [] | |
for img_url in img_urls: | |
base64_data, data_format = get_image_base64_and_type(img_url) | |
base64_data_list.append(base64_data) | |
data_format_list.append(data_format) | |
image_content = [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/{data_format};base64,{base64_data}", | |
}, | |
} | |
for base64_data, data_format in zip(base64_data_list, data_format_list) | |
] | |
elif img_paths is not None: | |
image_content = [ | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}", | |
}, | |
} | |
for img_path in img_paths | |
] | |
try: | |
logger.info("Extracting info via OpenAI...") | |
response = await self.client.beta.chat.completions.parse( | |
model=ai_model, | |
messages=[ | |
{ | |
"role": "system", | |
"content": prompts.EXTRACT_INFO_SYSTEM_MESSAGE, | |
}, | |
{ | |
"role": "user", | |
"content": text_content + image_content, | |
}, | |
], | |
max_tokens=1000, | |
response_format=attributes_model, | |
logprobs=False, | |
# top_logprobs=2, | |
# temperature=0.0, | |
top_p=1e-45, | |
) | |
except openai.BadRequestError as e: | |
error_message = exception_to_str(e) | |
raise BadRequestError(error_message) | |
except Exception as e: | |
raise VendorError( | |
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) | |
) | |
try: | |
content = response.choices[0].message.content | |
parsed_data = json.loads(content) | |
except: | |
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON) | |
return parsed_data | |
async def follow_schema( | |
self, schema: Dict[str, Any], data: Dict[str, Any] | |
) -> Dict[str, Any]: | |
logger.info("Following structure via OpenAI...") | |
text_content = [ | |
{ | |
"type": "text", | |
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.replace("{json_info}", json.dumps(data)), | |
}, | |
] | |
try: | |
response = await self.client.beta.chat.completions.parse( | |
model=ai_model, | |
messages=[ | |
{ | |
"role": "system", | |
"content": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE, | |
}, | |
{ | |
"role": "user", | |
"content": text_content, | |
}, | |
], | |
max_tokens=1000, | |
response_format=get_response_format(schema), | |
logprobs=False, | |
# top_logprobs=2, | |
temperature=0.0, | |
) | |
except Exception as e: | |
raise VendorError( | |
errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) | |
) | |
if response.choices[0].message.refusal: | |
logger.info("OpenAI refused to respond to the request") | |
return {"status": "refused"} | |
try: | |
content = response.choices[0].message.content | |
parsed_data = json.loads(content) | |
except: | |
raise ValueError(errors.VENDOR_ERROR_INVALID_JSON) | |
return parsed_data | |