Design_warper / preprocessing.py
gaur3009's picture
Update preprocessing.py
97b4c52 verified
import torch
def pad_to_22_channels(input_tensor):
if input_tensor.shape[1] == 3: # RGB input
# Repeat channels to make 22 channels
return torch.cat([input_tensor] * 7 + [input_tensor[:, 0:1]], dim=1)
return input_tensor