File size: 8,580 Bytes
9b63413 2826a7d 9b63413 2826a7d ebe744a 2826a7d 9b63413 2826a7d 9b63413 2826a7d 9b63413 2826a7d 06e301f 9366128 9b63413 2826a7d 9b63413 2826a7d 9b63413 2826a7d 9b63413 2826a7d 9b63413 2826a7d 9b63413 2835d45 2826a7d 2835d45 2826a7d 2a52992 9b63413 2a52992 2835d45 3d0e28a 2835d45 9b63413 2a52992 2835d45 3d0e28a 2835d45 2a52992 2835d45 2a52992 2835d45 2a52992 3d0e28a 2a52992 2835d45 2a52992 2835d45 9b63413 2a52992 2835d45 9b63413 2835d45 9b63413 2835d45 2a52992 2835d45 9b63413 06e301f 9b63413 2826a7d 9b63413 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import torch
import gradio as gr
import tempfile
import gc
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, login
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from promptdresser.models.unet import UNet2DConditionModel
from promptdresser.models.cloth_encoder import ClothEncoder
from promptdresser.pipelines.sdxl import PromptDresser
from lib.caption import generate_caption
from lib.mask import generate_clothing_mask
from lib.pose import generate_openpose
load_dotenv()
TOKEN = os.getenv("HF_TOKEN")
login(token=TOKEN)
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = torch.float16 if device == "cuda" else torch.float32
def load_models():
"""Загружает все необходимые модели"""
print("⚙️ Загрузка моделей...")
try:
noise_scheduler = DDPMScheduler.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="scheduler"
)
tokenizer = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder"
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="tokenizer_2"
)
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="text_encoder_2"
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
unet = UNet2DConditionModel.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
subfolder="unet"
)
unet_checkpoint_path = hf_hub_download(
repo_id="Benrise/VITON-HD",
filename="VITONHD/model/pytorch_model.bin",
token=TOKEN
)
unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
cloth_encoder = ClothEncoder.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
subfolder="unet"
)
models = {
"unet": unet.to(device, dtype=weight_dtype),
"vae": vae.to(device, dtype=weight_dtype),
"text_encoder": text_encoder.to(device, dtype=weight_dtype),
"text_encoder_2": text_encoder_2.to(device, dtype=weight_dtype),
"cloth_encoder": cloth_encoder.to(device, dtype=weight_dtype),
"noise_scheduler": noise_scheduler,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2
}
pipeline = PromptDresser(
vae=models["vae"],
text_encoder=models["text_encoder"],
text_encoder_2=models["text_encoder_2"],
tokenizer=models["tokenizer"],
tokenizer_2=models["tokenizer_2"],
unet=models["unet"],
scheduler=models["noise_scheduler"],
).to(device, dtype=weight_dtype)
print("✅ Модели успешно загружены")
return {**models, "pipeline": pipeline}
except Exception as e:
print(f"❌ Ошибка загрузки моделей: {e}")
raise
def generate_vton(person_image, cloth_image, outfit_prompt="", clothing_prompt="", label=7):
"""Генерация виртуальной примерки с очисткой памяти"""
try:
torch.cuda.empty_cache()
gc.collect()
with tempfile.TemporaryDirectory() as tmp_dir:
person_path = os.path.join(tmp_dir, "person.png")
cloth_path = os.path.join(tmp_dir, "cloth.png")
person_image.save(person_path)
cloth_image.save(cloth_path)
mask_image = generate_clothing_mask(person_path, label=label)
pose_image = generate_openpose(person_path)
final_outfit_prompt = outfit_prompt or generate_caption(person_path, device)
final_clothing_prompt = clothing_prompt or generate_caption(cloth_path, device)
with torch.autocast(device):
result = pipeline(
image=person_image,
mask_image=mask_image,
pose_image=pose_image,
cloth_encoder=models["cloth_encoder"],
cloth_encoder_image=cloth_image,
prompt=final_outfit_prompt,
prompt_clothing=final_clothing_prompt,
height=1024,
width=768,
guidance_scale=2.0,
guidance_scale_img=4.5,
guidance_scale_text=7.5,
num_inference_steps=30,
strength=1,
interm_cloth_start_ratio=0.5,
generator=None,
).images[0]
return result
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
return None
finally:
torch.cuda.empty_cache()
gc.collect()
print("🔍 Инициализация моделей...")
models = load_models()
pipeline = models["pipeline"]
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container") as demo:
gr.Markdown("# 🧥 Virtual Try-On")
gr.Markdown("Загрузите фото человека и одежды для виртуальной примерки")
clothing_classes = [
"фон", "шляпа", "волосы", "очки", "верхняя одежда", "юбка", "брюки", "платье",
"ремень", "левая обувь", "правая обувь", "лицо", "левая нога", "правая нога",
"левая рука", "правая рука", "сумка", "шарф"
]
with gr.Row():
with gr.Column():
person_input = gr.Image(label="Фото человека", type="pil", sources=["upload"])
cloth_input = gr.Image(label="Фото одежды", type="pil", sources=["upload"])
clothing_label = gr.Dropdown(
choices=[(f"{i}: {desc}", i) for i, desc in enumerate(clothing_classes)],
label="Класс одежды для маски",
value=4
)
outfit_prompt = gr.Textbox(label="Описание образа (опционально)", placeholder="Например: man in casual outfit")
clothing_prompt = gr.Textbox(label="Описание одежды (опционально)", placeholder="Например: red t-shirt with print")
generate_btn = gr.Button("Сгенерировать примерку", variant="primary")
gr.Examples(
examples=[
["./test/person2.png", "./test/00008_00.jpg", "man in skirt", "black longsleeve", 4]
],
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
label="Примеры для быстрого тестирования"
)
with gr.Column():
output_image = gr.Image(label="Результат примерки", interactive=False)
generate_btn.click(
fn=generate_vton,
inputs=[person_input, cloth_input, outfit_prompt, clothing_prompt, clothing_label],
outputs=output_image
)
gr.Markdown("### Инструкция:")
gr.Markdown("1. Загрузите четкое фото человека в полный рост\n"
"2. Загрузите фото одежды на белом фоне\n"
"3. Выберите тип одежды из выпадающего списка\n"
"4. При необходимости уточните описание образа или одежды\n"
"5. Нажмите кнопку 'Сгенерировать примерку'")
if __name__ == "__main__":
demo.queue(max_size=1).launch(
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
share=os.getenv("GRADIO_SHARE") == "True"
) |