|
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation |
|
from PIL import Image |
|
import numpy as np |
|
import requests |
|
import torch.nn.functional as F |
|
import torch |
|
import os |
|
|
|
def generate_clothing_mask( |
|
image_path: str, |
|
label: int, |
|
output_path: str = "./test/output_mask.png", |
|
model_name: str = "mattmdjaga/segformer_b2_clothes", |
|
) -> Image.Image: |
|
""" |
|
Генерирует бинарную маску для указанного класса одежды и сохраняет её |
|
|
|
Args: |
|
image_path: Путь к изображению или URL |
|
label: Класс для сегментации (0-17) |
|
output_path: Путь для сохранения маски |
|
model_name: Название модели HuggingFace |
|
show_result: Показать результат matplotlib |
|
|
|
Returns: |
|
PIL.Image: Бинарная маска (белый - выбранный класс, черный - остальное) |
|
""" |
|
|
|
processor = SegformerImageProcessor.from_pretrained(model_name) |
|
model = AutoModelForSemanticSegmentation.from_pretrained(model_name) |
|
|
|
if image_path.startswith(('http://', 'https://')): |
|
image = Image.open(requests.get(image_path, stream=True).raw) |
|
else: |
|
image = Image.open(image_path) |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
image_np = np.array(image) |
|
if len(image_np.shape) != 3 or image_np.shape[2] != 3: |
|
raise ValueError("Изображение должно быть в формате RGB (H, W, 3)") |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
upsampled_logits = F.interpolate( |
|
logits, |
|
size=image.size[::-1], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
pred_seg = upsampled_logits.argmax(dim=1)[0] |
|
mask = (pred_seg == label).numpy().astype('uint8') * 255 |
|
mask_image = Image.fromarray(mask) |
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
mask_image.save(output_path) |
|
|
|
return mask_image |
|
|