File size: 8,580 Bytes
9b63413
 
 
 
2826a7d
 
 
9b63413
 
 
 
 
 
 
 
 
 
2826a7d
 
ebe744a
2826a7d
9b63413
 
 
 
2826a7d
9b63413
 
2826a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b63413
2826a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06e301f
 
 
 
 
9366128
9b63413
2826a7d
 
 
 
9b63413
2826a7d
 
 
 
 
 
 
 
 
 
9b63413
2826a7d
 
 
 
 
 
 
 
 
9b63413
2826a7d
 
9b63413
2826a7d
 
 
9b63413
2835d45
2826a7d
 
 
 
 
 
 
 
 
 
 
2835d45
2826a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a52992
9b63413
2a52992
2835d45
 
3d0e28a
 
 
2835d45
 
9b63413
 
2a52992
 
2835d45
 
 
3d0e28a
2835d45
 
2a52992
 
2835d45
2a52992
2835d45
2a52992
 
3d0e28a
2a52992
2835d45
2a52992
 
2835d45
9b63413
2a52992
2835d45
9b63413
 
2835d45
9b63413
 
2835d45
2a52992
 
 
2835d45
 
 
9b63413
 
06e301f
9b63413
2826a7d
9b63413
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import torch
import gradio as gr
import tempfile
import gc
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, login
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection

from promptdresser.models.unet import UNet2DConditionModel
from promptdresser.models.cloth_encoder import ClothEncoder
from promptdresser.pipelines.sdxl import PromptDresser
from lib.caption import generate_caption
from lib.mask import generate_clothing_mask
from lib.pose import generate_openpose

load_dotenv()
TOKEN = os.getenv("HF_TOKEN")
login(token=TOKEN)

device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16 if device == "cuda" else torch.float32

def load_models():
    """Загружает все необходимые модели"""
    print("⚙️ Загрузка моделей...")
    
    try:
        noise_scheduler = DDPMScheduler.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="scheduler"
        )
        
        tokenizer = CLIPTokenizer.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="tokenizer"
        )
        
        text_encoder = CLIPTextModel.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="text_encoder"
        )
        
        tokenizer_2 = CLIPTokenizer.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="tokenizer_2"
        )
        
        text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="text_encoder_2"
        )
        
        vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
        unet = UNet2DConditionModel.from_pretrained(
            "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", 
            subfolder="unet"
        )
        
        unet_checkpoint_path = hf_hub_download(
            repo_id="Benrise/VITON-HD",
            filename="VITONHD/model/pytorch_model.bin",
            token=TOKEN
        )
        unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
        
        cloth_encoder = ClothEncoder.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            subfolder="unet"
        )
        
        models = {
            "unet": unet.to(device, dtype=weight_dtype),
            "vae": vae.to(device, dtype=weight_dtype),
            "text_encoder": text_encoder.to(device, dtype=weight_dtype),
            "text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
            "cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
            "noise_scheduler": noise_scheduler,
            "tokenizer": tokenizer,
            "tokenizer_2": tokenizer_2
        }
        
        pipeline = PromptDresser(
            vae=models["vae"],
            text_encoder=models["text_encoder"],
            text_encoder_2=models["text_encoder_2"],
            tokenizer=models["tokenizer"],
            tokenizer_2=models["tokenizer_2"],
            unet=models["unet"],
            scheduler=models["noise_scheduler"],
        ).to(device, dtype=weight_dtype)
        
        print("✅ Модели успешно загружены")
        return {**models, "pipeline": pipeline}
    
    except Exception as e:
        print(f"❌ Ошибка загрузки моделей: {e}")
        raise

def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7):
    """Генерация виртуальной примерки с очисткой памяти"""
    try:
        torch.cuda.empty_cache()
        gc.collect()
        
        with tempfile.TemporaryDirectory() as tmp_dir:
            person_path = os.path.join(tmp_dir, "person.png")
            cloth_path = os.path.join(tmp_dir, "cloth.png")
            person_image.save(person_path)
            cloth_image.save(cloth_path)
            
            mask_image = generate_clothing_mask(person_path, label=label)
            pose_image = generate_openpose(person_path)
            
            final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
            final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device)
            
            with torch.autocast(device):
                result = pipeline(
                    image=person_image,
                    mask_image=mask_image,
                    pose_image=pose_image,
                    cloth_encoder=models["cloth_encoder"],
                    cloth_encoder_image=cloth_image,
                    prompt=final_outfit_prompt,
                    prompt_clothing=final_clothing_prompt,
                    height=1024,
                    width=768,
                    guidance_scale=2.0,
                    guidance_scale_img=4.5,
                    guidance_scale_text=7.5,
                    num_inference_steps=30,
                    strength=1,
                    interm_cloth_start_ratio=0.5,
                    generator=None,
                ).images[0]
        
        return result
    
    except Exception as e:
        print(f"❌ Ошибка генерации: {e}")
        return None
    finally:
        torch.cuda.empty_cache()
        gc.collect()

print("🔍 Инициализация моделей...")
models = load_models()
pipeline = models["pipeline"]

with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
    gr.Markdown("# 🧥 Virtual Try-On")
    gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")

    clothing_classes = [
        "фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье", 
        "ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога", 
        "левая рука", "правая рука", "сумка", "шарф"
    ]

    with gr.Row():
        with gr.Column():
            person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
            cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
            clothing_label = gr.Dropdown(
                choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)],
                label="Класс одежды для маски",
                value=4
            )

            outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
            clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")

            generate_btn = gr.Button("Сгенерировать примерку", variant="primary")

            gr.Examples(
                examples=[
                    ["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4]
                ],
                inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
                label="Примеры для быстрого тестирования"
            )

        with gr.Column():
            output_image = gr.Image(label="Результат примерки", interactive=False)

    generate_btn.click(
        fn=generate_vton,
        inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
        outputs=output_image
    )

    gr.Markdown("### Инструкция:")
    gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
                "2. Загрузите фото одежды на белом фоне\n"
                "3. Выберите тип одежды из выпадающего списка\n"
                "4. При необходимости уточните описание образа или одежды\n"
                "5. Нажмите кнопку 'Сгенерировать примерку'")

if __name__ == "__main__":
    demo.queue(max_size=1).launch(
        server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
        share=os.getenv("GRADIO_SHARE") == "True"
    )