|
|
|
|
|
|
|
import torch |
|
import logging |
|
import time |
|
import spaces |
|
import sys |
|
import traceback |
|
from PIL import Image |
|
from typing import List, Optional, Any |
|
from collections import defaultdict |
|
from src.utils import LOG_LEVEL_MAP, EMOJI_MAP |
|
|
|
|
|
|
|
|
|
RTDETR_CONF = 0.4 |
|
RTDETR_ARTIFACT_CONF = 0.35 |
|
|
|
|
|
|
|
|
|
MODEL_LABEL_CONFIG = { |
|
"rtdetr_model": { |
|
"person_list": { |
|
"person": ["person"] |
|
}, |
|
"product_type_list": {}, |
|
"head_list": {}, |
|
"shoes_list": {}, |
|
"clothing_features_list": { |
|
"collar": ["tie"] |
|
}, |
|
"artifacts_list": { |
|
"bag": ["backpack", "handbag", "suitcase"], |
|
"cup": ["bottle", "wine glass", "cup"], |
|
"umbrella": ["umbrella"], |
|
"book": ["book"], |
|
"phone": ["cell phone"], |
|
"camera": [], |
|
"other": ["fork", "knife", "spoon", "bowl", "frisbee", "sports ball", "kite", |
|
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", |
|
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", |
|
"laptop", "mouse", "remote", "keyboard", "microwave", "oven", "toaster", |
|
"sink", "refrigerator", "clock", "vase", "scissors", "teddy bear", |
|
"hair drier", "toothbrush"] |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
def get_rtdetr_clothing_labels(): |
|
clothing_labels = set() |
|
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) |
|
|
|
for keyword, labels in rtdetr_config.get("person_list", {}).items(): |
|
clothing_labels.update(labels) |
|
|
|
for keyword, labels in rtdetr_config.get("product_type_list", {}).items(): |
|
clothing_labels.update(labels) |
|
|
|
clothing_labels.update(["coat", "dress", "jacket", "shirt", "skirt", "pants", "shorts"]) |
|
|
|
return clothing_labels |
|
|
|
def get_rtdetr_person_and_product_labels(): |
|
labels = set() |
|
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) |
|
|
|
for keyword, label_list in rtdetr_config.get("person_list", {}).items(): |
|
labels.update(label_list) |
|
|
|
for keyword, label_list in rtdetr_config.get("product_type_list", {}).items(): |
|
labels.update(label_list) |
|
|
|
labels.update(["person", "coat", "dress", "jacket", "shirt", "skirt", "pants", "shorts"]) |
|
|
|
return labels |
|
|
|
def get_rtdetr_artifact_labels(): |
|
artifact_labels = set() |
|
rtdetr_config = MODEL_LABEL_CONFIG.get("rtdetr_model", {}) |
|
|
|
for keyword, labels in rtdetr_config.get("artifacts_list", {}).items(): |
|
if keyword != "other": |
|
artifact_labels.update(labels) |
|
|
|
return artifact_labels |
|
|
|
def get_label_name_from_model(model, label_id): |
|
if hasattr(model, 'config') and hasattr(model.config, 'id2label'): |
|
return model.config.id2label.get(label_id, f"unknown_{label_id}").lower() |
|
if hasattr(model, 'model_labels') and isinstance(model.model_labels, dict): |
|
return model.model_labels.get(label_id, f"unknown_{label_id}").lower() |
|
return f"unknown_{label_id}" |
|
|
|
def map_label_to_keyword(label_name: str, valid_kws: List[str], model_name: str) -> Optional[str]: |
|
ln = label_name.strip().lower() |
|
|
|
model_config = MODEL_LABEL_CONFIG.get(model_name, {}) |
|
|
|
for list_type in ["person_list", "product_type_list", "head_list", |
|
"shoes_list", "clothing_features_list", "artifacts_list"]: |
|
category_config = model_config.get(list_type, {}) |
|
|
|
for keyword, labels in category_config.items(): |
|
if keyword in valid_kws: |
|
for label in labels: |
|
if ln == label.lower() or ln in label.lower(): |
|
return keyword |
|
|
|
return None |
|
|
|
def process_rtdetr_results(results, model, label_set, threshold, fallback_box=None): |
|
try: |
|
if isinstance(results, list): |
|
if len(results) > 0: |
|
result = results[0] |
|
else: |
|
return None, 0.0, None |
|
else: |
|
result = results |
|
|
|
found_box = None |
|
found_score = 0.0 |
|
found_label = None |
|
|
|
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): |
|
score_val = score.item() |
|
if score_val < threshold: |
|
continue |
|
|
|
label_id = label.item() |
|
label_name = get_label_name_from_model(model, label_id) |
|
|
|
if label_name in label_set: |
|
x1, y1, x2, y2 = [int(val) for val in box.tolist()] |
|
found_box = [x1, y1, x2, y2] |
|
found_score = score_val |
|
found_label = label_name |
|
break |
|
|
|
return found_box, found_score, found_label |
|
except Exception as e: |
|
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} Error processing RTDETR results: {e}") |
|
return fallback_box, 0.0, None |
|
|
|
|
|
|
|
|
|
def detect_rtdetr_in_roi(roi_rgb, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, log_item): |
|
boxes = [] |
|
labels = [] |
|
scores = [] |
|
raw_labels = [] |
|
|
|
try: |
|
rtdetr_inputs = RTDETR_PROCESSOR(images=roi_rgb, return_tensors="pt") |
|
rtdetr_inputs = {k: v.to(DEVICE) for k, v in rtdetr_inputs.items()} |
|
|
|
with torch.no_grad(): |
|
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) |
|
|
|
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( |
|
rtdetr_outputs, |
|
target_sizes=torch.tensor([[roi_rgb.height, roi_rgb.width]]).to(DEVICE), |
|
threshold=RTDETR_CONF |
|
) |
|
|
|
if isinstance(rtdetr_results, list) and len(rtdetr_results) > 0: |
|
result = rtdetr_results[0] |
|
else: |
|
result = rtdetr_results |
|
|
|
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): |
|
label_id = label.item() |
|
score_val = score.item() |
|
x1, y1, x2, y2 = [int(val) for val in box.tolist()] |
|
label_name = get_label_name_from_model(RTDETR_MODEL, label_id) |
|
|
|
boxes.append([x1, y1, x2, y2]) |
|
labels.append(label_id) |
|
scores.append(score_val) |
|
raw_labels.append(label_name) |
|
|
|
logging.log(LOG_LEVEL_MAP["INFO"], f"rtdetr_model: {EMOJI_MAP['INFO']} RT-DETR detected: {label_name} at score {score_val:.3f}") |
|
|
|
except Exception as e: |
|
error_msg = f"RTDETR detection error: {str(e)}" |
|
error_trace = traceback.format_exc() |
|
|
|
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}") |
|
logging.error(f"Traceback:\n{error_trace}") |
|
|
|
log_item["warnings"] = log_item.get("warnings", []) + [error_msg] |
|
log_item["traceback"] = error_trace |
|
|
|
if "CUDA must not be initialized" in str(e): |
|
logging.critical("CUDA initialization error in Spaces Zero GPU environment") |
|
sys.exit(1) |
|
|
|
return boxes, labels, scores, raw_labels |
|
|
|
def detect_rtdetr_artifacts_in_roi(roi_rgb, keywords, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, log_item): |
|
boxes = [] |
|
labels = [] |
|
scores = [] |
|
raw_labels = [] |
|
|
|
try: |
|
rtdetr_inputs = RTDETR_PROCESSOR(images=roi_rgb, return_tensors="pt") |
|
rtdetr_inputs = {k: v.to(DEVICE) for k, v in rtdetr_inputs.items()} |
|
|
|
with torch.no_grad(): |
|
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) |
|
|
|
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( |
|
rtdetr_outputs, |
|
target_sizes=torch.tensor([[roi_rgb.height, roi_rgb.width]]).to(DEVICE), |
|
threshold=RTDETR_ARTIFACT_CONF |
|
) |
|
|
|
rtdetr_artifact_labels = get_rtdetr_artifact_labels() |
|
|
|
if isinstance(rtdetr_results, list) and len(rtdetr_results) > 0: |
|
result = rtdetr_results[0] |
|
else: |
|
result = rtdetr_results |
|
|
|
for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): |
|
label_id = label.item() |
|
score_val = score.item() |
|
|
|
if score_val < RTDETR_ARTIFACT_CONF: |
|
continue |
|
|
|
label_name = get_label_name_from_model(RTDETR_MODEL, label_id) |
|
|
|
if label_name in rtdetr_artifact_labels: |
|
x1, y1, x2, y2 = [int(val) for val in box.tolist()] |
|
|
|
artifact_keyword = map_label_to_keyword(label_name, keywords, "rtdetr_model") |
|
if not artifact_keyword: |
|
continue |
|
|
|
boxes.append([x1, y1, x2, y2]) |
|
labels.append(label_id) |
|
scores.append(score_val) |
|
raw_labels.append(label_name) |
|
|
|
logging.log(LOG_LEVEL_MAP["INFO"], f"rtdetr_model: {EMOJI_MAP['INFO']} Artifact detected: {label_name} at score {score_val:.3f}") |
|
|
|
except Exception as e: |
|
error_msg = f"RTDETR artifact detection error: {str(e)}" |
|
error_trace = traceback.format_exc() |
|
|
|
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}") |
|
logging.error(f"Traceback:\n{error_trace}") |
|
|
|
log_item["warnings"] = log_item.get("warnings", []) + [error_msg] |
|
log_item["traceback"] = error_trace |
|
|
|
if "CUDA must not be initialized" in str(e): |
|
logging.critical("CUDA initialization error in Spaces Zero GPU environment") |
|
sys.exit(1) |
|
|
|
return boxes, labels, scores, raw_labels |
|
|
|
def update_fallback_detection(ctx, pi_rgba, fallback_box, RTDETR_PROCESSOR, RTDETR_MODEL, DEVICE, RTDETR_CONF, final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log): |
|
try: |
|
if not (fallback_box and isinstance(fallback_box, list) and len(fallback_box) == 4): |
|
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log |
|
|
|
sub_ = pi_rgba.crop(( |
|
fallback_box[0], |
|
fallback_box[1], |
|
fallback_box[2], |
|
fallback_box[3] |
|
)) |
|
sub_ = sub_.convert("RGB") |
|
subW = sub_.width |
|
subH = sub_.height |
|
|
|
rtdetr_inputs = RTDETR_PROCESSOR(images=sub_, return_tensors="pt").to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
rtdetr_outputs = RTDETR_MODEL(**rtdetr_inputs) |
|
|
|
rtdetr_results = RTDETR_PROCESSOR.post_process_object_detection( |
|
rtdetr_outputs, |
|
target_sizes=torch.tensor([[subH, subW]]).to(DEVICE), |
|
threshold=RTDETR_CONF |
|
) |
|
|
|
rtdetr_clothing_labels = get_rtdetr_clothing_labels() |
|
found_fb_box, found_fb_score, _ = process_rtdetr_results( |
|
rtdetr_results, RTDETR_MODEL, rtdetr_clothing_labels, RTDETR_CONF |
|
) |
|
|
|
if found_fb_box: |
|
fx1 = fallback_box[0] + found_fb_box[0] |
|
fy1 = fallback_box[1] + found_fb_box[1] |
|
fx2 = fallback_box[0] + found_fb_box[2] |
|
fy2 = fallback_box[1] + found_fb_box[3] |
|
|
|
found_fb_box = [fx1, fy1, fx2, fy2] |
|
final_boxes.append(found_fb_box) |
|
final_labels.append(90001) |
|
final_scores.append(round(found_fb_score, 2)) |
|
final_kws.append(ctx.product_type) |
|
final_raws.append("fallback_label") |
|
final_mods.append("rtdetr_model") |
|
dd_log[ctx.product_type].append({ |
|
"box": found_fb_box, |
|
"score": found_fb_score, |
|
"raw_label": "fallback_label", |
|
"model": "rtdetr_model" |
|
}) |
|
else: |
|
final_boxes.append(fallback_box) |
|
final_labels.append(90000) |
|
final_scores.append(0.0) |
|
final_kws.append(ctx.product_type) |
|
final_raws.append("fallback_label") |
|
final_mods.append("fallback") |
|
dd_log[ctx.product_type].append({ |
|
"box": fallback_box, |
|
"score": 0.0, |
|
"raw_label": "fallback_label", |
|
"model": "fallback" |
|
}) |
|
|
|
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log |
|
except Exception as e: |
|
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} Fallback detection error: {e}") |
|
final_boxes.append(fallback_box) |
|
final_labels.append(90000) |
|
final_scores.append(0.0) |
|
final_kws.append(ctx.product_type) |
|
final_raws.append("fallback_label") |
|
final_mods.append("fallback_error") |
|
dd_log[ctx.product_type].append({ |
|
"box": fallback_box, |
|
"score": 0.0, |
|
"raw_label": "fallback_label", |
|
"model": "fallback_error" |
|
}) |
|
return final_boxes, final_labels, final_scores, final_kws, final_raws, final_mods, dd_log |
|
|
|
|