File size: 246 Bytes
580136d
 
 
 
97b4c52
 
b314508
1
2
3
4
5
6
7
import torch

def pad_to_22_channels(input_tensor):
    if input_tensor.shape[1] == 3:  # RGB input
        # Repeat channels to make 22 channels
        return torch.cat([input_tensor] * 7 + [input_tensor[:, 0:1]], dim=1)
    return input_tensor