VITON-HD / app.py
Benrise's picture
Change categories
3d0e28a
import os
import torch
import gradio as gr
import tempfile
import gc
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, login
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from promptdresser.models.unet import UNet2DConditionModel
from promptdresser.models.cloth_encoder import ClothEncoder
from promptdresser.pipelines.sdxl import PromptDresser
from lib.caption import generate_caption
from lib.mask import generate_clothing_mask
from lib.pose import generate_openpose
load_dotenv()
TOKEN = os.getenv("HF_TOKEN")
login(token=TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16 if device == "cuda" else torch.float32
def load_models():
"""Загружает все необходимые модели"""
print("⚙️ Загрузка моделей...")
try:
noise_scheduler = DDPMScheduler.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="scheduler"
)
tokenizer = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer_2"
)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder_2"
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
unet = UNet2DConditionModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="unet"
)
unet_checkpoint_path = hf_hub_download(
repo_id="Benrise/VITON-HD",
filename="VITONHD/model/pytorch_model.bin",
token=TOKEN
)
unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
cloth_encoder = ClothEncoder.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet"
)
models = {
"unet": unet.to(device, dtype=weight_dtype),
"vae": vae.to(device, dtype=weight_dtype),
"text_encoder": text_encoder.to(device, dtype=weight_dtype),
"text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
"cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
"noise_scheduler": noise_scheduler,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2
}
pipeline = PromptDresser(
vae=models["vae"],
text_encoder=models["text_encoder"],
text_encoder_2=models["text_encoder_2"],
tokenizer=models["tokenizer"],
tokenizer_2=models["tokenizer_2"],
unet=models["unet"],
scheduler=models["noise_scheduler"],
).to(device, dtype=weight_dtype)
print("✅ Модели успешно загружены")
return {**models, "pipeline": pipeline}
except Exception as e:
print(f"❌ Ошибка загрузки моделей: {e}")
raise
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7):
"""Генерация виртуальной примерки с очисткой памяти"""
try:
torch.cuda.empty_cache()
gc.collect()
with tempfile.TemporaryDirectory() as tmp_dir:
person_path = os.path.join(tmp_dir, "person.png")
cloth_path = os.path.join(tmp_dir, "cloth.png")
person_image.save(person_path)
cloth_image.save(cloth_path)
mask_image = generate_clothing_mask(person_path, label=label)
pose_image = generate_openpose(person_path)
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device)
with torch.autocast(device):
result = pipeline(
image=person_image,
mask_image=mask_image,
pose_image=pose_image,
cloth_encoder=models["cloth_encoder"],
cloth_encoder_image=cloth_image,
prompt=final_outfit_prompt,
prompt_clothing=final_clothing_prompt,
height=1024,
width=768,
guidance_scale=2.0,
guidance_scale_img=4.5,
guidance_scale_text=7.5,
num_inference_steps=30,
strength=1,
interm_cloth_start_ratio=0.5,
generator=None,
).images[0]
return result
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
return None
finally:
torch.cuda.empty_cache()
gc.collect()
print("🔍 Инициализация моделей...")
models = load_models()
pipeline = models["pipeline"]
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
gr.Markdown("# 🧥 Virtual Try-On")
gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
clothing_classes = [
"фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье",
"ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога",
"левая рука", "правая рука", "сумка", "шарф"
]
with gr.Row():
with gr.Column():
person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
clothing_label = gr.Dropdown(
choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)],
label="Класс одежды для маски",
value=4
)
outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
gr.Examples(
examples=[
["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4]
],
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
label="Примеры для быстрого тестирования"
)
with gr.Column():
output_image = gr.Image(label="Результат примерки", interactive=False)
generate_btn.click(
fn=generate_vton,
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
outputs=output_image
)
gr.Markdown("### Инструкция:")
gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
"2. Загрузите фото одежды на белом фоне\n"
"3. Выберите тип одежды из выпадающего списка\n"
"4. При необходимости уточните описание образа или одежды\n"
"5. Нажмите кнопку 'Сгенерировать примерку'")
if __name__ == "__main__":
demo.queue(max_size=1).launch(
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
share=os.getenv("GRADIO_SHARE") == "True"
)