sohamc10's picture
gradio app
9b0d6c2
import warnings
warnings.filterwarnings("ignore")
import os
import sys
import time
import math
import random
import datetime
import subprocess
from collections import defaultdict, deque
import numpy as np
import torch
import torch.distributed as dist
import argparse
from numpy.random import randint
def GMML_replace_list(samples, corrup_prev, masks_prev, drop_type='noise', max_replace=0.35, align=16):
rep_drop = 1 if drop_type == '' else (1 / (len(drop_type.split('-')) + 1))
n_imgs = samples.size()[0] # this is batch size, but in case bad inistance happened while loading
samples_aug = samples.detach().clone()
masks = torch.zeros_like(samples_aug)
for i in range(n_imgs):
idx_rnd = randint(0, n_imgs)
if random.random() < rep_drop:
samples_aug[i], masks[i] = GMML_drop_rand_patches(samples_aug[i], samples[idx_rnd], max_replace=max_replace,
align=align)
else:
samples_aug[i], masks[i] = corrup_prev[i], masks_prev[i]
return samples_aug, masks
def GMML_drop_rand_patches(X, X_rep=None, drop_type='noise', max_replace=0.7, align=16, max_block_sz=0.3):
#######################
# max_replace: percentage of image to be replaced
# align: align corruption with the patch sizes
# max_block_sz: percentage of the maximum block to be dropped
#######################
np.random.seed()
C, H, W = X.size()
n_drop_pix = np.random.uniform(min(0.5, max_replace), max_replace) * H * W
mx_blk_height = int(H * max_block_sz)
mx_blk_width = int(W * max_block_sz)
align = max(1, align)
mask = torch.zeros_like(X)
drop_t = np.random.choice(drop_type.split('-'))
while mask[0].sum() < n_drop_pix:
####### get a random block to replace
rnd_r = (randint(0, H - align) // align) * align
rnd_c = (randint(0, W - align) // align) * align
rnd_h = min(randint(align, mx_blk_height), H - rnd_r)
rnd_h = round(rnd_h / align) * align
rnd_w = min(randint(align, mx_blk_width), W - rnd_c)
rnd_w = round(rnd_w / align) * align
if X_rep is not None:
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X_rep[:, rnd_r:rnd_r + rnd_h,
rnd_c:rnd_c + rnd_w].detach().clone()
else:
if drop_t == 'noise':
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.empty((C, rnd_h, rnd_w), dtype=X.dtype,
device=X.device).normal_()
elif drop_t == 'zeros':
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = torch.zeros((C, rnd_h, rnd_w), dtype=X.dtype,
device=X.device)
else:
####### get a random block to replace from
rnd_r2 = (randint(0, H - rnd_h) // align) * align
rnd_c2 = (randint(0, W - rnd_w) // align) * align
X[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = X[:, rnd_r2:rnd_r2 + rnd_h,
rnd_c2:rnd_c2 + rnd_w].detach().clone()
mask[:, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = 1
return X, mask
class collate_batch(object): # replace from other images
def __init__(self, drop_replace=0., drop_align=1):
self.drop_replace = drop_replace
self.drop_align = drop_align
def __call__(self, batch):
batch = torch.utils.data.dataloader.default_collate(batch)
if self.drop_replace > 0:
batch[0][1][0], batch[0][2][0] = GMML_replace_list(batch[0][0][0], batch[0][1][0], batch[0][2][0],
max_replace=self.drop_replace, align=self.drop_align)
batch[0][1][1], batch[0][2][1] = GMML_replace_list(batch[0][0][1], batch[0][1][1], batch[0][2][1],
max_replace=self.drop_replace, align=self.drop_align)
return batch
def clip_gradients(model, clip):
norms = []
for name, p in model.named_parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
norms.append(param_norm.item())
clip_coef = clip / (param_norm + 1e-6)
if clip_coef < 1:
p.grad.data.mul_(clip_coef)
return norms
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
if epoch >= freeze_last_layer:
return
for n, p in model.named_parameters():
if "last_layer" in n:
p.grad = None
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
if not os.path.isfile(ckp_path):
return
print("Found checkpoint at {}".format(ckp_path))
# open checkpoint file
checkpoint = torch.load(ckp_path, map_location="cpu")
# key is what to look for in the checkpoint file
# value is the object to load
# example: {'state_dict': model}
for key, value in kwargs.items():
if key in checkpoint and value is not None:
try:
msg = value.load_state_dict(checkpoint[key], strict=False)
print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
except TypeError:
try:
msg = value.load_state_dict(checkpoint[key])
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
except ValueError:
print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
else:
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
FALSY_STRINGS = {"off", "false", "0"}
TRUTHY_STRINGS = {"on", "true", "1"}
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
def fix_random_seeds(seed=31):
"""
Fix random seeds.
"""
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.6f} ({global_avg:.6f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.6f}')
data_time = SmoothedValue(fmt='{avg:.6f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.6f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def init_distributed_mode(args):
# launched with torch.distributed.launch
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
# launched with submitit on a slurm cluster
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif torch.cuda.is_available():
print('Will run the code on one GPU.')
args.rank, args.gpu, args.world_size = 0, 0, 1
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
else:
print('Does not support training without GPU.')
sys.exit(1)
args.distributed = True
dist.init_process_group(
backend="nccl",
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.cuda.set_device(args.gpu)
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
dist.barrier()
setup_for_distributed(args.rank == 0)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def get_params_groups(model):
regularized = []
not_regularized = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# we do not regularize biases nor Norm parameters
if name.endswith(".bias") or len(param.shape) == 1:
not_regularized.append(param)
else:
regularized.append(param)
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]