Spaces:
Running
on
Zero
Running
on
Zero
# ---------------------------------------------------------------------- | |
# IMPORTS | |
# ---------------------------------------------------------------------- | |
import torch | |
import logging | |
import time | |
import spaces | |
import sys | |
import traceback | |
from PIL import Image | |
from typing import List, Optional, Any | |
from collections import defaultdict | |
from src.utils import LOG_LEVEL_MAP, EMOJI_MAP | |
# ---------------------------------------------------------------------- | |
# RT-DETR CONSTANTS | |
# ---------------------------------------------------------------------- | |
RTDETR_CONF = 0.4 | |
RTDETR_ARTIFACT_CONF = 0.35 | |
# ---------------------------------------------------------------------- | |
# MODEL LABEL CONFIGURATION | |
# ---------------------------------------------------------------------- | |
MODEL_LABEL_CONFIG = { | |
"rtdetr_model": { | |
"person_list": { | |
"person": ["person"] | |
}, | |
"product_type_list": {}, | |
"head_list": {}, | |
"shoes_list": {}, | |
"clothing_features_list": { | |
"collar": ["tie"] | |
}, | |
"artifacts_list": { | |
"bag": ["backpack", "handbag", "suitcase"], | |
"cup": ["bottle", "wine glass", "cup"], | |
"umbrella": ["umbrella"], | |
"book": ["book"], | |
"phone": ["cell phone"], | |
"camera": [], | |
"other": ["fork", "knife", "spoon", "bowl", "frisbee", "sports ball", "kite", | |
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", | |
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", | |
"laptop", "mouse", "remote", "keyboard", "microwave", "oven", "toaster", | |
"sink", "refrigerator", "clock", "vase", "scissors", "teddy bear", | |
"hair drier", "toothbrush"] | |
} | |
} | |
} | |
# ---------------------------------------------------------------------- | |
# RT-DETR HELPER FUNCTIONS | |
# ---------------------------------------------------------------------- | |
def get_rtdetr_clothing_labels(): | |
clothing_labels = set() | |
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) | |
for keyword, labels in rtdetr_config.get("person_list", {}).items(): | |
clothing_labels.update(labels) | |
for keyword, labels in rtdetr_config.get("product_type_list", {}).items(): | |
clothing_labels.update(labels) | |
clothing_labels.update(["coat", "dress", "jacket", "shirt", "skirt", "pants", "shorts"]) | |
return clothing_labels | |
def get_rtdetr_person_and_product_labels(): | |
labels = set() | |
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) | |
for keyword, label_list in rtdetr_config.get("person_list", {}).items(): | |
labels.update(label_list) | |
for keyword, label_list in rtdetr_config.get("product_type_list", {}).items(): | |
labels.update(label_list) | |
labels.update(["person", "coat", "dress", "jacket", "shirt", "skirt", "pants", "shorts"]) | |
return labels | |
def get_rtdetr_artifact_labels(): | |
artifact_labels = set() | |
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) | |
for keyword, labels in rtdetr_config.get("artifacts_list", {}).items(): | |
if keyword != "other": | |
artifact_labels.update(labels) | |
return artifact_labels | |
def get_label_name_from_model(model, label_id): | |
if hasattr(model, 'config') and hasattr(model.config, 'id2label'): | |
return model.config.id2label.get(label_id, f"unknown_{label_id}").lower() | |
if hasattr(model, 'model_labels') and isinstance(model.model_labels, dict): | |
return model.model_labels.get(label_id, f"unknown_{label_id}").lower() | |
return f"unknown_{label_id}" | |
def map_label_to_keyword(label_name: str, valid_kws: List[str], model_name: str) -> Optional[str]: | |
ln = label_name.strip().lower() | |
model_config = MODEL_LABEL_CONFIG.get(model_name, {}) | |
for list_type in ["person_list", "product_type_list", "head_list", | |
"shoes_list", "clothing_features_list", "artifacts_list"]: | |
category_config = model_config.get(list_type, {}) | |
for keyword, labels in category_config.items(): | |
if keyword in valid_kws: | |
for label in labels: | |
if ln == label.lower() or ln in label.lower(): | |
return keyword | |
return None | |
def process_rtdetr_results(results, model, label_set, threshold, fallback_box=None): | |
try: | |
if isinstance(results, list): | |
if len(results) > 0: | |
result = results[0] | |
else: | |
return None, 0.0, None | |
else: | |
result = results | |
found_box = None | |
found_score = 0.0 | |
found_label = None | |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): | |
score_val = score.item() | |
if score_val < threshold: | |
continue | |
label_id = label.item() | |
label_name = get_label_name_from_model(model, label_id) | |
if label_name in label_set: | |
x1, y1, x2, y2 = [int(val) for val in box.tolist()] | |
found_box = [x1, y1, x2, y2] | |
found_score = score_val | |
found_label = label_name | |
break | |
return found_box, found_score, found_label | |
except Exception as e: | |
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} Error processing RTDETR results: {e}") | |
return fallback_box, 0.0, None | |
# ---------------------------------------------------------------------- | |
# RT-DETR DETECTION FUNCTIONS | |
# ---------------------------------------------------------------------- | |
def detect_rtdetr_in_roi(roi_rgb, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, log_item): | |
boxes = [] | |
labels = [] | |
scores = [] | |
raw_labels = [] | |
try: | |
rtdetr_inputs = RTDETR_PROCESSOR(images=roi_rgb, return_tensors="pt") | |
rtdetr_inputs = {k: v.to(DEVICE) for k, v in rtdetr_inputs.items()} | |
with torch.no_grad(): | |
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) | |
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( | |
rtdetr_outputs, | |
target_sizes=torch.tensor([[roi_rgb.height, roi_rgb.width]]).to(DEVICE), | |
threshold=RTDETR_CONF | |
) | |
if isinstance(rtdetr_results, list) and len(rtdetr_results) > 0: | |
result = rtdetr_results[0] | |
else: | |
result = rtdetr_results | |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): | |
label_id = label.item() | |
score_val = score.item() | |
x1, y1, x2, y2 = [int(val) for val in box.tolist()] | |
label_name = get_label_name_from_model(RTDETR_MODEL, label_id) | |
boxes.append([x1, y1, x2, y2]) | |
labels.append(label_id) | |
scores.append(score_val) | |
raw_labels.append(label_name) | |
logging.log(LOG_LEVEL_MAP["INFO"], f"rtdetr_model: {EMOJI_MAP['INFO']} RT-DETR detected: {label_name} at score {score_val:.3f}") | |
except Exception as e: | |
error_msg = f"RTDETR detection error: {str(e)}" | |
error_trace = traceback.format_exc() | |
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}") | |
logging.error(f"Traceback:\n{error_trace}") | |
log_item["warnings"] = log_item.get("warnings", []) + [error_msg] | |
log_item["traceback"] = error_trace | |
if "CUDA must not be initialized" in str(e): | |
logging.critical("CUDA initialization error in Spaces Zero GPU environment") | |
sys.exit(1) | |
return boxes, labels, scores, raw_labels | |
def detect_rtdetr_artifacts_in_roi(roi_rgb, keywords, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, log_item): | |
boxes = [] | |
labels = [] | |
scores = [] | |
raw_labels = [] | |
try: | |
rtdetr_inputs = RTDETR_PROCESSOR(images=roi_rgb, return_tensors="pt") | |
rtdetr_inputs = {k: v.to(DEVICE) for k, v in rtdetr_inputs.items()} | |
with torch.no_grad(): | |
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) | |
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( | |
rtdetr_outputs, | |
target_sizes=torch.tensor([[roi_rgb.height, roi_rgb.width]]).to(DEVICE), | |
threshold=RTDETR_ARTIFACT_CONF | |
) | |
rtdetr_artifact_labels = get_rtdetr_artifact_labels() | |
if isinstance(rtdetr_results, list) and len(rtdetr_results) > 0: | |
result = rtdetr_results[0] | |
else: | |
result = rtdetr_results | |
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): | |
label_id = label.item() | |
score_val = score.item() | |
if score_val < RTDETR_ARTIFACT_CONF: | |
continue | |
label_name = get_label_name_from_model(RTDETR_MODEL, label_id) | |
if label_name in rtdetr_artifact_labels: | |
x1, y1, x2, y2 = [int(val) for val in box.tolist()] | |
artifact_keyword = map_label_to_keyword(label_name, keywords, "rtdetr_model") | |
if not artifact_keyword: | |
continue | |
boxes.append([x1, y1, x2, y2]) | |
labels.append(label_id) | |
scores.append(score_val) | |
raw_labels.append(label_name) | |
logging.log(LOG_LEVEL_MAP["INFO"], f"rtdetr_model: {EMOJI_MAP['INFO']} Artifact detected: {label_name} at score {score_val:.3f}") | |
except Exception as e: | |
error_msg = f"RTDETR artifact detection error: {str(e)}" | |
error_trace = traceback.format_exc() | |
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}") | |
logging.error(f"Traceback:\n{error_trace}") | |
log_item["warnings"] = log_item.get("warnings", []) + [error_msg] | |
log_item["traceback"] = error_trace | |
if "CUDA must not be initialized" in str(e): | |
logging.critical("CUDA initialization error in Spaces Zero GPU environment") | |
sys.exit(1) | |
return boxes, labels, scores, raw_labels | |
def update_fallback_detection(ctx, pi_rgba, fallback_box, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, RTDETR_CONF, final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log): | |
try: | |
if not (fallback_box and isinstance(fallback_box, list) and len(fallback_box) == 4): | |
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log | |
sub_ = pi_rgba.crop(( | |
fallback_box[0], | |
fallback_box[1], | |
fallback_box[2], | |
fallback_box[3] | |
)) | |
sub_ = sub_.convert("RGB") | |
subW = sub_.width | |
subH = sub_.height | |
rtdetr_inputs = RTDETR_PROCESSOR(images=sub_, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) | |
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( | |
rtdetr_outputs, | |
target_sizes=torch.tensor([[subH, subW]]).to(DEVICE), | |
threshold=RTDETR_CONF | |
) | |
rtdetr_clothing_labels = get_rtdetr_clothing_labels() | |
found_fb_box, found_fb_score, _ = process_rtdetr_results( | |
rtdetr_results, RTDETR_MODEL, rtdetr_clothing_labels, RTDETR_CONF | |
) | |
if found_fb_box: | |
fx1 = fallback_box[0] + found_fb_box[0] | |
fy1 = fallback_box[1] + found_fb_box[1] | |
fx2 = fallback_box[0] + found_fb_box[2] | |
fy2 = fallback_box[1] + found_fb_box[3] | |
found_fb_box = [fx1, fy1, fx2, fy2] | |
final_boxes.append(found_fb_box) | |
final_labels.append(90001) | |
final_scores.append(round(found_fb_score, 2)) | |
final_kws.append(ctx.product_type) | |
final_raws.append("fallback_label") | |
final_mods.append("rtdetr_model") | |
dd_log[ctx.product_type].append({ | |
"box": found_fb_box, | |
"score": found_fb_score, | |
"raw_label": "fallback_label", | |
"model": "rtdetr_model" | |
}) | |
else: | |
final_boxes.append(fallback_box) | |
final_labels.append(90000) | |
final_scores.append(0.0) | |
final_kws.append(ctx.product_type) | |
final_raws.append("fallback_label") | |
final_mods.append("fallback") | |
dd_log[ctx.product_type].append({ | |
"box": fallback_box, | |
"score": 0.0, | |
"raw_label": "fallback_label", | |
"model": "fallback" | |
}) | |
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log | |
except Exception as e: | |
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} Fallback detection error: {e}") | |
final_boxes.append(fallback_box) | |
final_labels.append(90000) | |
final_scores.append(0.0) | |
final_kws.append(ctx.product_type) | |
final_raws.append("fallback_label") | |
final_mods.append("fallback_error") | |
dd_log[ctx.product_type].append({ | |
"box": fallback_box, | |
"score": 0.0, | |
"raw_label": "fallback_label", | |
"model": "fallback_error" | |
}) | |
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log | |