import timm import torch.nn as nn def replace_last_layer(model, num_classes): modules = list(model.named_modules()) for name, module in reversed(modules): if isinstance(module, nn.Linear): in_features = module.in_features new_fc = nn.Linear(in_features, num_classes) new_fc.requires_grad = True if "." in name: parent_name, child_name = name.rsplit(".", 1) parent = dict(model.named_modules())[parent_name] setattr(parent, child_name, new_fc) else: setattr(model, name, new_fc) break elif isinstance(module, nn.Conv2d): out_channels = module.out_channels new_fc = nn.Conv2d( out_channels, num_classes, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, ) new_fc.requires_grad = True if "." in name: parent_name, child_name = name.rsplit(".", 1) parent = dict(model.named_modules())[parent_name] setattr(parent, child_name, new_fc) else: setattr(model, name, new_fc) break def enable_first_layer_grad(model): for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): module.requires_grad = True break def create_model(key, in_chans=1, num_classes=1): model = timm.create_model( key, pretrained=False, in_chans=in_chans, num_classes=num_classes ) for param in model.parameters(): param.requires_grad = False enable_first_layer_grad(model) replace_last_layer(model, num_classes) return model