Spaces:
Configuration error
Configuration error
File size: 5,422 Bytes
0034848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import os.path as op
import torch
import logging
import code
from custom_mesh_graphormer.utils.comm import get_world_size
from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps a BatchSampler, resampling from it until
a specified number of iterations have been sampled
"""
def __init__(self, batch_sampler, num_iterations, start_iter=0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_gpu, drop_last=False
)
if num_iters is not None and num_iters >= 0:
batch_sampler = IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
def make_data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader
#==============================================================================================
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
def make_hand_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader
|