import os import torch import gradio as gr import tempfile import gc from dotenv import load_dotenv from huggingface_hub import hf_hub_download, login from diffusers import AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from promptdresser.models.unet import UNet2DConditionModel from promptdresser.models.cloth_encoder import ClothEncoder from promptdresser.pipelines.sdxl import PromptDresser from lib.caption import generate_caption from lib.mask import generate_clothing_mask from lib.pose import generate_openpose load_dotenv() TOKEN = os.getenv("HF_TOKEN") login(token=TOKEN) device = "cuda" if torch.cuda.is_available() else "cpu" weight_dtype = torch.float16 if device == "cuda" else torch.float32 def load_models(): """Загружает все необходимые модели""" print("⚙️ Загрузка моделей...") try: noise_scheduler = DDPMScheduler.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="scheduler" ) tokenizer = CLIPTokenizer.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer" ) text_encoder = CLIPTextModel.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder" ) tokenizer_2 = CLIPTokenizer.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="tokenizer_2" ) text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="text_encoder_2" ) vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix") unet = UNet2DConditionModel.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", subfolder="unet" ) unet_checkpoint_path = hf_hub_download( repo_id="Benrise/VITON-HD", filename="VITONHD/model/pytorch_model.bin", token=TOKEN ) unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device)) cloth_encoder = ClothEncoder.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet" ) models = { "unet": unet.to(device, dtype=weight_dtype), "vae": vae.to(device, dtype=weight_dtype), "text_encoder": text_encoder.to(device, dtype=weight_dtype), "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype), "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype), "noise_scheduler": noise_scheduler, "tokenizer": tokenizer, "tokenizer_2": tokenizer_2 } pipeline = PromptDresser( vae=models["vae"], text_encoder=models["text_encoder"], text_encoder_2=models["text_encoder_2"], tokenizer=models["tokenizer"], tokenizer_2=models["tokenizer_2"], unet=models["unet"], scheduler=models["noise_scheduler"], ).to(device, dtype=weight_dtype) print("✅ Модели успешно загружены") return {**models, "pipeline": pipeline} except Exception as e: print(f"❌ Ошибка загрузки моделей: {e}") raise def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7): """Генерация виртуальной примерки с очисткой памяти""" try: torch.cuda.empty_cache() gc.collect() with tempfile.TemporaryDirectory() as tmp_dir: person_path = os.path.join(tmp_dir, "person.png") cloth_path = os.path.join(tmp_dir, "cloth.png") person_image.save(person_path) cloth_image.save(cloth_path) mask_image = generate_clothing_mask(person_path, label=label) pose_image = generate_openpose(person_path) final_outfit_prompt = outfit_prompt or generate_caption(person_path, device) final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device) with torch.autocast(device): result = pipeline( image=person_image, mask_image=mask_image, pose_image=pose_image, cloth_encoder=models["cloth_encoder"], cloth_encoder_image=cloth_image, prompt=final_outfit_prompt, prompt_clothing=final_clothing_prompt, height=1024, width=768, guidance_scale=2.0, guidance_scale_img=4.5, guidance_scale_text=7.5, num_inference_steps=30, strength=1, interm_cloth_start_ratio=0.5, generator=None, ).images[0] return result except Exception as e: print(f"❌ Ошибка генерации: {e}") return None finally: torch.cuda.empty_cache() gc.collect() print("🔍 Инициализация моделей...") models = load_models() pipeline = models["pipeline"] with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo: gr.Markdown("# 🧥 Virtual Try-On") gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки") clothing_classes = [ "фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье", "ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога", "левая рука", "правая рука", "сумка", "шарф" ] with gr.Row(): with gr.Column(): person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"]) cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"]) clothing_label = gr.Dropdown( choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)], label="Класс одежды для маски", value=4 ) outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit") clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print") generate_btn = gr.Button("Сгенерировать примерку", variant="primary") gr.Examples( examples=[ ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4] ], inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], label="Примеры для быстрого тестирования" ) with gr.Column(): output_image = gr.Image(label="Результат примерки", interactive=False) generate_btn.click( fn=generate_vton, inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label], outputs=output_image ) gr.Markdown("### Инструкция:") gr.Markdown("1. Загрузите четкое фото человека в полный рост\n" "2. Загрузите фото одежды на белом фоне\n" "3. Выберите тип одежды из выпадающего списка\n" "4. При необходимости уточните описание образа или одежды\n" "5. Нажмите кнопку 'Сгенерировать примерку'") if __name__ == "__main__": demo.queue(max_size=1).launch( server_name="0.0.0.0" if os.getenv("SPACE_ID") else None, share=os.getenv("GRADIO_SHARE") == "True" )