File size: 243 Bytes
9b63413
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch.nn as nn

class CombinedModel(nn.Module):
    def __init__(self, unet, cloth_encoder):
        super().__init__()
        self.unet = unet
        self.cloth_encoder = cloth_encoder
    def forward(self, x):
        return x