Spaces:
Running
on
Zero
Running
on
Zero
# ---------------------------------------------------------------------- | |
# IMPORTS | |
# ---------------------------------------------------------------------- | |
import torch | |
import logging | |
import spaces | |
import sys | |
import traceback | |
from typing import List, Tuple | |
from src.utils import LOG_LEVEL_MAP, EMOJI_MAP | |
# ---------------------------------------------------------------------- | |
# HEAD MODEL CONSTANTS | |
# ---------------------------------------------------------------------- | |
HEAD_DETECTION_THRESHOLD = 0.2 | |
# ---------------------------------------------------------------------- | |
# MODEL LABEL CONFIGURATION | |
# ---------------------------------------------------------------------- | |
MODEL_LABEL_CONFIG = { | |
"head_model": { | |
"person_list": {}, | |
"product_type_list": {}, | |
"head_list": { | |
"head": ["head", "face"] | |
}, | |
"shoes_list": {}, | |
"clothing_features_list": {}, | |
"artifacts_list": {} | |
} | |
} | |
# ---------------------------------------------------------------------- | |
# HEAD MODEL HELPER FUNCTIONS | |
# ---------------------------------------------------------------------- | |
def get_label_name_from_model(model, label_id): | |
if hasattr(model, 'config') and hasattr(model.config, 'id2label'): | |
return model.config.id2label.get(label_id, f"unknown_{label_id}").lower() | |
if hasattr(model, 'model_labels') and isinstance(model.model_labels, dict): | |
return model.model_labels.get(label_id, f"unknown_{label_id}").lower() | |
return f"unknown_{label_id}" | |
def clamp_box_to_region(box: List[int], region: List[int]) -> List[int]: | |
x1, y1, x2, y2 = box | |
rx1, ry1, rx2, ry2 = region | |
xx1 = max(rx1, min(x1, rx2)) | |
yy1 = max(ry1, min(y1, ry2)) | |
xx2 = max(rx1, min(x2, rx2)) | |
yy2 = max(ry1, min(y2, ry2)) | |
return [xx1, yy1, xx2, yy2] | |
# ---------------------------------------------------------------------- | |
# HEAD DETECTION FUNCTIONS | |
# ---------------------------------------------------------------------- | |
def detect_head_in_roi(roi_rgb, rx1, ry1, rW, rH, HEAD_PROCESSOR, HEAD_MODEL, HEAD_DETECTION_FULL_PRECISION, DEVICE, log_item): | |
boxes = [] | |
labels = [] | |
scores = [] | |
raw_labels = [] | |
try: | |
hd_in = HEAD_PROCESSOR( | |
images=roi_rgb, | |
return_tensors="pt", | |
do_resize=False, | |
do_normalize=True | |
).to(DEVICE) | |
if not HEAD_DETECTION_FULL_PRECISION and HEAD_MODEL.dtype == torch.float16: | |
hd_in = {k: v.half() if v.dtype == torch.float32 else v for k, v in hd_in.items()} | |
with torch.no_grad(): | |
hd_out = HEAD_MODEL(**hd_in) | |
hd_logits = hd_out.logits[0] | |
hd_boxes = hd_out.pred_boxes[0] | |
if hd_logits.size(-1) > 1: | |
softmax_scores = torch.softmax(hd_logits, dim=-1) | |
if softmax_scores.size(-1) > 1: | |
class_scores = softmax_scores[:, :-1] | |
max_scores, max_score_indices = torch.max(class_scores, dim=1) | |
above_threshold_indices = torch.where(max_scores >= HEAD_DETECTION_THRESHOLD)[0].cpu().tolist() | |
for i_ in above_threshold_indices: | |
score_val = max_scores[i_].item() | |
label_idx = max_score_indices[i_].item() | |
label_name = get_label_name_from_model(HEAD_MODEL, label_idx) | |
if label_name in ["face", "head"]: | |
if i_ < len(hd_boxes): | |
box_data = hd_boxes[i_].tolist() | |
if len(box_data) >= 4: | |
cx, cy, w_, h_ = box_data | |
x1 = int(rx1 + (cx - 0.5 * w_) * rW) | |
y1 = int(ry1 + (cy - 0.5 * h_) * rH) | |
x2 = int(rx1 + (cx + 0.5 * w_) * rW) | |
y2 = int(ry1 + (cy + 0.5 * h_) * rH) | |
x1, y1, x2, y2 = clamp_box_to_region( | |
[x1, y1, x2, y2], | |
[rx1, ry1, rx1 + rW, ry1 + rH] | |
) | |
boxes.append([x1, y1, x2, y2]) | |
labels.append(9999) | |
scores.append(score_val) | |
raw_labels.append(label_name) | |
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Head detected: {label_name} at score {score_val:.3f}") | |
except Exception as e: | |
error_msg = f"Head detection error: {str(e)}" | |
error_trace = traceback.format_exc() | |
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}") | |
logging.error(f"Traceback:\n{error_trace}") | |
log_item["warnings"] = log_item.get("warnings", []) + [error_msg] | |
log_item["traceback"] = error_trace | |
if "CUDA must not be initialized" in str(e): | |
logging.critical("CUDA initialization error in Spaces Zero GPU environment") | |
sys.exit(1) | |
return boxes, labels, scores, raw_labels | |