import os import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import numpy as np from networks import GMM, TOM, load_checkpoint, Options import torchvision.transforms.functional as TF def prepare_inputs(dress_path, design_path, height=256, width=192): """Prepare and normalize input images""" transform = transforms.Compose([ transforms.Resize((height, width)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dress_img = Image.open(dress_path).convert('RGB') design_img = Image.open(design_path).convert('RGB') dress_tensor = transform(dress_img).unsqueeze(0) design_tensor = transform(design_img).unsqueeze(0) # Create mask (assume design has transparent background) design_arr = np.array(design_img) if design_arr.shape[2] == 4: # Has alpha channel mask = (design_arr[:, :, 3] > 0).astype(np.float32) else: mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32) mask_img = Image.fromarray((mask * 255).astype(np.uint8)) mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0) return dress_tensor, design_tensor, mask_tensor def create_agnostic(dress_tensor): """Create agnostic representation of dress""" return dress_tensor.clone() def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir): os.makedirs(output_dir, exist_ok=True) # Prepare inputs device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path) agnostic = create_agnostic(dress_tensor) # Initialize models opt = Options() gmm = GMM(opt).to(device) tom = TOM(opt).to(device) # Load checkpoints load_checkpoint(gmm, gmm_ckpt) load_checkpoint(tom, tom_ckpt) # Move tensors to device agnostic = agnostic.to(device) design_tensor = design_tensor.to(device) design_mask = design_mask.to(device) # GMM Processing with torch.no_grad(): gmm.eval() grid, _ = gmm(agnostic, design_mask) warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True) warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True) # TOM Processing with torch.no_grad(): tom.eval() # Prepare TOM input: [agnostic, warped_design, warped_mask] tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1) p_rendered, m_composite = tom(tom_input) # Final composition tryon = warped_design * m_composite + p_rendered * (1 - m_composite) # Save output tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy() tryon = (tryon * 0.5 + 0.5) * 255 tryon = tryon.clip(0, 255).astype(np.uint8) output_path = os.path.join(output_dir, "warped_design.jpg") Image.fromarray(tryon).save(output_path) return output_path