|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import sys |
|
from typing import Iterable |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from utils import misc as misc |
|
|
|
|
|
def split_prediction_conf(predictions, with_conf=False): |
|
if not with_conf: |
|
return predictions, None |
|
conf = predictions[:, -1:, :, :] |
|
predictions = predictions[:, :-1, :, :] |
|
return predictions, conf |
|
|
|
|
|
def train_one_epoch( |
|
model: torch.nn.Module, |
|
criterion: torch.nn.Module, |
|
metrics: torch.nn.Module, |
|
data_loader: Iterable, |
|
optimizer: torch.optim.Optimizer, |
|
device: torch.device, |
|
epoch: int, |
|
loss_scaler, |
|
log_writer=None, |
|
print_freq=20, |
|
args=None, |
|
): |
|
model.train(True) |
|
metric_logger = misc.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
|
header = "Epoch: [{}]".format(epoch) |
|
|
|
accum_iter = args.accum_iter |
|
|
|
optimizer.zero_grad() |
|
|
|
details = {} |
|
|
|
if log_writer is not None: |
|
print("log_dir: {}".format(log_writer.log_dir)) |
|
|
|
if args.img_per_epoch: |
|
iter_per_epoch = args.img_per_epoch // args.batch_size + int( |
|
args.img_per_epoch % args.batch_size > 0 |
|
) |
|
assert ( |
|
len(data_loader) >= iter_per_epoch |
|
), "Dataset is too small for so many iterations" |
|
len_data_loader = iter_per_epoch |
|
else: |
|
len_data_loader, iter_per_epoch = len(data_loader), None |
|
|
|
for data_iter_step, (image1, image2, gt, pairname) in enumerate( |
|
metric_logger.log_every( |
|
data_loader, print_freq, header, max_iter=iter_per_epoch |
|
) |
|
): |
|
image1 = image1.to(device, non_blocking=True) |
|
image2 = image2.to(device, non_blocking=True) |
|
gt = gt.to(device, non_blocking=True) |
|
|
|
|
|
if data_iter_step % accum_iter == 0: |
|
misc.adjust_learning_rate( |
|
optimizer, data_iter_step / len_data_loader + epoch, args |
|
) |
|
|
|
with torch.cuda.amp.autocast(enabled=bool(args.amp)): |
|
prediction = model(image1, image2) |
|
prediction, conf = split_prediction_conf(prediction, criterion.with_conf) |
|
batch_metrics = metrics(prediction.detach(), gt) |
|
loss = ( |
|
criterion(prediction, gt) |
|
if conf is None |
|
else criterion(prediction, gt, conf) |
|
) |
|
|
|
loss_value = loss.item() |
|
if not math.isfinite(loss_value): |
|
print("Loss is {}, stopping training".format(loss_value)) |
|
sys.exit(1) |
|
|
|
loss /= accum_iter |
|
loss_scaler( |
|
loss, |
|
optimizer, |
|
parameters=model.parameters(), |
|
update_grad=(data_iter_step + 1) % accum_iter == 0, |
|
) |
|
if (data_iter_step + 1) % accum_iter == 0: |
|
optimizer.zero_grad() |
|
|
|
torch.cuda.synchronize() |
|
|
|
metric_logger.update(loss=loss_value) |
|
for k, v in batch_metrics.items(): |
|
metric_logger.update(**{k: v.item()}) |
|
lr = optimizer.param_groups[0]["lr"] |
|
metric_logger.update(lr=lr) |
|
|
|
|
|
time_to_log = (data_iter_step + 1) % ( |
|
args.tboard_log_step * accum_iter |
|
) == 0 or data_iter_step == len_data_loader - 1 |
|
loss_value_reduce = misc.all_reduce_mean(loss_value) |
|
if log_writer is not None and time_to_log: |
|
epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000) |
|
|
|
log_writer.add_scalar("train/loss", loss_value_reduce, epoch_1000x) |
|
log_writer.add_scalar("lr", lr, epoch_1000x) |
|
for k, v in batch_metrics.items(): |
|
log_writer.add_scalar("train/" + k, v.item(), epoch_1000x) |
|
|
|
|
|
|
|
print("Averaged stats:", metric_logger) |
|
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
|
|
|
|
|
@torch.no_grad() |
|
def validate_one_epoch( |
|
model: torch.nn.Module, |
|
criterion: torch.nn.Module, |
|
metrics: torch.nn.Module, |
|
data_loaders: list[Iterable], |
|
device: torch.device, |
|
epoch: int, |
|
log_writer=None, |
|
args=None, |
|
): |
|
model.eval() |
|
metric_loggers = [] |
|
header = "Epoch: [{}]".format(epoch) |
|
print_freq = 20 |
|
|
|
conf_mode = args.tile_conf_mode |
|
crop = args.crop |
|
|
|
if log_writer is not None: |
|
print("log_dir: {}".format(log_writer.log_dir)) |
|
|
|
results = {} |
|
dnames = [] |
|
image1, image2, gt, prediction = None, None, None, None |
|
for didx, data_loader in enumerate(data_loaders): |
|
dname = str(data_loader.dataset) |
|
dnames.append(dname) |
|
metric_loggers.append(misc.MetricLogger(delimiter=" ")) |
|
for data_iter_step, (image1, image2, gt, pairname) in enumerate( |
|
metric_loggers[didx].log_every(data_loader, print_freq, header) |
|
): |
|
image1 = image1.to(device, non_blocking=True) |
|
image2 = image2.to(device, non_blocking=True) |
|
gt = gt.to(device, non_blocking=True) |
|
if dname.startswith("Spring"): |
|
assert ( |
|
gt.size(2) == image1.size(2) * 2 |
|
and gt.size(3) == image1.size(3) * 2 |
|
) |
|
gt = ( |
|
gt[:, :, 0::2, 0::2] |
|
+ gt[:, :, 0::2, 1::2] |
|
+ gt[:, :, 1::2, 0::2] |
|
+ gt[:, :, 1::2, 1::2] |
|
) / 4.0 |
|
|
|
with torch.inference_mode(): |
|
prediction, tiled_loss, c = tiled_pred( |
|
model, |
|
criterion, |
|
image1, |
|
image2, |
|
gt, |
|
conf_mode=conf_mode, |
|
overlap=args.val_overlap, |
|
crop=crop, |
|
with_conf=criterion.with_conf, |
|
) |
|
batch_metrics = metrics(prediction.detach(), gt) |
|
loss = ( |
|
criterion(prediction.detach(), gt) |
|
if not criterion.with_conf |
|
else criterion(prediction.detach(), gt, c) |
|
) |
|
loss_value = loss.item() |
|
metric_loggers[didx].update(loss_tiled=tiled_loss.item()) |
|
metric_loggers[didx].update(**{f"loss": loss_value}) |
|
for k, v in batch_metrics.items(): |
|
metric_loggers[didx].update(**{dname + "_" + k: v.item()}) |
|
|
|
results = { |
|
k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items() |
|
} |
|
if len(dnames) > 1: |
|
for k in batch_metrics.keys(): |
|
results["AVG_" + k] = sum( |
|
results[dname + "_" + k] for dname in dnames |
|
) / len(dnames) |
|
|
|
if log_writer is not None: |
|
epoch_1000x = int((1 + epoch) * 1000) |
|
for k, v in results.items(): |
|
log_writer.add_scalar("val/" + k, v, epoch_1000x) |
|
|
|
print("Averaged stats:", results) |
|
return results |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
def _resize_img(img, new_size): |
|
return F.interpolate(img, size=new_size, mode="bicubic", align_corners=False) |
|
|
|
|
|
def _resize_stereo_or_flow(data, new_size): |
|
assert data.ndim == 4 |
|
assert data.size(1) in [1, 2] |
|
scale_x = new_size[1] / float(data.size(3)) |
|
out = F.interpolate(data, size=new_size, mode="bicubic", align_corners=False) |
|
out[:, 0, :, :] *= scale_x |
|
if out.size(1) == 2: |
|
scale_y = new_size[0] / float(data.size(2)) |
|
out[:, 1, :, :] *= scale_y |
|
print(scale_x, new_size, data.shape) |
|
return out |
|
|
|
|
|
@torch.no_grad() |
|
def tiled_pred( |
|
model, |
|
criterion, |
|
img1, |
|
img2, |
|
gt, |
|
overlap=0.5, |
|
bad_crop_thr=0.05, |
|
downscale=False, |
|
crop=512, |
|
ret="loss", |
|
conf_mode="conf_expsigmoid_10_5", |
|
with_conf=False, |
|
return_time=False, |
|
): |
|
|
|
|
|
if gt is not None: |
|
B, C, H, W = gt.shape |
|
else: |
|
B, _, H, W = img1.shape |
|
C = model.head.num_channels - int(with_conf) |
|
win_height, win_width = crop[0], crop[1] |
|
|
|
|
|
do_change_scale = H < win_height or W < win_width |
|
if do_change_scale: |
|
upscale_factor = max(win_width / W, win_height / W) |
|
original_size = (H, W) |
|
new_size = (round(H * upscale_factor), round(W * upscale_factor)) |
|
img1 = _resize_img(img1, new_size) |
|
img2 = _resize_img(img2, new_size) |
|
|
|
if gt is not None: |
|
gt = _resize_stereo_or_flow(gt, new_size) |
|
H, W = img1.shape[2:4] |
|
|
|
if conf_mode.startswith("conf_expsigmoid_"): |
|
beta, betasigmoid = map(float, conf_mode[len("conf_expsigmoid_") :].split("_")) |
|
elif conf_mode.startswith("conf_expbeta"): |
|
beta = float(conf_mode[len("conf_expbeta") :]) |
|
else: |
|
raise NotImplementedError(f"conf_mode {conf_mode} is not implemented") |
|
|
|
def crop_generator(): |
|
for sy in _overlapping(H, win_height, overlap): |
|
for sx in _overlapping(W, win_width, overlap): |
|
yield sy, sx, sy, sx, True |
|
|
|
|
|
accu_pred = img1.new_zeros( |
|
(B, C, H, W) |
|
) |
|
accu_conf = img1.new_zeros((B, H, W)) + 1e-16 |
|
accu_c = img1.new_zeros( |
|
(B, H, W) |
|
) |
|
|
|
tiled_losses = [] |
|
|
|
if return_time: |
|
start = torch.cuda.Event(enable_timing=True) |
|
end = torch.cuda.Event(enable_timing=True) |
|
start.record() |
|
|
|
for sy1, sx1, sy2, sx2, aligned in crop_generator(): |
|
|
|
pred = model(_crop(img1, sy1, sx1), _crop(img2, sy2, sx2)) |
|
pred, predconf = split_prediction_conf(pred, with_conf=with_conf) |
|
|
|
if gt is not None: |
|
gtcrop = _crop(gt, sy1, sx1) |
|
if criterion is not None and gt is not None: |
|
tiled_losses.append( |
|
criterion(pred, gtcrop).item() |
|
if predconf is None |
|
else criterion(pred, gtcrop, predconf).item() |
|
) |
|
|
|
if conf_mode.startswith("conf_expsigmoid_"): |
|
conf = torch.exp( |
|
-beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5) |
|
).view(B, win_height, win_width) |
|
elif conf_mode.startswith("conf_expbeta"): |
|
conf = torch.exp(-beta * predconf).view(B, win_height, win_width) |
|
else: |
|
raise NotImplementedError |
|
|
|
accu_pred[..., sy1, sx1] += pred * conf[:, None, :, :] |
|
accu_conf[..., sy1, sx1] += conf |
|
accu_c[..., sy1, sx1] += predconf.view(B, win_height, win_width) * conf |
|
|
|
pred = accu_pred / accu_conf[:, None, :, :] |
|
c = accu_c / accu_conf |
|
assert not torch.any(torch.isnan(pred)) |
|
|
|
if return_time: |
|
end.record() |
|
torch.cuda.synchronize() |
|
time = start.elapsed_time(end) / 1000.0 |
|
|
|
if do_change_scale: |
|
pred = _resize_stereo_or_flow(pred, original_size) |
|
|
|
if return_time: |
|
return pred, torch.mean(torch.tensor(tiled_losses)), c, time |
|
return pred, torch.mean(torch.tensor(tiled_losses)), c |
|
|
|
|
|
def _overlapping(total, window, overlap=0.5): |
|
assert total >= window and 0 <= overlap < 1, (total, window, overlap) |
|
num_windows = 1 + int(np.ceil((total - window) / ((1 - overlap) * window))) |
|
offsets = np.linspace(0, total - window, num_windows).round().astype(int) |
|
yield from (slice(x, x + window) for x in offsets) |
|
|
|
|
|
def _crop(img, sy, sx): |
|
B, THREE, H, W = img.shape |
|
if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W: |
|
return img[:, :, sy, sx] |
|
l, r = max(0, -sx.start), max(0, sx.stop - W) |
|
t, b = max(0, -sy.start), max(0, sy.stop - H) |
|
img = torch.nn.functional.pad(img, (l, r, t, b), mode="constant") |
|
return img[:, :, slice(sy.start + t, sy.stop + t), slice(sx.start + l, sx.stop + l)] |
|
|