Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation, YolosImageProcessor, YolosForObjectDetection | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import cv2 | |
import io | |
import base64 | |
app = FastAPI(title="Fashion Detection API", version="1.0.0") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global models | |
seg_processor = None | |
seg_model = None | |
obj_processor = None | |
obj_model = None | |
async def load_models(): | |
"""Load both models on startup""" | |
global seg_processor, seg_model, obj_processor, obj_model | |
print("Loading models...") | |
seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") | |
obj_processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") | |
obj_model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") | |
print("✅ Models loaded!") | |
def detect_clothing_items(image): | |
"""Detect main clothing using segmentation""" | |
MAIN_CLOTHING = {4: "Upper-clothes", 6: "Pants", 5: "Skirt", 7: "Dress", 8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 16: "Bag"} | |
# Segmentation | |
seg_inputs = seg_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
seg_outputs = seg_model(**seg_inputs) | |
logits = seg_outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate(logits, size=image.size[::-1], mode="bilinear", align_corners=False) | |
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy() | |
items = [] | |
for label_id, label_name in MAIN_CLOTHING.items(): | |
item_mask = (pred_seg == label_id).astype(np.uint8) | |
if np.sum(item_mask) < 500: | |
continue | |
contours, _ = cv2.findContours(item_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
if not contours: | |
continue | |
for i, contour in enumerate(contours): | |
if cv2.contourArea(contour) < 500: | |
continue | |
x, y, w, h = cv2.boundingRect(contour) | |
if w < 50 or h < 50: | |
continue | |
# Add padding | |
padding = 15 | |
x1 = max(0, x - padding) | |
y1 = max(0, y - padding) | |
x2 = min(image.width, x + w + padding) | |
y2 = min(image.height, y + h + padding) | |
# Crop and convert to base64 | |
cropped = image.crop((x1, y1, x2, y2)) | |
buffer = io.BytesIO() | |
cropped.save(buffer, format="PNG") | |
img_base64 = base64.b64encode(buffer.getvalue()).decode() | |
items.append({ | |
"type": label_name, | |
"confidence": round(cv2.contourArea(contour) / ((x2-x1) * (y2-y1)), 2), | |
"bbox": [x1, y1, x2, y2], | |
"image": img_base64 | |
}) | |
return items | |
def detect_accessories(image): | |
"""Detect accessories using object detection""" | |
accessory_map = {'glasses': 'Glasses', 'hat': 'Hat', 'watch': 'Watch', 'scarf': 'Scarf', 'tie': 'Tie', 'glove': 'Glove'} | |
inputs = obj_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = obj_model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = obj_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0] | |
items = [] | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
label_name = obj_model.config.id2label[label.item()] | |
if label_name in accessory_map: | |
x1, y1, x2, y2 = [int(coord) for coord in box.tolist()] | |
if (x2 - x1) < 30 or (y2 - y1) < 30: | |
continue | |
# Add padding | |
padding = 10 | |
x1 = max(0, x1 - padding) | |
y1 = max(0, y1 - padding) | |
x2 = min(image.width, x2 + padding) | |
y2 = min(image.height, y2 + padding) | |
# Crop and convert to base64 | |
cropped = image.crop((x1, y1, x2, y2)) | |
buffer = io.BytesIO() | |
cropped.save(buffer, format="PNG") | |
img_base64 = base64.b64encode(buffer.getvalue()).decode() | |
items.append({ | |
"type": accessory_map[label_name], | |
"confidence": round(score.item(), 2), | |
"bbox": [x1, y1, x2, y2], | |
"image": img_base64 | |
}) | |
return items | |
async def detect_fashion_items(file: UploadFile = File(...)): | |
"""Upload image and get detected fashion items""" | |
if not file.content_type.startswith("image/"): | |
raise HTTPException(status_code=400, detail="Must be an image file") | |
try: | |
# Process image | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
# Run both models | |
clothing = detect_clothing_items(image) | |
accessories = detect_accessories(image) | |
return { | |
"success": True, | |
"total_items": len(clothing) + len(accessories), | |
"results": { | |
"clothing": clothing, | |
"accessories": accessories | |
} | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Fashion Detection API", "endpoint": "POST /detect"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |