File size: 3,123 Bytes
b61f3f8
 
 
 
 
198f320
 
6989926
b61f3f8
6989926
 
 
 
ddec91c
a4a6754
b61f3f8
1edd3bd
6989926
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4a6754
6989926
198f320
6989926
 
 
198f320
6989926
 
 
 
 
 
 
 
 
198f320
a4a6754
6989926
 
 
 
 
 
 
 
 
 
 
 
b61f3f8
6989926
198f320
6989926
 
 
a4a6754
b61f3f8
6989926
 
 
a4a6754
6989926
 
b61f3f8
198f320
 
6989926
 
 
 
 
 
 
198f320
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
from networks import GMM, TOM, load_checkpoint, Options
import torchvision.transforms.functional as TF

def prepare_inputs(dress_path, design_path, height=256, width=192):
    """Prepare and normalize input images"""
    transform = transforms.Compose([
        transforms.Resize((height, width)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    dress_img = Image.open(dress_path).convert('RGB')
    design_img = Image.open(design_path).convert('RGB')
    
    dress_tensor = transform(dress_img).unsqueeze(0)
    design_tensor = transform(design_img).unsqueeze(0)
    
    # Create mask (assume design has transparent background)
    design_arr = np.array(design_img)
    if design_arr.shape[2] == 4:  # Has alpha channel
        mask = (design_arr[:, :, 3] > 0).astype(np.float32)
    else:
        mask = np.ones((design_arr.shape[0], design_arr.shape[1]), dtype=np.float32)
    
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    mask_tensor = TF.to_tensor(TF.resize(mask_img, (height, width))).unsqueeze(0)
    
    return dress_tensor, design_tensor, mask_tensor

def create_agnostic(dress_tensor):
    """Create agnostic representation of dress"""
    return dress_tensor.clone()

def run_design_warp_on_dress(dress_path, design_path, gmm_ckpt, tom_ckpt, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    # Prepare inputs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dress_tensor, design_tensor, design_mask = prepare_inputs(dress_path, design_path)
    agnostic = create_agnostic(dress_tensor)
    
    # Initialize models
    opt = Options()
    gmm = GMM(opt).to(device)
    tom = TOM(opt).to(device)
    
    # Load checkpoints
    load_checkpoint(gmm, gmm_ckpt)
    load_checkpoint(tom, tom_ckpt)
    
    # Move tensors to device
    agnostic = agnostic.to(device)
    design_tensor = design_tensor.to(device)
    design_mask = design_mask.to(device)
    
    # GMM Processing
    with torch.no_grad():
        gmm.eval()
        grid, _ = gmm(agnostic, design_mask)
        warped_design = F.grid_sample(design_tensor, grid, padding_mode='border', align_corners=True)
        warped_mask = F.grid_sample(design_mask, grid, padding_mode='zeros', align_corners=True)
    
    # TOM Processing
    with torch.no_grad():
        tom.eval()
        # Prepare TOM input: [agnostic, warped_design, warped_mask]
        tom_input = torch.cat([agnostic, warped_design, warped_mask], dim=1)
        p_rendered, m_composite = tom(tom_input)
        
        # Final composition
        tryon = warped_design * m_composite + p_rendered * (1 - m_composite)
    
    # Save output
    tryon = tryon.squeeze().permute(1, 2, 0).cpu().numpy()
    tryon = (tryon * 0.5 + 0.5) * 255
    tryon = tryon.clip(0, 255).astype(np.uint8)
    
    output_path = os.path.join(output_dir, "warped_design.jpg")
    Image.fromarray(tryon).save(output_path)
    
    return output_path