Spaces:
Sleeping
Sleeping
File size: 3,123 Bytes
b61f3f8 198f320 6989926 b61f3f8 6989926 ddec91c a4a6754 b61f3f8 1edd3bd 6989926 a4a6754 6989926 198f320 6989926 198f320 6989926 198f320 a4a6754 6989926 b61f3f8 6989926 198f320 6989926 a4a6754 b61f3f8 6989926 a4a6754 6989926 b61f3f8 198f320 6989926 198f320 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
import torchvision.transforms.functional as TF
def prepare_inputs(dress_path, design_path, height=256, width=192):
"""Prepare and normalize input images"""
transform = transforms.Compose([
transforms.Resize((height, width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dress_img = Image.open(dress_path).convert('RGB')
design_img = Image.open(design_path).convert('RGB')
dress_tensor = transform(dress_img).unsqueeze(0)
design_tensor = transform(design_img).unsqueeze(0)
# Create mask (assume design has transparent background)
design_arr = np.array(design_img)
if design_arr.shape[2] == 4: # Has alpha channel
mask = (design_arr[:, :, 3] > 0).astype(np.float32)
else:
mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32)
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0)
return dress_tensor, design_tensor, mask_tensor
def create_agnostic(dress_tensor):
"""Create agnostic representation of dress"""
return dress_tensor.clone()
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Prepare inputs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path)
agnostic = create_agnostic(dress_tensor)
# Initialize models
opt = Options()
gmm = GMM(opt).to(device)
tom = TOM(opt).to(device)
# Load checkpoints
load_checkpoint(gmm, gmm_ckpt)
load_checkpoint(tom, tom_ckpt)
# Move tensors to device
agnostic = agnostic.to(device)
design_tensor = design_tensor.to(device)
design_mask = design_mask.to(device)
# GMM Processing
with torch.no_grad():
gmm.eval()
grid, _ = gmm(agnostic, design_mask)
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True)
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True)
# TOM Processing
with torch.no_grad():
tom.eval()
# Prepare TOM input: [agnostic, warped_design, warped_mask]
tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1)
p_rendered, m_composite = tom(tom_input)
# Final composition
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Save output
tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
tryon = (tryon * 0.5 + 0.5) * 255
tryon = tryon.clip(0, 255).astype(np.uint8)
output_path = os.path.join(output_dir, "warped_design.jpg")
Image.fromarray(tryon).save(output_path)
return output_path |