|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import pickle |
|
|
|
import numpy as np |
|
import torch |
|
import utils.misc as misc |
|
from models.croco_downstream import CroCoDownstreamBinocular |
|
from models.head_downstream import PixelwiseTaskWithDPT |
|
from PIL import Image |
|
from stereoflow.criterion import * |
|
from stereoflow.datasets_flow import flowToColor, get_test_datasets_flow |
|
from stereoflow.datasets_stereo import get_test_datasets_stereo, vis_disparity |
|
from stereoflow.engine import tiled_pred |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False) |
|
|
|
parser.add_argument( |
|
"--model", required=True, type=str, help="Path to the model to evaluate" |
|
) |
|
parser.add_argument( |
|
"--dataset", |
|
required=True, |
|
type=str, |
|
help="test dataset (there can be multiple dataset separated by a +)", |
|
) |
|
|
|
parser.add_argument( |
|
"--tile_conf_mode", |
|
type=str, |
|
default="", |
|
help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint", |
|
) |
|
parser.add_argument( |
|
"--tile_overlap", type=float, default=0.7, help="overlap between tiles" |
|
) |
|
|
|
parser.add_argument( |
|
"--save", |
|
type=str, |
|
nargs="+", |
|
default=[], |
|
help="what to save: \ |
|
metrics (pickle file), \ |
|
pred (raw prediction save as torch tensor), \ |
|
visu (visualization in png of each prediction), \ |
|
err10 (visualization in png of the error clamp at 10 for each prediction), \ |
|
submission (submission file)", |
|
) |
|
|
|
parser.add_argument("--num_workers", default=4, type=int) |
|
return parser |
|
|
|
|
|
def _load_model_and_criterion(model_path, do_load_metrics, device): |
|
print("loading model from", model_path) |
|
assert os.path.isfile(model_path) |
|
ckpt = torch.load(model_path, "cpu") |
|
|
|
ckpt_args = ckpt["args"] |
|
task = ckpt_args.task |
|
tile_conf_mode = ckpt_args.tile_conf_mode |
|
num_channels = {"stereo": 1, "flow": 2}[task] |
|
with_conf = eval(ckpt_args.criterion).with_conf |
|
if with_conf: |
|
num_channels += 1 |
|
print("head: PixelwiseTaskWithDPT()") |
|
head = PixelwiseTaskWithDPT() |
|
head.num_channels = num_channels |
|
print("croco_args:", ckpt_args.croco_args) |
|
model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args) |
|
msg = model.load_state_dict(ckpt["model"], strict=True) |
|
model.eval() |
|
model = model.to(device) |
|
|
|
if do_load_metrics: |
|
if task == "stereo": |
|
metrics = StereoDatasetMetrics().to(device) |
|
else: |
|
metrics = FlowDatasetMetrics().to(device) |
|
else: |
|
metrics = None |
|
|
|
return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode |
|
|
|
|
|
def _save_batch( |
|
pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None |
|
): |
|
for i in range(len(pairnames)): |
|
pairname = ( |
|
eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i] |
|
) |
|
fname = os.path.join(outdir, dataset.pairname_to_str(pairname)) |
|
os.makedirs(os.path.dirname(fname), exist_ok=True) |
|
|
|
predi = pred[i, ...] |
|
if gt is not None: |
|
gti = gt[i, ...] |
|
|
|
if "pred" in save: |
|
torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth") |
|
|
|
if "visu" in save: |
|
if task == "stereo": |
|
disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy() |
|
m, M = None |
|
if gt is not None: |
|
mask = torch.isfinite(gti) |
|
m = gt[mask].min() |
|
M = gt[mask].max() |
|
img_disparity = vis_disparity(disparity, m=m, M=M) |
|
Image.fromarray(img_disparity).save(fname + "_pred.png") |
|
else: |
|
|
|
flowNorm = ( |
|
torch.sqrt( |
|
torch.sum((gti if gt is not None else predi) ** 2, dim=0) |
|
) |
|
.max() |
|
.item() |
|
) |
|
imgflow = flowToColor( |
|
predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm |
|
) |
|
Image.fromarray(imgflow).save(fname + "_pred.png") |
|
|
|
if "err10" in save: |
|
assert gt is not None |
|
L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0)) |
|
valid = torch.isfinite(gti[0, :, :]) |
|
L2err[~valid] = 0.0 |
|
L2err = torch.clamp(L2err, max=10.0) |
|
red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None] |
|
zer = torch.zeros_like(red) |
|
imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy() |
|
Image.fromarray(imgerr).save(fname + "_err10.png") |
|
|
|
if "submission" in save: |
|
assert submission_dir is not None |
|
predi_np = ( |
|
predi.permute(1, 2, 0).squeeze(2).cpu().numpy() |
|
) |
|
dataset.submission_save_pairname(pairname, predi_np, submission_dir, time) |
|
|
|
|
|
def main(args): |
|
|
|
device = ( |
|
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
|
) |
|
( |
|
model, |
|
metrics, |
|
cropsize, |
|
with_conf, |
|
task, |
|
tile_conf_mode, |
|
) = _load_model_and_criterion(args.model, "metrics" in args.save, device) |
|
if args.tile_conf_mode == "": |
|
args.tile_conf_mode = tile_conf_mode |
|
|
|
|
|
datasets = ( |
|
get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow |
|
)(args.dataset) |
|
dataloaders = [ |
|
DataLoader( |
|
dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
for dataset in datasets |
|
] |
|
|
|
|
|
for i, dataloader in enumerate(dataloaders): |
|
dataset = datasets[i] |
|
dstr = args.dataset.split("+")[i] |
|
|
|
outdir = args.model + "_" + misc.filename(dstr) |
|
if "metrics" in args.save and len(args.save) == 1: |
|
fname = os.path.join( |
|
outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl" |
|
) |
|
if os.path.isfile(fname) and len(args.save) == 1: |
|
print(" metrics already compute in " + fname) |
|
with open(fname, "rb") as fid: |
|
results = pickle.load(fid) |
|
for k, v in results.items(): |
|
print("{:s}: {:.3f}".format(k, v)) |
|
continue |
|
|
|
if "submission" in args.save: |
|
dirname = ( |
|
f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}" |
|
) |
|
submission_dir = os.path.join(outdir, dirname) |
|
else: |
|
submission_dir = None |
|
|
|
print("") |
|
print("saving {:s} in {:s}".format("+".join(args.save), outdir)) |
|
print(repr(dataset)) |
|
|
|
if metrics is not None: |
|
metrics.reset() |
|
|
|
for data_iter_step, (image1, image2, gt, pairnames) in enumerate( |
|
tqdm(dataloader) |
|
): |
|
do_flip = ( |
|
task == "stereo" |
|
and dstr.startswith("Spring") |
|
and any("right" in p for p in pairnames) |
|
) |
|
|
|
image1 = image1.to(device, non_blocking=True) |
|
image2 = image2.to(device, non_blocking=True) |
|
gt = ( |
|
gt.to(device, non_blocking=True) if gt.numel() > 0 else None |
|
) |
|
if do_flip: |
|
assert all("right" in p for p in pairnames) |
|
image1 = image1.flip( |
|
dims=[3] |
|
) |
|
image2 = image2.flip(dims=[3]) |
|
gt = gt |
|
|
|
with torch.inference_mode(): |
|
pred, _, _, time = tiled_pred( |
|
model, |
|
None, |
|
image1, |
|
image2, |
|
None if dataset.name == "Spring" else gt, |
|
conf_mode=args.tile_conf_mode, |
|
overlap=args.tile_overlap, |
|
crop=cropsize, |
|
with_conf=with_conf, |
|
return_time=True, |
|
) |
|
|
|
if do_flip: |
|
pred = pred.flip(dims=[3]) |
|
|
|
if metrics is not None: |
|
metrics.add_batch(pred, gt) |
|
|
|
if any(k in args.save for k in ["pred", "visu", "err10", "submission"]): |
|
_save_batch( |
|
pred, |
|
gt, |
|
pairnames, |
|
dataset, |
|
task, |
|
args.save, |
|
outdir, |
|
time, |
|
submission_dir=submission_dir, |
|
) |
|
|
|
|
|
if metrics is not None: |
|
results = metrics.get_results() |
|
for k, v in results.items(): |
|
print("{:s}: {:.3f}".format(k, v)) |
|
|
|
|
|
if "metrics" in args.save: |
|
os.makedirs(os.path.dirname(fname), exist_ok=True) |
|
with open(fname, "wb") as fid: |
|
pickle.dump(results, fid) |
|
print("metrics saved in", fname) |
|
|
|
|
|
if "submission" in args.save: |
|
dataset.finalize_submission(submission_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser() |
|
args = args.parse_args() |
|
main(args) |
|
|