9b63413
1
2
3
4
5
6
7
8
9
10
11
import torch.nn as nn class CombinedModel(nn.Module): def __init__(self, unet, cloth_encoder): super().__init__() self.unet = unet self.cloth_encoder = cloth_encoder def forward(self, x): return x