from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation, YolosImageProcessor, YolosForObjectDetection from PIL import Image import torch import torch.nn as nn import numpy as np import cv2 import io import base64 app = FastAPI(title="Fashion Detection API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global models seg_processor = None seg_model = None obj_processor = None obj_model = None @app.on_event("startup") async def load_models(): """Load both models on startup""" global seg_processor, seg_model, obj_processor, obj_model print("Loading models...") seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") obj_processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") obj_model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") print("✅ Models loaded!") def detect_clothing_items(image): """Detect main clothing using segmentation""" MAIN_CLOTHING = {4: "Upper-clothes", 6: "Pants", 5: "Skirt", 7: "Dress", 8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 16: "Bag"} # Segmentation seg_inputs = seg_processor(images=image, return_tensors="pt") with torch.no_grad(): seg_outputs = seg_model(**seg_inputs) logits = seg_outputs.logits.cpu() upsampled_logits = nn.functional.interpolate(logits, size=image.size[::-1], mode="bilinear", align_corners=False) pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() items = [] for label_id, label_name in MAIN_CLOTHING.items(): item_mask = (pred_seg == label_id).astype(np.uint8) if np.sum(item_mask) < 500: continue contours, _ = cv2.findContours(item_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: continue for i, contour in enumerate(contours): if cv2.contourArea(contour) < 500: continue x, y, w, h = cv2.boundingRect(contour) if w < 50 or h < 50: continue # Add padding padding = 15 x1 = max(0, x - padding) y1 = max(0, y - padding) x2 = min(image.width, x + w + padding) y2 = min(image.height, y + h + padding) # Crop and convert to base64 cropped = image.crop((x1, y1, x2, y2)) buffer = io.BytesIO() cropped.save(buffer, format="PNG") img_base64 = base64.b64encode(buffer.getvalue()).decode() items.append({ "type": label_name, "confidence": round(cv2.contourArea(contour) / ((x2-x1) * (y2-y1)), 2), "bbox": [x1, y1, x2, y2], "image": img_base64 }) return items def detect_accessories(image): """Detect accessories using object detection""" accessory_map = {'glasses': 'Glasses', 'hat': 'Hat', 'watch': 'Watch', 'scarf': 'Scarf', 'tie': 'Tie', 'glove': 'Glove'} inputs = obj_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = obj_model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = obj_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0] items = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): label_name = obj_model.config.id2label[label.item()] if label_name in accessory_map: x1, y1, x2, y2 = [int(coord) for coord in box.tolist()] if (x2 - x1) < 30 or (y2 - y1) < 30: continue # Add padding padding = 10 x1 = max(0, x1 - padding) y1 = max(0, y1 - padding) x2 = min(image.width, x2 + padding) y2 = min(image.height, y2 + padding) # Crop and convert to base64 cropped = image.crop((x1, y1, x2, y2)) buffer = io.BytesIO() cropped.save(buffer, format="PNG") img_base64 = base64.b64encode(buffer.getvalue()).decode() items.append({ "type": accessory_map[label_name], "confidence": round(score.item(), 2), "bbox": [x1, y1, x2, y2], "image": img_base64 }) return items @app.post("/detect") async def detect_fashion_items(file: UploadFile = File(...)): """Upload image and get detected fashion items""" if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="Must be an image file") try: # Process image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Run both models clothing = detect_clothing_items(image) accessories = detect_accessories(image) return { "success": True, "total_items": len(clothing) + len(accessories), "results": { "clothing": clothing, "accessories": accessories } } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Fashion Detection API", "endpoint": "POST /detect"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)