| import numpy as np |
| from torchvision import transforms |
|
|
| class CIFARTransform: |
| MEAN = [0.5071, 0.4866, 0.4409] |
| STD = [0.2675, 0.2565, 0.2761] |
| |
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean=MEAN, std=STD)] |
| |
| resnet_train_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=63 / 255), |
| *common_trfs |
| ]) |
| |
| resnet_test_transform = transforms.Compose([*common_trfs]) |
| |
| |
| |
|
|
| |
| dset_mean = (0., 0., 0.) |
| dset_std = (1., 1., 1.) |
| vit_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std)]) |
| |
| vit_test_transform = transforms.Compose([ |
| transforms.Resize(224), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std)]) |
|
|
| |
| mean=[x/255 for x in [125.3,123.0,113.9]] |
| std=[x/255 for x in [63.0,62.1,66.7]] |
|
|
| alexnet_train_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| alexnet_test_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| @staticmethod |
| def get_transform(model_type, mode): |
| if model_type == 'resnet': |
| if mode == 'train': |
| return CIFARTransform.resnet_train_transform |
| elif mode == 'test': |
| return CIFARTransform.resnet_test_transform |
| elif model_type == 'vit': |
| if mode == 'train': |
| return CIFARTransform.vit_train_transform |
| elif mode == 'test': |
| return CIFARTransform.vit_test_transform |
| elif model_type == 'alexnet': |
| if mode == 'train': |
| return CIFARTransform.alexnet_train_transform |
| elif mode == 'test': |
| return CIFARTransform.alexnet_test_transform |
| else: |
| raise ValueError("Unsupported model type") |
| |
| class ImageNetTransform: |
| MEAN=[0.4914, 0.4822, 0.4465] |
| STD=[0.2023, 0.1994, 0.2010] |
| |
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean=MEAN, std=STD)] |
| |
| resnet_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=63 / 255), |
| *common_trfs |
| ]) |
| |
| resnet_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| *common_trfs |
| ]) |
| |
| |
| dset_mean = (0., 0., 0.) |
| dset_std = (1., 1., 1.) |
| vit_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std), |
| ]) |
| |
| vit_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std), |
| ]) |
| |
| @staticmethod |
| def get_transform(model_type, mode): |
| if model_type == 'resnet': |
| if mode == 'train': |
| return ImageNetTransform.resnet_train_transform |
| elif mode == 'test': |
| return ImageNetTransform.resnet_test_transform |
| elif model_type == 'vit': |
| if mode == 'train': |
| return ImageNetTransform.vit_train_transform |
| elif mode == 'test': |
| return ImageNetTransform.vit_test_transform |
| else: |
| raise ValueError("Unsupported model type") |
| |
| class ImageNetRTransform: |
| mean = [0.4914, 0.4822, 0.4465] |
| std = [0.2023, 0.1994, 0.2010] |
| |
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean, std)] |
| |
| resnet_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=63 / 255), |
| *common_trfs]) |
| |
| resnet_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| *common_trfs]) |
|
|
| mean = [0., 0., 0.] |
| std = [1., 1., 1.] |
|
|
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean, std)] |
|
|
| vit_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| *common_trfs]) |
| |
| vit_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| *common_trfs]) |
| |
|
|
| |
| mean=[x/255 for x in [125.3,123.0,113.9]] |
| std=[x/255 for x in [63.0,62.1,66.7]] |
|
|
| alexnet_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| alexnet_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| @staticmethod |
| def get_transform(model_type, mode): |
| if model_type == 'resnet': |
| if mode == 'train': |
| return ImageNetRTransform.resnet_train_transform |
| elif mode == 'test': |
| return ImageNetRTransform.resnet_test_transform |
| elif model_type == 'vit': |
| if mode == 'train': |
| return ImageNetRTransform.vit_train_transform |
| elif mode == 'test': |
| return ImageNetRTransform.vit_test_transform |
| elif model_type == 'alexnet': |
| if mode == 'train': |
| return ImageNetRTransform.alexnet_train_transform |
| elif mode == 'test': |
| return ImageNetRTransform.alexnet_test_transform |
| else: |
| raise ValueError("Unsupported model type") |
|
|
| class TinyImageNetTransform: |
| |
| MEAN = [0.485, 0.456, 0.406] |
| STD = [0.229, 0.224, 0.225] |
|
|
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean=MEAN, std=STD)] |
|
|
| |
| resnet_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(64), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=63 / 255), |
| *common_trfs |
| ]) |
|
|
| resnet_test_transform = transforms.Compose([ |
| transforms.Resize(64), |
| transforms.CenterCrop(64), |
| *common_trfs |
| ]) |
|
|
| |
| dset_mean = (0., 0., 0.) |
| dset_std = (1., 1., 1.) |
|
|
| vit_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std) |
| ]) |
|
|
| vit_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std) |
| ]) |
|
|
| |
| mean=[x/255 for x in [125.3,123.0,113.9]] |
| std=[x/255 for x in [63.0,62.1,66.7]] |
|
|
| alexnet_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| alexnet_test_transform = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| @staticmethod |
| def get_transform(model_type, mode): |
| if model_type == 'resnet': |
| if mode == 'train': |
| return TinyImageNetTransform.resnet_train_transform |
| elif mode == 'test': |
| return TinyImageNetTransform.resnet_test_transform |
| elif model_type == 'vit': |
| if mode == 'train': |
| return TinyImageNetTransform.vit_train_transform |
| elif mode == 'test': |
| return TinyImageNetTransform.vit_test_transform |
| elif model_type == 'alexnet': |
| if mode == 'train': |
| return TinyImageNetTransform.alexnet_train_transform |
| elif mode == 'test': |
| return TinyImageNetTransform.alexnet_test_transform |
| else: |
| raise ValueError("Unsupported model type") |
|
|
|
|
| class FiveDatasetsTransform: |
| MEAN = [0.5071, 0.4866, 0.4409] |
| STD = [0.2675, 0.2565, 0.2761] |
| |
| common_trfs = [transforms.ToTensor(), |
| transforms.Normalize(mean=MEAN, std=STD)] |
| |
| resnet_train_transform = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ColorJitter(brightness=63 / 255), |
| *common_trfs |
| ]) |
| |
| resnet_test_transform = transforms.Compose([ |
| transforms.Resize(32), |
| *common_trfs |
| ]) |
|
|
| |
| dset_mean = (0., 0., 0.) |
| dset_std = (1., 1., 1.) |
| vit_train_transform = transforms.Compose([ |
| transforms.RandomResizedCrop(224), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std)]) |
| |
| vit_test_transform = transforms.Compose([ |
| transforms.Resize(224), |
| transforms.ToTensor(), |
| transforms.Normalize(dset_mean, dset_std)]) |
|
|
| |
| mean=[x/255 for x in [125.3,123.0,113.9]] |
| std=[x/255 for x in [63.0,62.1,66.7]] |
|
|
| alexnet_train_transform = transforms.Compose([ |
| transforms.Resize(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| alexnet_test_transform = transforms.Compose([ |
| transforms.Resize(32), |
| transforms.ToTensor(), |
| transforms.Normalize(mean,std)]) |
|
|
| @staticmethod |
| def get_transform(model_type, mode): |
| if model_type == 'resnet': |
| if mode == 'train': |
| return FiveDatasetsTransform.resnet_train_transform |
| elif mode == 'test': |
| return FiveDatasetsTransform.resnet_test_transform |
| elif model_type == 'vit': |
| if mode == 'train': |
| return FiveDatasetsTransform.vit_train_transform |
| elif mode == 'test': |
| return FiveDatasetsTransform.vit_test_transform |
| elif model_type == 'alexnet': |
| if mode == 'train': |
| return FiveDatasetsTransform.alexnet_train_transform |
| elif mode == 'test': |
| return FiveDatasetsTransform.alexnet_test_transform |
| else: |
| raise ValueError("Unsupported model type") |
|
|
| transform_classes = { |
| 'cifar': CIFARTransform, |
| 'imagenet': ImageNetTransform, |
| 'imagenet-r': ImageNetRTransform, |
| 'tiny-imagenet': TinyImageNetTransform, |
| '5-datasets': FiveDatasetsTransform |
| } |