Design_warper / warp_design_on_dress.py
gaur3009's picture
Update warp_design_on_dress.py
6989926 verified
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
import torchvision.transforms.functional as TF
def prepare_inputs(dress_path, design_path, height=256, width=192):
"""Prepare and normalize input images"""
transform = transforms.Compose([
transforms.Resize((height, width)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dress_img = Image.open(dress_path).convert('RGB')
design_img = Image.open(design_path).convert('RGB')
dress_tensor = transform(dress_img).unsqueeze(0)
design_tensor = transform(design_img).unsqueeze(0)
# Create mask (assume design has transparent background)
design_arr = np.array(design_img)
if design_arr.shape[2] == 4: # Has alpha channel
mask = (design_arr[:, :, 3] > 0).astype(np.float32)
else:
mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32)
mask_img = Image.fromarray((mask * 255).astype(np.uint8))
mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0)
return dress_tensor, design_tensor, mask_tensor
def create_agnostic(dress_tensor):
"""Create agnostic representation of dress"""
return dress_tensor.clone()
def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
os.makedirs(output_dir, exist_ok=True)
# Prepare inputs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path)
agnostic = create_agnostic(dress_tensor)
# Initialize models
opt = Options()
gmm = GMM(opt).to(device)
tom = TOM(opt).to(device)
# Load checkpoints
load_checkpoint(gmm, gmm_ckpt)
load_checkpoint(tom, tom_ckpt)
# Move tensors to device
agnostic = agnostic.to(device)
design_tensor = design_tensor.to(device)
design_mask = design_mask.to(device)
# GMM Processing
with torch.no_grad():
gmm.eval()
grid, _ = gmm(agnostic, design_mask)
warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True)
warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True)
# TOM Processing
with torch.no_grad():
tom.eval()
# Prepare TOM input: [agnostic, warped_design, warped_mask]
tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1)
p_rendered, m_composite = tom(tom_input)
# Final composition
tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
# Save output
tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
tryon = (tryon * 0.5 + 0.5) * 255
tryon = tryon.clip(0, 255).astype(np.uint8)
output_path = os.path.join(output_dir, "warped_design.jpg")
Image.fromarray(tryon).save(output_path)
return output_path