Spaces:
Sleeping
Sleeping
import os | |
os.environ["HUGGINGFACE_DEMO"] = "1" # set before import from app | |
from dotenv import load_dotenv | |
load_dotenv() | |
################################################################################################ | |
import gradio as gr | |
import uuid | |
import shutil | |
from app.config import get_settings | |
from app.schemas.requests import Attribute | |
from app.request_handler import handle_extract | |
from app.services.factory import AIServiceFactory | |
settings = get_settings() | |
IMAGE_MAX_SIZE = 1536 | |
async def forward_request( | |
attributes, product_taxonomy, product_data, ai_model, pil_images | |
): | |
# prepare temp folder | |
request_id = str(uuid.uuid4()) | |
request_temp_folder = os.path.join("gradio_temp", request_id) | |
os.makedirs(request_temp_folder, exist_ok=True) | |
try: | |
# convert attributes to schema | |
attributes = "attributes_object = {" + attributes + "}" | |
try: | |
attributes = exec(attributes, globals()) | |
except: | |
raise gr.Error( | |
"Invalid `Attribute Schema`. Please insert valid schema following the example." | |
) | |
for key, value in attributes_object.items(): # type: ignore | |
attributes_object[key] = Attribute(**value) # type: ignore | |
if product_data == "": | |
product_data = "{}" | |
product_data_code = f"product_data_object = {product_data}" | |
try: | |
exec(product_data_code, globals()) | |
except: | |
raise gr.Error( | |
"Invalid `Product Data`. Please insert valid dictionary or leave it empty." | |
) | |
if pil_images is None: | |
raise gr.Error("Please upload image(s) of the product") | |
pil_images = [pil_image[0] for pil_image in pil_images] | |
img_paths = [] | |
for i, pil_image in enumerate(pil_images): | |
if max(pil_image.size) > IMAGE_MAX_SIZE: | |
ratio = IMAGE_MAX_SIZE / max(pil_image.size) | |
pil_image = pil_image.resize( | |
(int(pil_image.width * ratio), int(pil_image.height * ratio)) | |
) | |
img_path = os.path.join(request_temp_folder, f"{i}.jpg") | |
if pil_image.mode in ("RGBA", "LA") or ( | |
pil_image.mode == "P" and "transparency" in pil_image.info | |
): | |
pil_image = pil_image.convert("RGBA") | |
if pil_image.getchannel("A").getextrema() == ( | |
255, | |
255, | |
): # if fully opaque, save as JPEG | |
pil_image = pil_image.convert("RGB") | |
image_format = "JPEG" | |
else: | |
image_format = "PNG" | |
else: | |
image_format = "JPEG" | |
pil_image.save(img_path, image_format, quality=100, subsampling=0) | |
img_paths.append(img_path) | |
# mapping | |
if ai_model in settings.OPENAI_MODELS: | |
ai_vendor = "openai" | |
elif ai_model in settings.ANTHROPIC_MODELS: | |
ai_vendor = "anthropic" | |
elif ai_model in settings.GEMINI_MODELS: | |
ai_vendor = "gemini" | |
service = AIServiceFactory.get_service(ai_vendor) | |
try: | |
json_attributes = await service.extract_attributes_with_validation( | |
attributes_object, # type: ignore | |
ai_model, | |
None, | |
product_taxonomy, | |
product_data_object, # type: ignore | |
img_paths=img_paths, | |
) | |
except: | |
raise gr.Error("Failed to extract attributes. Something went wrong.") | |
finally: | |
# remove temp folder anyway | |
shutil.rmtree(request_temp_folder) | |
gr.Info("Process completed!") | |
return json_attributes | |
def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values): | |
schema = f""" | |
"{attr_name}": {{ | |
"description": "{attr_desc}", | |
"data_type": "{attr_type}", | |
"allowed_values": [ | |
{', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')]) if allowed_values != "" else ""} | |
] | |
}}, | |
""" | |
return attributes + schema, "", "", "", "" | |
sample_schema = """"category": { | |
"description": "Category of the garment", | |
"data_type": "list[string]", | |
"allowed_values": [ | |
"upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses" | |
] | |
}, | |
"color": { | |
"description": "Color of the garment", | |
"data_type": "list[string]", | |
"allowed_values": [ | |
"black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other" | |
] | |
}, | |
"pattern": { | |
"description": "Pattern of the garment", | |
"data_type": "list[string]", | |
"allowed_values": [ | |
"plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other" | |
] | |
}, | |
"material": { | |
"description": "Material of the garment", | |
"data_type": "string", | |
"allowed_values": [] | |
} | |
""" | |
description = """ | |
This is a simple demo for Attribution. Follow the steps below: | |
1. Upload image(s) of a product. | |
2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty. | |
3. Select the AI model to use. | |
4. Enter known attributes (optional). | |
5. Enter the attribute schema or use the "Add Attributes" section to add attributes. | |
6. Click "Extract Attributes" to get the extracted attributes. | |
""" | |
product_data_placeholder = """Example: | |
{ | |
"brand": "Leaf", | |
"size": "M", | |
"product_name": "Leaf T-shirt", | |
"color": "red" | |
} | |
""" | |
product_data_value = """ | |
{ | |
"data1": "", | |
"data2": "" | |
} | |
""" | |
with gr.Blocks(title="Internal Demo for Attribution") as demo: | |
with gr.Row(): | |
with gr.Column(scale=12): | |
gr.Markdown( | |
"""<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>""" | |
) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=12): | |
with gr.Row(): | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Upload images of your product here", type="pil" | |
) | |
product_taxnomy = gr.Textbox( | |
label="Product Taxonomy", | |
placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')", | |
lines=1, | |
max_lines=1, | |
) | |
ai_model = gr.Dropdown( | |
label="AI Model", | |
choices=settings.SUPPORTED_MODELS, | |
interactive=True, | |
) | |
product_data = gr.TextArea( | |
label="Product Data (Optional)", | |
placeholder=product_data_placeholder, | |
value=product_data_value.strip(), | |
interactive=True, | |
lines=10, | |
max_lines=10, | |
) | |
# track_count = gr.State(1) | |
# @gr.render(inputs=track_count) | |
# def render_tracks(count): | |
# ka_names = [] | |
# ka_values = [] | |
# with gr.Column(): | |
# for i in range(count): | |
# with gr.Column(variant="panel"): | |
# with gr.Row(): | |
# ka_name = gr.Textbox(placeholder="key", key=f"key-{i}", show_label=False) | |
# ka_value = gr.Textbox(placeholder="data", key=f"data-{i}", show_label=False) | |
# ka_names.append(ka_name) | |
# ka_values.append(ka_value) | |
# add_track_btn = gr.Button("Add Product Data") | |
# remove_track_btn = gr.Button("Remove Product Data") | |
# add_track_btn.click(lambda count: count + 1, track_count, track_count) | |
# remove_track_btn.click(lambda count: count - 1, track_count, track_count) | |
with gr.Column(): | |
attributes = gr.TextArea( | |
label="Attribute Schema", | |
value=sample_schema, | |
placeholder="Enter schema here or use Add Attributes below", | |
interactive=True, | |
lines=30, | |
max_lines=30, | |
) | |
with gr.Accordion("Add Attributes", open=False): | |
attr_name = gr.Textbox( | |
label="Attribute name", placeholder="Enter attribute name" | |
) | |
attr_desc = gr.Textbox( | |
label="Description", placeholder="Enter description" | |
) | |
attr_type = gr.Dropdown( | |
label="Type", | |
choices=[ | |
"string", | |
"list[string]", | |
"int", | |
"list[int]", | |
"float", | |
"list[float]", | |
"bool", | |
"list[bool]", | |
], | |
interactive=True, | |
) | |
allowed_values = gr.Textbox( | |
label="Allowed values (separated by comma)", | |
placeholder="yellow, red, blue", | |
) | |
add_btn = gr.Button("Add Attribute") | |
with gr.Row(): | |
submit_btn = gr.Button("Extract Attributes") | |
with gr.Column(scale=6): | |
output_json = gr.Json( | |
label="Extracted Attributes", value={}, show_indices=False | |
) | |
add_btn.click( | |
add_attribute_schema, | |
inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], | |
outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values], | |
) | |
submit_btn.click( | |
forward_request, | |
inputs=[attributes, product_taxnomy, product_data, ai_model, gallery], | |
outputs=output_json, | |
) | |
attr_user = os.getenv("ATTR_USER", "1") | |
attr_pass = os.getenv("ATTR_PASS", "a") | |
auth = (attr_user, attr_pass) | |
demo.launch(auth=auth, debug=True, ssr_mode=False) | |