| import torch | |
| import torch.nn as nn | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| def get_transforms(means, stds): | |
| train_transforms = A.Compose( | |
| [ | |
| A.Normalize(mean=means, std=stds, always_apply=True), | |
| A.PadIfNeeded(min_height=36, min_width=36, always_apply=True), | |
| A.RandomCrop(height=32, width=32, always_apply=True), | |
| A.HorizontalFlip(), | |
| A.Cutout (fill_value=means), | |
| ToTensorV2(), | |
| ] | |
| ) | |
| test_transforms = A.Compose( | |
| [ | |
| A.Normalize(mean=means, std=stds, always_apply=True), | |
| ToTensorV2(), | |
| ] | |
| ) | |
| return(train_transforms, test_transforms) |