|
import os |
|
import torch |
|
import gradio as gr |
|
import tempfile |
|
import gc |
|
from dotenv import load_dotenv |
|
from huggingface_hub import hf_hub_download, login |
|
from diffusers import AutoencoderKL, DDPMScheduler |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
|
|
|
from promptdresser.models.unet import UNet2DConditionModel |
|
from promptdresser.models.cloth_encoder import ClothEncoder |
|
from promptdresser.pipelines.sdxl import PromptDresser |
|
from lib.caption import generate_caption |
|
from lib.mask import generate_clothing_mask |
|
from lib.pose import generate_openpose |
|
|
|
load_dotenv() |
|
TOKEN = os.getenv("HF_TOKEN") |
|
login(token=TOKEN) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
weight_dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
def load_models(): |
|
"""Загружает все необходимые модели""" |
|
print("⚙️ Загрузка моделей...") |
|
|
|
try: |
|
noise_scheduler = DDPMScheduler.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="scheduler" |
|
) |
|
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="tokenizer" |
|
) |
|
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="text_encoder" |
|
) |
|
|
|
tokenizer_2 = CLIPTokenizer.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="tokenizer_2" |
|
) |
|
|
|
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="text_encoder_2" |
|
) |
|
|
|
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix") |
|
unet = UNet2DConditionModel.from_pretrained( |
|
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", |
|
subfolder="unet" |
|
) |
|
|
|
unet_checkpoint_path = hf_hub_download( |
|
repo_id="Benrise/VITON-HD", |
|
filename="VITONHD/model/pytorch_model.bin", |
|
token=TOKEN |
|
) |
|
unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device)) |
|
|
|
cloth_encoder = ClothEncoder.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
subfolder="unet" |
|
) |
|
|
|
models = { |
|
"unet": unet.to(device, dtype=weight_dtype), |
|
"vae": vae.to(device, dtype=weight_dtype), |
|
"text_encoder": text_encoder.to(device, dtype=weight_dtype), |
|
"text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype), |
|
"cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype), |
|
"noise_scheduler": noise_scheduler, |
|
"tokenizer": tokenizer, |
|
"tokenizer_2": tokenizer_2 |
|
} |
|
|
|
pipeline = PromptDresser( |
|
vae=models["vae"], |
|
text_encoder=models["text_encoder"], |
|
text_encoder_2=models["text_encoder_2"], |
|
tokenizer=models["tokenizer"], |
|
tokenizer_2=models["tokenizer_2"], |
|
unet=models["unet"], |
|
scheduler=models["noise_scheduler"], |
|
).to(device, dtype=weight_dtype) |
|
|
|
print("✅ Модели успешно загружены") |
|
return {**models, "pipeline": pipeline} |
|
|
|
except Exception as e: |
|
print(f"❌ Ошибка загрузки моделей: {e}") |
|
raise |
|
|
|
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7): |
|
"""Генерация виртуальной примерки с очисткой памяти""" |
|
try: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
person_path = os.path.join(tmp_dir, "person.png") |
|
cloth_path = os.path.join(tmp_dir, "cloth.png") |
|
person_image.save(person_path) |
|
cloth_image.save(cloth_path) |
|
|
|
mask_image = generate_clothing_mask(person_path, label=label) |
|
pose_image = generate_openpose(person_path) |
|
|
|
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device) |
|
final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device) |
|
|
|
with torch.autocast(device): |
|
result = pipeline( |
|
image=person_image, |
|
mask_image=mask_image, |
|
pose_image=pose_image, |
|
cloth_encoder=models["cloth_encoder"], |
|
cloth_encoder_image=cloth_image, |
|
prompt=final_outfit_prompt, |
|
prompt_clothing=final_clothing_prompt, |
|
height=1024, |
|
width=768, |
|
guidance_scale=2.0, |
|
guidance_scale_img=4.5, |
|
guidance_scale_text=7.5, |
|
num_inference_steps=30, |
|
strength=1, |
|
interm_cloth_start_ratio=0.5, |
|
generator=None, |
|
).images[0] |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"❌ Ошибка генерации: {e}") |
|
return None |
|
finally: |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
print("🔍 Инициализация моделей...") |
|
models = load_models() |
|
pipeline = models["pipeline"] |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo: |
|
gr.Markdown("# 🧥 Virtual Try-On") |
|
gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки") |
|
|
|
clothing_classes = [ |
|
"фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье", |
|
"ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога", |
|
"левая рука", "правая рука", "сумка", "шарф" |
|
] |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"]) |
|
cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"]) |
|
clothing_label = gr.Dropdown( |
|
choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)], |
|
label="Класс одежды для маски", |
|
value=4 |
|
) |
|
|
|
outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit") |
|
clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print") |
|
|
|
generate_btn = gr.Button("Сгенерировать примерку", variant="primary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4] |
|
], |
|
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], |
|
label="Примеры для быстрого тестирования" |
|
) |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(label="Результат примерки", interactive=False) |
|
|
|
generate_btn.click( |
|
fn=generate_vton, |
|
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], |
|
outputs=output_image |
|
) |
|
|
|
gr.Markdown("### Инструкция:") |
|
gr.Markdown("1. Загрузите четкое фото человека в полный рост\n" |
|
"2. Загрузите фото одежды на белом фоне\n" |
|
"3. Выберите тип одежды из выпадающего списка\n" |
|
"4. При необходимости уточните описание образа или одежды\n" |
|
"5. Нажмите кнопку 'Сгенерировать примерку'") |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=1).launch( |
|
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None, |
|
share=os.getenv("GRADIO_SHARE") == "True" |
|
) |