UVIS / models /detection /detector.py
DurgaDeepak's picture
Update models/detection/detector.py
03c9511 verified
import os
import logging
import random
from PIL import Image, ImageDraw, ImageFont
from huggingface_hub import hf_hub_download
logger = logging.getLogger(__name__)
class ObjectDetector:
def __init__(self, model_key="yolov8n", device="cpu"):
self.device = device
self.model = None
self.model_key = model_key.lower().replace(".pt", "")
hf_map = {
"yolov8n": ("ultralytics/yolov8", "yolov8n.pt"),
"yolov8s": ("ultralytics/yolov8", "yolov8s.pt"),
"yolov8l": ("ultralytics/yolov8", "yolov8l.pt"),
"yolov11b": ("Ultralytics/YOLO11", "yolov11b.pt"),
}
if self.model_key not in hf_map:
raise ValueError(f"Unsupported model key: {self.model_key}")
repo_id, filename = hf_map[self.model_key]
self.weights_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir="models/detection/weights",
force_download=False
)
def load_model(self):
logger.info(f"Loading model from path: {self.weights_path}")
if self.model is None:
import torch # Safe to import here
from ultralytics import YOLO # Defer import
if self.device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Initialize model
self.model = YOLO(self.weights_path)
# Move to CUDA only if necessary and safe
if self.device == "cuda" and torch.cuda.is_available():
self.model.to("cuda")
return self
def predict(self, image: Image.Image, conf_threshold=0.25):
self.load_model()
if self.model is None:
raise RuntimeError("YOLO model not loaded. Call load_model() first.")
results = self.model(image)
detections = []
for r in results:
for box in r.boxes:
detections.append({
"class_name": r.names[int(box.cls)],
"confidence": float(box.conf),
"bbox": box.xyxy[0].tolist()
})
return detections
def draw(self, image: Image.Image, detections, alpha=0.5):
"""
Draws thicker, per-class-colored bounding boxes and labels.
Args:
image (PIL.Image.Image): Original image.
detections (List[Dict]): Each dict has "bbox", "class_name", "confidence".
alpha (float): Blend strength for overlay.
Returns:
PIL.Image.Image: Blended image with overlays.
"""
# copy & overlay
overlay = image.copy()
draw = ImageDraw.Draw(overlay)
# try a TTF font, fallback to default
try:
font = ImageFont.truetype("arial.ttf", 18)
except:
font = ImageFont.load_default()
# deterministic color per class
class_colors = {}
def get_color(cls):
if cls not in class_colors:
# seed by class name → same color every run
rnd = random.Random(cls)
class_colors[cls] = (
rnd.randint(100, 255),
rnd.randint(100, 255),
rnd.randint(100, 255),
)
return class_colors[cls]
for det in detections:
x1, y1, x2, y2 = det["bbox"]
cls_name = det["class_name"]
conf = det["confidence"]
label = f"{cls_name} {conf:.2f}"
color = get_color(cls_name)
# thicker box: draw multiple offsets
for t in range(4):
draw.rectangle(
(x1 - t, y1 - t, x2 + t, y2 + t),
outline=color
)
# calculate text size
text_box = draw.textbbox((x1, y1), label, font=font)
tb_w = text_box[2] - text_box[0]
tb_h = text_box[3] - text_box[1]
# background rect for text
bg = (x1, y1 - tb_h, x1 + tb_w + 6, y1)
draw.rectangle(bg, fill=color)
# draw text (with small padding)
draw.text((x1 + 3, y1 - tb_h), label, font=font, fill="black")
# blend and return
return Image.blend(image, overlay, alpha)