File size: 2,199 Bytes
9b63413
 
 
 
 
 
 
 
 
 
 
2835d45
9b63413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import numpy as np
import requests
import torch.nn.functional as F
import torch
import os

def generate_clothing_mask(
    image_path: str,
    label: int,
    output_path: str = "./test/output_mask.png",
    model_name: str = "mattmdjaga/segformer_b2_clothes",
) -> Image.Image:
    """
    Генерирует бинарную маску для указанного класса одежды и сохраняет её
    
    Args:
        image_path: Путь к изображению или URL
        label: Класс для сегментации (0-17)
        output_path: Путь для сохранения маски
        model_name: Название модели HuggingFace
        show_result: Показать результат matplotlib
        
    Returns:
        PIL.Image: Бинарная маска (белый - выбранный класс, черный - остальное)
    """

    processor = SegformerImageProcessor.from_pretrained(model_name)
    model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
    
    if image_path.startswith(('http://', 'https://')):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    image_np = np.array(image)
    if len(image_np.shape) != 3 or image_np.shape[2] != 3:
        raise ValueError("Изображение должно быть в формате RGB (H, W, 3)")
    
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    upsampled_logits = F.interpolate(
        logits,
        size=image.size[::-1],
        mode="bilinear",
        align_corners=False,
    )
    
    pred_seg = upsampled_logits.argmax(dim=1)[0]
    mask = (pred_seg == label).numpy().astype('uint8') * 255
    mask_image = Image.fromarray(mask)
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    mask_image.save(output_path)
    
    return mask_image