| | import sys |
| | from functools import partial |
| |
|
| | from typing import Callable |
| | from typing import Dict |
| | from typing import Tuple |
| | from typing import Union |
| | from argparse import Namespace |
| |
|
| | sys.path.append("vision/references/segmentation") |
| |
|
| | import presets |
| | import torch |
| | import torch.utils.data |
| | import torchvision |
| | import utils |
| | from torch import nn |
| | from common import flops_calculation_function |
| | from common import NanSafeConfusionMatrix as ConfusionMatrix |
| | from common import get_coco |
| |
|
| |
|
| | def get_dataset(args: Namespace, is_train: bool, transform: Callable = None) -> Tuple[torch.utils.data.Dataset, int]: |
| | def sbd(*args, **kwargs): |
| | kwargs.pop("use_v2") |
| | return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) |
| |
|
| | def voc(*args, **kwargs): |
| | kwargs.pop("use_v2") |
| | return torchvision.datasets.VOCSegmentation(*args, **kwargs) |
| |
|
| | paths = { |
| | "voc": (args.data_path, voc, 21), |
| | "voc_aug": (args.data_path, sbd, 21), |
| | "coco": (args.data_path, get_coco, 21), |
| | "coco_orig": (args.data_path, partial(get_coco, use_orig=True), 81) |
| | } |
| | p, ds_fn, num_classes = paths["coco_orig"] |
| |
|
| | if transform is None: |
| | transform = presets.SegmentationPresetEval(base_size=520, backend=args.backend, use_v2=args.use_v2) |
| | image_set = "train" if is_train else "val" |
| | ds = ds_fn(p, image_set=image_set, transforms=transform, use_v2=args.use_v2) |
| | return ds, num_classes |
| |
|
| |
|
| | def criterion(inputs: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor: |
| | losses = {} |
| | for name, x in inputs.items(): |
| | losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) |
| |
|
| | if len(losses) == 1: |
| | return losses["out"] |
| |
|
| | return losses["out"] + 0.5 * losses["aux"] |
| |
|
| |
|
| | def evaluate( |
| | model: torch.nn.Module, |
| | data_loader: torch.utils.data.DataLoader, |
| | device: Union[str, torch.device], |
| | num_classes: int, |
| | criterion: Callable, |
| | ) -> Tuple[ConfusionMatrix, float]: |
| | model.eval() |
| | confmat = ConfusionMatrix(num_classes) |
| | metric_logger = utils.MetricLogger(delimiter=" ") |
| | header = "Test:" |
| | num_processed_samples = 0 |
| | with torch.inference_mode(): |
| | for batch_n, (image, target) in enumerate(metric_logger.log_every(data_loader, 100, header)): |
| | image, target = image.to(device), target.to(device) |
| | output = model(image) |
| | loss = criterion(output, target) |
| | output = output["out"] |
| |
|
| | confmat.update(target.flatten(), output.argmax(1).flatten()) |
| | |
| | |
| | num_processed_samples += image.shape[0] |
| |
|
| | metric_logger.update(loss=loss.item()) |
| |
|
| | confmat.reduce_from_all_processes() |
| |
|
| | return confmat, metric_logger.loss.global_avg |
| |
|
| |
|
| | def main(args): |
| | if args.backend.lower() != "pil" and not args.use_v2: |
| | |
| | raise ValueError("Use --use-v2 if you want to use the tv_tensor or tensor backend.") |
| | if args.use_v2: |
| | raise ValueError("v2 is only supported for coco dataset for now.") |
| |
|
| | print(args) |
| |
|
| | device = torch.device(args.device) |
| |
|
| | if args.use_deterministic_algorithms: |
| | torch.backends.cudnn.benchmark = False |
| | torch.use_deterministic_algorithms(True) |
| | else: |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | dataset_test, num_classes = get_dataset(args, is_train=False) |
| |
|
| | test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
| |
|
| | data_loader_test = torch.utils.data.DataLoader( |
| | dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn |
| | ) |
| |
|
| | checkpoint = torch.load(args.model_path) |
| | model = checkpoint["model"] |
| | model.to(device) |
| | model_flops = flops_calculation_function(model=model, input_sample=next(iter(data_loader_test))[0].to(device)) |
| | print(f"Model Flops: {model_flops}M") |
| |
|
| | |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cudnn.deterministic = True |
| | confmat, loss = evaluate( |
| | model=model, |
| | data_loader=data_loader_test, |
| | device=device, |
| | num_classes=num_classes, |
| | criterion=criterion, |
| | ) |
| | print(confmat) |
| | return |
| |
|
| | def get_args_parser(add_help=True): |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help) |
| |
|
| | parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path") |
| | parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") |
| | parser.add_argument( |
| | "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" |
| | ) |
| | parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run") |
| |
|
| | parser.add_argument( |
| | "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" |
| | ) |
| | parser.add_argument( |
| | "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." |
| | ) |
| | |
| |
|
| | parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") |
| | parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms") |
| | parser.add_argument("--model-path", default=None, help="Path to model checkpoint.") |
| | return parser |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = get_args_parser().parse_args() |
| | main(args) |
| |
|