Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,253 Bytes
18faf97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# ----------------------------------------------------------------------
# IMPORTS
# ----------------------------------------------------------------------
import torch
import logging
import spaces
import sys
import traceback
from typing import List, Tuple
from src.utils import LOG_LEVEL_MAP, EMOJI_MAP
# ----------------------------------------------------------------------
# HEAD MODEL CONSTANTS
# ----------------------------------------------------------------------
HEAD_DETECTION_THRESHOLD = 0.2
# ----------------------------------------------------------------------
# MODEL LABEL CONFIGURATION
# ----------------------------------------------------------------------
MODEL_LABEL_CONFIG = {
"head_model": {
"person_list": {},
"product_type_list": {},
"head_list": {
"head": ["head", "face"]
},
"shoes_list": {},
"clothing_features_list": {},
"artifacts_list": {}
}
}
# ----------------------------------------------------------------------
# HEAD MODEL HELPER FUNCTIONS
# ----------------------------------------------------------------------
def get_label_name_from_model(model, label_id):
if hasattr(model, 'config') and hasattr(model.config, 'id2label'):
return model.config.id2label.get(label_id, f"unknown_{label_id}").lower()
if hasattr(model, 'model_labels') and isinstance(model.model_labels, dict):
return model.model_labels.get(label_id, f"unknown_{label_id}").lower()
return f"unknown_{label_id}"
def clamp_box_to_region(box: List[int], region: List[int]) -> List[int]:
x1, y1, x2, y2 = box
rx1, ry1, rx2, ry2 = region
xx1 = max(rx1, min(x1, rx2))
yy1 = max(ry1, min(y1, ry2))
xx2 = max(rx1, min(x2, rx2))
yy2 = max(ry1, min(y2, ry2))
return [xx1, yy1, xx2, yy2]
# ----------------------------------------------------------------------
# HEAD DETECTION FUNCTIONS
# ----------------------------------------------------------------------
def detect_head_in_roi(roi_rgb, rx1, ry1, rW, rH, HEAD_PROCESSOR, HEAD_MODEL, HEAD_DETECTION_FULL_PRECISION, DEVICE, log_item):
boxes = []
labels = []
scores = []
raw_labels = []
try:
hd_in = HEAD_PROCESSOR(
images=roi_rgb,
return_tensors="pt",
do_resize=False,
do_normalize=True
).to(DEVICE)
if not HEAD_DETECTION_FULL_PRECISION and HEAD_MODEL.dtype == torch.float16:
hd_in = {k: v.half() if v.dtype == torch.float32 else v for k, v in hd_in.items()}
with torch.no_grad():
hd_out = HEAD_MODEL(**hd_in)
hd_logits = hd_out.logits[0]
hd_boxes = hd_out.pred_boxes[0]
if hd_logits.size(-1) > 1:
softmax_scores = torch.softmax(hd_logits, dim=-1)
if softmax_scores.size(-1) > 1:
class_scores = softmax_scores[:, :-1]
max_scores, max_score_indices = torch.max(class_scores, dim=1)
above_threshold_indices = torch.where(max_scores >= HEAD_DETECTION_THRESHOLD)[0].cpu().tolist()
for i_ in above_threshold_indices:
score_val = max_scores[i_].item()
label_idx = max_score_indices[i_].item()
label_name = get_label_name_from_model(HEAD_MODEL, label_idx)
if label_name in ["face", "head"]:
if i_ < len(hd_boxes):
box_data = hd_boxes[i_].tolist()
if len(box_data) >= 4:
cx, cy, w_, h_ = box_data
x1 = int(rx1 + (cx - 0.5 * w_) * rW)
y1 = int(ry1 + (cy - 0.5 * h_) * rH)
x2 = int(rx1 + (cx + 0.5 * w_) * rW)
y2 = int(ry1 + (cy + 0.5 * h_) * rH)
x1, y1, x2, y2 = clamp_box_to_region(
[x1, y1, x2, y2],
[rx1, ry1, rx1 + rW, ry1 + rH]
)
boxes.append([x1, y1, x2, y2])
labels.append(9999)
scores.append(score_val)
raw_labels.append(label_name)
logging.log(LOG_LEVEL_MAP["INFO"], f"{EMOJI_MAP['INFO']} Head detected: {label_name} at score {score_val:.3f}")
except Exception as e:
error_msg = f"Head detection error: {str(e)}"
error_trace = traceback.format_exc()
logging.log(LOG_LEVEL_MAP["WARNING"], f"{EMOJI_MAP['WARNING']} {error_msg}")
logging.error(f"Traceback:\n{error_trace}")
log_item["warnings"] = log_item.get("warnings", []) + [error_msg]
log_item["traceback"] = error_trace
if "CUDA must not be initialized" in str(e):
logging.critical("CUDA initialization error in Spaces Zero GPU environment")
sys.exit(1)
return boxes, labels, scores, raw_labels
|