|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import os.path as osp |
|
import pickle |
|
import sys |
|
from glob import glob |
|
|
|
import cv2 |
|
import h5py |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils import data |
|
|
|
from .augmentor import StereoAugmentor |
|
|
|
dataset_to_root = { |
|
"CREStereo": "./data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/", |
|
"SceneFlow": "./data/stereoflow//SceneFlow/", |
|
"ETH3DLowRes": "./data/stereoflow/eth3d_lowres/", |
|
"Booster": "./data/stereoflow/booster_gt/", |
|
"Middlebury2021": "./data/stereoflow/middlebury/2021/data/", |
|
"Middlebury2014": "./data/stereoflow/middlebury/2014/", |
|
"Middlebury2006": "./data/stereoflow/middlebury/2006/", |
|
"Middlebury2005": "./data/stereoflow/middlebury/2005/train/", |
|
"MiddleburyEval3": "./data/stereoflow/middlebury/MiddEval3/", |
|
"Spring": "./data/stereoflow/spring/", |
|
"Kitti15": "./data/stereoflow/kitti-stereo-2015/", |
|
"Kitti12": "./data/stereoflow/kitti-stereo-2012/", |
|
} |
|
cache_dir = "./data/stereoflow/datasets_stereo_cache/" |
|
|
|
|
|
in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) |
|
in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) |
|
|
|
|
|
def img_to_tensor(img): |
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 |
|
img = (img - in1k_mean) / in1k_std |
|
return img |
|
|
|
|
|
def disp_to_tensor(disp): |
|
return torch.from_numpy(disp)[None, :, :] |
|
|
|
|
|
class StereoDataset(data.Dataset): |
|
def __init__(self, split, augmentor=False, crop_size=None, totensor=True): |
|
self.split = split |
|
if not augmentor: |
|
assert crop_size is None |
|
if crop_size: |
|
assert augmentor |
|
self.crop_size = crop_size |
|
self.augmentor_str = augmentor |
|
self.augmentor = StereoAugmentor(crop_size) if augmentor else None |
|
self.totensor = totensor |
|
self.rmul = 1 |
|
self.has_constant_resolution = True |
|
self._prepare_data() |
|
self._load_or_build_cache() |
|
|
|
def prepare_data(self): |
|
""" |
|
to be defined for each dataset |
|
""" |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return len(self.pairnames) |
|
|
|
def __getitem__(self, index): |
|
pairname = self.pairnames[index] |
|
|
|
|
|
Limgname = self.pairname_to_Limgname(pairname) |
|
Rimgname = self.pairname_to_Rimgname(pairname) |
|
Ldispname = ( |
|
self.pairname_to_Ldispname(pairname) |
|
if self.pairname_to_Ldispname is not None |
|
else None |
|
) |
|
|
|
|
|
Limg = _read_img(Limgname) |
|
Rimg = _read_img(Rimgname) |
|
disp = self.load_disparity(Ldispname) if Ldispname is not None else None |
|
|
|
|
|
if disp is not None: |
|
assert np.all(disp > 0) or self.name == "Spring", ( |
|
self.name, |
|
pairname, |
|
Ldispname, |
|
) |
|
|
|
|
|
if self.augmentor is not None: |
|
Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) |
|
|
|
if self.totensor: |
|
Limg = img_to_tensor(Limg) |
|
Rimg = img_to_tensor(Rimg) |
|
if disp is None: |
|
disp = torch.tensor( |
|
[] |
|
) |
|
else: |
|
disp = disp_to_tensor(disp) |
|
|
|
return Limg, Rimg, disp, str(pairname) |
|
|
|
def __rmul__(self, v): |
|
self.rmul *= v |
|
self.pairnames = v * self.pairnames |
|
return self |
|
|
|
def __str__(self): |
|
return f"{self.__class__.__name__}_{self.split}" |
|
|
|
def __repr__(self): |
|
s = f"{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})" |
|
if self.rmul == 1: |
|
s += f"\n\tnum pairs: {len(self.pairnames)}" |
|
else: |
|
s += f"\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})" |
|
return s |
|
|
|
def _set_root(self): |
|
self.root = dataset_to_root[self.name] |
|
assert os.path.isdir( |
|
self.root |
|
), f"could not find root directory for dataset {self.name}: {self.root}" |
|
|
|
def _load_or_build_cache(self): |
|
cache_file = osp.join(cache_dir, self.name + ".pkl") |
|
if osp.isfile(cache_file): |
|
with open(cache_file, "rb") as fid: |
|
self.pairnames = pickle.load(fid)[self.split] |
|
else: |
|
tosave = self._build_cache() |
|
os.makedirs(cache_dir, exist_ok=True) |
|
with open(cache_file, "wb") as fid: |
|
pickle.dump(tosave, fid) |
|
self.pairnames = tosave[self.split] |
|
|
|
|
|
class CREStereoDataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "CREStereo" |
|
self._set_root() |
|
assert self.split in ["train"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname + "_left.jpg" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname + "_right.jpg" |
|
) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, pairname + "_left.disp.png" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_crestereo_disp |
|
|
|
def _build_cache(self): |
|
allpairs = [ |
|
s + "/" + f[: -len("_left.jpg")] |
|
for s in sorted(os.listdir(self.root)) |
|
for f in sorted(os.listdir(self.root + "/" + s)) |
|
if f.endswith("_left.jpg") |
|
] |
|
assert len(allpairs) == 200000, "incorrect parsing of pairs in CreStereo" |
|
tosave = {"train": allpairs} |
|
return tosave |
|
|
|
|
|
class SceneFlowDataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "SceneFlow" |
|
self._set_root() |
|
assert self.split in [ |
|
"train_finalpass", |
|
"train_cleanpass", |
|
"train_allpass", |
|
"test_finalpass", |
|
"test_cleanpass", |
|
"test_allpass", |
|
"test1of100_cleanpass", |
|
"test1of100_finalpass", |
|
] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname |
|
).replace("/left/", "/right/") |
|
self.pairname_to_Ldispname = ( |
|
lambda pairname: osp.join(self.root, pairname) |
|
.replace("/frames_finalpass/", "/disparity/") |
|
.replace("/frames_cleanpass/", "/disparity/")[:-4] |
|
+ ".pfm" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_sceneflow_disp |
|
|
|
def _build_cache(self): |
|
trainpairs = [] |
|
|
|
pairs = sorted(glob(self.root + "Driving/frames_finalpass/*/*/*/left/*.png")) |
|
pairs = list(map(lambda x: x[len(self.root) :], pairs)) |
|
assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
|
|
pairs = sorted(glob(self.root + "Monkaa/frames_finalpass/*/left/*.png")) |
|
pairs = list(map(lambda x: x[len(self.root) :], pairs)) |
|
assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
|
|
pairs = sorted( |
|
glob(self.root + "FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png") |
|
) |
|
pairs = list(map(lambda x: x[len(self.root) :], pairs)) |
|
assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" |
|
testpairs = sorted( |
|
glob(self.root + "FlyingThings/frames_finalpass/TEST/*/*/left/*.png") |
|
) |
|
testpairs = list(map(lambda x: x[len(self.root) :], testpairs)) |
|
assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" |
|
test1of100pairs = testpairs[::100] |
|
assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" |
|
|
|
tosave = { |
|
"train_finalpass": trainpairs, |
|
"train_cleanpass": list( |
|
map( |
|
lambda x: x.replace("frames_finalpass", "frames_cleanpass"), |
|
trainpairs, |
|
) |
|
), |
|
"test_finalpass": testpairs, |
|
"test_cleanpass": list( |
|
map( |
|
lambda x: x.replace("frames_finalpass", "frames_cleanpass"), |
|
testpairs, |
|
) |
|
), |
|
"test1of100_finalpass": test1of100pairs, |
|
"test1of100_cleanpass": list( |
|
map( |
|
lambda x: x.replace("frames_finalpass", "frames_cleanpass"), |
|
test1of100pairs, |
|
) |
|
), |
|
} |
|
tosave["train_allpass"] = tosave["train_finalpass"] + tosave["train_cleanpass"] |
|
tosave["test_allpass"] = tosave["test_finalpass"] + tosave["test_cleanpass"] |
|
return tosave |
|
|
|
|
|
class Md21Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Middlebury2021" |
|
self._set_root() |
|
assert self.split in ["train", "subtrain", "subval"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname.replace("/im0", "/im1") |
|
) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, pairname.split("/")[0], "disp0.pfm" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury_disp |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
|
|
trainpairs += [ |
|
s + "/ambient/" + b + "/" + a |
|
for b in sorted(os.listdir(osp.join(self.root, s, "ambient"))) |
|
for a in sorted(os.listdir(osp.join(self.root, s, "ambient", b))) |
|
if a.startswith("im0") |
|
] |
|
assert len(trainpairs) == 355 |
|
subtrainpairs = [ |
|
p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[:-2]) |
|
] |
|
subvalpairs = [ |
|
p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[-2:]) |
|
] |
|
assert ( |
|
len(subtrainpairs) == 335 and len(subvalpairs) == 20 |
|
), "incorrect parsing of pairs in Middlebury 2021" |
|
tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} |
|
return tosave |
|
|
|
|
|
class Md14Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Middlebury2014" |
|
self._set_root() |
|
assert self.split in ["train", "subtrain", "subval"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, osp.dirname(pairname), "im0.png" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, osp.dirname(pairname), "disp0.pfm" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
trainpairs += [s + "/im1.png", s + "/im1E.png", s + "/im1L.png"] |
|
assert len(trainpairs) == 138 |
|
valseqs = ["Umbrella-imperfect", "Vintage-perfect"] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [ |
|
p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
subvalpairs = [ |
|
p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
assert ( |
|
len(subtrainpairs) == 132 and len(subvalpairs) == 6 |
|
), "incorrect parsing of pairs in Middlebury 2014" |
|
tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} |
|
return tosave |
|
|
|
|
|
class Md06Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Middlebury2006" |
|
self._set_root() |
|
assert self.split in ["train", "subtrain", "subval"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, osp.dirname(pairname), "view5.png" |
|
) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, pairname.split("/")[0], "disp1.png" |
|
) |
|
self.load_disparity = _read_middlebury20052006_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
for i in ["Illum1", "Illum2", "Illum3"]: |
|
for e in ["Exp0", "Exp1", "Exp2"]: |
|
trainpairs.append(osp.join(s, i, e, "view1.png")) |
|
assert len(trainpairs) == 189 |
|
valseqs = ["Rocks1", "Wood2"] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [ |
|
p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
subvalpairs = [ |
|
p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
assert ( |
|
len(subtrainpairs) == 171 and len(subvalpairs) == 18 |
|
), "incorrect parsing of pairs in Middlebury 2006" |
|
tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} |
|
return tosave |
|
|
|
|
|
class Md05Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Middlebury2005" |
|
self._set_root() |
|
assert self.split in ["train", "subtrain", "subval"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, osp.dirname(pairname), "view5.png" |
|
) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, pairname.split("/")[0], "disp1.png" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury20052006_disp |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
for i in ["Illum1", "Illum2", "Illum3"]: |
|
for e in ["Exp0", "Exp1", "Exp2"]: |
|
trainpairs.append(osp.join(s, i, e, "view1.png")) |
|
assert len(trainpairs) == 54, "incorrect parsing of pairs in Middlebury 2005" |
|
valseqs = ["Reindeer"] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [ |
|
p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
subvalpairs = [ |
|
p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs) |
|
] |
|
assert ( |
|
len(subtrainpairs) == 45 and len(subvalpairs) == 9 |
|
), "incorrect parsing of pairs in Middlebury 2005" |
|
tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs} |
|
return tosave |
|
|
|
|
|
class MdEval3Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "MiddleburyEval3" |
|
self._set_root() |
|
assert self.split in [ |
|
s + "_" + r |
|
for s in ["train", "subtrain", "subval", "test", "all"] |
|
for r in ["full", "half", "quarter"] |
|
] |
|
if self.split.endswith("_full"): |
|
self.root = self.root.replace("/MiddEval3", "/MiddEval3_F") |
|
elif self.split.endswith("_half"): |
|
self.root = self.root.replace("/MiddEval3", "/MiddEval3_H") |
|
else: |
|
assert self.split.endswith("_quarter") |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname, "im0.png" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname, "im1.png" |
|
) |
|
self.pairname_to_Ldispname = ( |
|
lambda pairname: None |
|
if pairname.startswith("test") |
|
else osp.join(self.root, pairname, "disp0GT.pfm") |
|
) |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_middlebury_disp |
|
|
|
self.submission_methodname = "CroCo-Stereo" |
|
self.submission_sresolution = ( |
|
"F" |
|
if self.split.endswith("_full") |
|
else ("H" if self.split.endswith("_half") else "Q") |
|
) |
|
|
|
def _build_cache(self): |
|
trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))] |
|
testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))] |
|
subvalpairs = trainpairs[-1:] |
|
subtrainpairs = trainpairs[:-1] |
|
allpairs = trainpairs + testpairs |
|
assert ( |
|
len(trainpairs) == 15 |
|
and len(testpairs) == 15 |
|
and len(subvalpairs) == 1 |
|
and len(subtrainpairs) == 14 |
|
and len(allpairs) == 30 |
|
), "incorrect parsing of pairs in Middlebury Eval v3" |
|
tosave = {} |
|
for r in ["full", "half", "quarter"]: |
|
tosave.update( |
|
**{ |
|
"train_" + r: trainpairs, |
|
"subtrain_" + r: subtrainpairs, |
|
"subval_" + r: subvalpairs, |
|
"test_" + r: testpairs, |
|
"all_" + r: allpairs, |
|
} |
|
) |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim == 2 |
|
assert prediction.dtype == np.float32 |
|
outfile = os.path.join( |
|
outdir, |
|
pairname.split("/")[0].replace("train", "training") |
|
+ self.submission_sresolution, |
|
pairname.split("/")[1], |
|
"disp0" + self.submission_methodname + ".pfm", |
|
) |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
writePFM(outfile, prediction) |
|
timefile = os.path.join( |
|
os.path.dirname(outfile), "time" + self.submission_methodname + ".txt" |
|
) |
|
with open(timefile, "w") as fid: |
|
fid.write(str(time)) |
|
|
|
def finalize_submission(self, outdir): |
|
cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f"Done. Submission file at {outdir}/{self.submission_methodname}.zip") |
|
|
|
|
|
class ETH3DLowResDataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "ETH3DLowRes" |
|
self._set_root() |
|
assert self.split in ["train", "test", "subtrain", "subval", "all"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname, "im0.png" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname, "im1.png" |
|
) |
|
self.pairname_to_Ldispname = ( |
|
None |
|
if self.split == "test" |
|
else lambda pairname: None |
|
if pairname.startswith("test/") |
|
else osp.join( |
|
self.root, pairname.replace("train/", "train_gt/"), "disp0GT.pfm" |
|
) |
|
) |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_eth3d_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))] |
|
testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))] |
|
assert ( |
|
len(trainpairs) == 27 and len(testpairs) == 20 |
|
), "incorrect parsing of pairs in ETH3D Low Res" |
|
subvalpairs = [ |
|
"train/delivery_area_3s", |
|
"train/electro_3l", |
|
"train/playground_3l", |
|
] |
|
assert all(p in trainpairs for p in subvalpairs) |
|
subtrainpairs = [p for p in trainpairs if not p in subvalpairs] |
|
assert ( |
|
len(subvalpairs) == 3 and len(subtrainpairs) == 24 |
|
), "incorrect parsing of pairs in ETH3D Low Res" |
|
tosave = { |
|
"train": trainpairs, |
|
"test": testpairs, |
|
"subtrain": subtrainpairs, |
|
"subval": subvalpairs, |
|
"all": trainpairs + testpairs, |
|
} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim == 2 |
|
assert prediction.dtype == np.float32 |
|
outfile = os.path.join( |
|
outdir, "low_res_two_view", pairname.split("/")[1] + ".pfm" |
|
) |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
writePFM(outfile, prediction) |
|
timefile = outfile[:-4] + ".txt" |
|
with open(timefile, "w") as fid: |
|
fid.write("runtime " + str(time)) |
|
|
|
def finalize_submission(self, outdir): |
|
cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f"Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip") |
|
|
|
|
|
class BoosterDataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Booster" |
|
self._set_root() |
|
assert self.split in [ |
|
"train_balanced", |
|
"test_balanced", |
|
"subtrain_balanced", |
|
"subval_balanced", |
|
] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname |
|
).replace("/camera_00/", "/camera_02/") |
|
self.pairname_to_Ldispname = lambda pairname: osp.join( |
|
self.root, osp.dirname(pairname), "../disp_00.npy" |
|
) |
|
self.pairname_to_str = lambda pairname: pairname[:-4].replace( |
|
"/camera_00/", "/" |
|
) |
|
self.load_disparity = _read_booster_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = sorted(os.listdir(self.root + "train/balanced")) |
|
trainpairs = [ |
|
"train/balanced/" + s + "/camera_00/" + imname |
|
for s in trainseqs |
|
for imname in sorted( |
|
os.listdir(self.root + "train/balanced/" + s + "/camera_00/") |
|
) |
|
] |
|
testpairs = [ |
|
"test/balanced/" + s + "/camera_00/" + imname |
|
for s in sorted(os.listdir(self.root + "test/balanced")) |
|
for imname in sorted( |
|
os.listdir(self.root + "test/balanced/" + s + "/camera_00/") |
|
) |
|
] |
|
assert len(trainpairs) == 228 and len(testpairs) == 191 |
|
subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] |
|
subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] |
|
|
|
tosave = { |
|
"train_balanced": trainpairs, |
|
"test_balanced": testpairs, |
|
"subtrain_balanced": subtrainpairs, |
|
"subval_balanced": subvalpairs, |
|
} |
|
return tosave |
|
|
|
|
|
class SpringDataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Spring" |
|
self._set_root() |
|
assert self.split in ["train", "test", "subtrain", "subval"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname + ".png" |
|
) |
|
self.pairname_to_Rimgname = ( |
|
lambda pairname: osp.join(self.root, pairname + ".png") |
|
.replace("frame_right", "<frame_right>") |
|
.replace("frame_left", "frame_right") |
|
.replace("<frame_right>", "frame_left") |
|
) |
|
self.pairname_to_Ldispname = ( |
|
lambda pairname: None |
|
if pairname.startswith("test") |
|
else osp.join(self.root, pairname + ".dsp5") |
|
.replace("frame_left", "disp1_left") |
|
.replace("frame_right", "disp1_right") |
|
) |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_hdf5_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = sorted(os.listdir(osp.join(self.root, "train"))) |
|
trainpairs = [ |
|
osp.join("train", s, "frame_left", f[:-4]) |
|
for s in trainseqs |
|
for f in sorted(os.listdir(osp.join(self.root, "train", s, "frame_left"))) |
|
] |
|
testseqs = sorted(os.listdir(osp.join(self.root, "test"))) |
|
testpairs = [ |
|
osp.join("test", s, "frame_left", f[:-4]) |
|
for s in testseqs |
|
for f in sorted(os.listdir(osp.join(self.root, "test", s, "frame_left"))) |
|
] |
|
testpairs += [p.replace("frame_left", "frame_right") for p in testpairs] |
|
"""maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" |
|
subtrainpairs = [p for p in trainpairs if p.split("/")[1] != "0041"] |
|
subvalpairs = [p for p in trainpairs if p.split("/")[1] == "0041"] |
|
assert ( |
|
len(trainpairs) == 5000 |
|
and len(testpairs) == 2000 |
|
and len(subtrainpairs) == 4904 |
|
and len(subvalpairs) == 96 |
|
), "incorrect parsing of pairs in Spring" |
|
tosave = { |
|
"train": trainpairs, |
|
"test": testpairs, |
|
"subtrain": subtrainpairs, |
|
"subval": subvalpairs, |
|
} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim == 2 |
|
assert prediction.dtype == np.float32 |
|
outfile = ( |
|
os.path.join(outdir, pairname + ".dsp5") |
|
.replace("frame_left", "disp1_left") |
|
.replace("frame_right", "disp1_right") |
|
) |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
writeDsp5File(prediction, outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split == "test" |
|
exe = "{self.root}/disp1_subsampling" |
|
if os.path.isfile(exe): |
|
cmd = f'cd "{outdir}/test"; {exe} .' |
|
print(cmd) |
|
os.system(cmd) |
|
else: |
|
print("Could not find disp1_subsampling executable for submission.") |
|
print("Please download it and run:") |
|
print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .') |
|
|
|
|
|
class Kitti12Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Kitti12" |
|
self._set_root() |
|
assert self.split in ["train", "test"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname + "_10.png" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname.replace("/colored_0/", "/colored_1/") + "_10.png" |
|
) |
|
self.pairname_to_Ldispname = ( |
|
None |
|
if self.split == "test" |
|
else lambda pairname: osp.join( |
|
self.root, pairname.replace("/colored_0/", "/disp_occ/") + "_10.png" |
|
) |
|
) |
|
self.pairname_to_str = lambda pairname: pairname.replace("/colored_0/", "/") |
|
self.load_disparity = _read_kitti_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = ["training/colored_0/%06d" % (i) for i in range(194)] |
|
testseqs = ["testing/colored_0/%06d" % (i) for i in range(195)] |
|
assert ( |
|
len(trainseqs) == 194 and len(testseqs) == 195 |
|
), "incorrect parsing of pairs in Kitti12" |
|
tosave = {"train": trainseqs, "test": testseqs} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim == 2 |
|
assert prediction.dtype == np.float32 |
|
outfile = os.path.join(outdir, pairname.split("/")[-1] + "_10.png") |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
img = (prediction * 256).astype("uint16") |
|
Image.fromarray(img).save(outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split == "test" |
|
cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f"Done. Submission file at {outdir}/kitti12_results.zip") |
|
|
|
|
|
class Kitti15Dataset(StereoDataset): |
|
def _prepare_data(self): |
|
self.name = "Kitti15" |
|
self._set_root() |
|
assert self.split in ["train", "subtrain", "subval", "test"] |
|
self.pairname_to_Limgname = lambda pairname: osp.join( |
|
self.root, pairname + "_10.png" |
|
) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join( |
|
self.root, pairname.replace("/image_2/", "/image_3/") + "_10.png" |
|
) |
|
self.pairname_to_Ldispname = ( |
|
None |
|
if self.split == "test" |
|
else lambda pairname: osp.join( |
|
self.root, pairname.replace("/image_2/", "/disp_occ_0/") + "_10.png" |
|
) |
|
) |
|
self.pairname_to_str = lambda pairname: pairname.replace("/image_2/", "/") |
|
self.load_disparity = _read_kitti_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = ["training/image_2/%06d" % (i) for i in range(200)] |
|
subtrainseqs = trainseqs[:-5] |
|
subvalseqs = trainseqs[-5:] |
|
testseqs = ["testing/image_2/%06d" % (i) for i in range(200)] |
|
assert ( |
|
len(trainseqs) == 200 |
|
and len(subtrainseqs) == 195 |
|
and len(subvalseqs) == 5 |
|
and len(testseqs) == 200 |
|
), "incorrect parsing of pairs in Kitti15" |
|
tosave = { |
|
"train": trainseqs, |
|
"subtrain": subtrainseqs, |
|
"subval": subvalseqs, |
|
"test": testseqs, |
|
} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim == 2 |
|
assert prediction.dtype == np.float32 |
|
outfile = os.path.join(outdir, "disp_0", pairname.split("/")[-1] + "_10.png") |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
img = (prediction * 256).astype("uint16") |
|
Image.fromarray(img).save(outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split == "test" |
|
cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f"Done. Submission file at {outdir}/kitti15_results.zip") |
|
|
|
|
|
|
|
|
|
|
|
def _read_img(filename): |
|
|
|
img = np.asarray(Image.open(filename).convert("RGB")) |
|
return img |
|
|
|
|
|
def _read_booster_disp(filename): |
|
disp = np.load(filename) |
|
disp[disp == 0.0] = np.inf |
|
return disp |
|
|
|
|
|
def _read_png_disp(filename, coef=1.0): |
|
disp = np.asarray(Image.open(filename)) |
|
disp = disp.astype(np.float32) / coef |
|
disp[disp == 0.0] = np.inf |
|
return disp |
|
|
|
|
|
def _read_pfm_disp(filename): |
|
disp = np.ascontiguousarray(_read_pfm(filename)[0]) |
|
disp[ |
|
disp <= 0 |
|
] = ( |
|
np.inf |
|
) |
|
return disp |
|
|
|
|
|
def _read_npy_disp(filename): |
|
return np.load(filename) |
|
|
|
|
|
def _read_crestereo_disp(filename): |
|
return _read_png_disp(filename, coef=32.0) |
|
|
|
|
|
def _read_middlebury20052006_disp(filename): |
|
return _read_png_disp(filename, coef=1.0) |
|
|
|
|
|
def _read_kitti_disp(filename): |
|
return _read_png_disp(filename, coef=256.0) |
|
|
|
|
|
_read_sceneflow_disp = _read_pfm_disp |
|
_read_eth3d_disp = _read_pfm_disp |
|
_read_middlebury_disp = _read_pfm_disp |
|
_read_carla_disp = _read_pfm_disp |
|
_read_tartanair_disp = _read_npy_disp |
|
|
|
|
|
def _read_hdf5_disp(filename): |
|
disp = np.asarray(h5py.File(filename)["disparity"]) |
|
disp[np.isnan(disp)] = np.inf |
|
|
|
return disp.astype(np.float32) |
|
|
|
|
|
import re |
|
|
|
|
|
def _read_pfm(file): |
|
file = open(file, "rb") |
|
|
|
color = None |
|
width = None |
|
height = None |
|
scale = None |
|
endian = None |
|
|
|
header = file.readline().rstrip() |
|
if header.decode("ascii") == "PF": |
|
color = True |
|
elif header.decode("ascii") == "Pf": |
|
color = False |
|
else: |
|
raise Exception("Not a PFM file.") |
|
|
|
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) |
|
if dim_match: |
|
width, height = list(map(int, dim_match.groups())) |
|
else: |
|
raise Exception("Malformed PFM header.") |
|
|
|
scale = float(file.readline().decode("ascii").rstrip()) |
|
if scale < 0: |
|
endian = "<" |
|
scale = -scale |
|
else: |
|
endian = ">" |
|
|
|
data = np.fromfile(file, endian + "f") |
|
shape = (height, width, 3) if color else (height, width) |
|
|
|
data = np.reshape(data, shape) |
|
data = np.flipud(data) |
|
return data, scale |
|
|
|
|
|
def writePFM(file, image, scale=1): |
|
file = open(file, "wb") |
|
|
|
color = None |
|
|
|
if image.dtype.name != "float32": |
|
raise Exception("Image dtype must be float32.") |
|
|
|
image = np.flipud(image) |
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
color = True |
|
elif ( |
|
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 |
|
): |
|
color = False |
|
else: |
|
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") |
|
|
|
file.write("PF\n" if color else "Pf\n".encode()) |
|
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) |
|
|
|
endian = image.dtype.byteorder |
|
|
|
if endian == "<" or endian == "=" and sys.byteorder == "little": |
|
scale = -scale |
|
|
|
file.write("%f\n".encode() % scale) |
|
|
|
image.tofile(file) |
|
|
|
|
|
def writeDsp5File(disp, filename): |
|
with h5py.File(filename, "w") as f: |
|
f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) |
|
|
|
|
|
|
|
|
|
|
|
def vis_disparity(disp, m=None, M=None): |
|
if m is None: |
|
m = disp.min() |
|
if M is None: |
|
M = disp.max() |
|
disp_vis = (disp - m) / (M - m) * 255.0 |
|
disp_vis = disp_vis.astype("uint8") |
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
|
return disp_vis |
|
|
|
|
|
|
|
|
|
|
|
def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): |
|
dataset_str = dataset_str.replace("(", "Dataset(") |
|
if augmentor: |
|
dataset_str = dataset_str.replace(")", ", augmentor=True)") |
|
if crop_size is not None: |
|
dataset_str = dataset_str.replace( |
|
")", ", crop_size={:s})".format(str(crop_size)) |
|
) |
|
return eval(dataset_str) |
|
|
|
|
|
def get_test_datasets_stereo(dataset_str): |
|
dataset_str = dataset_str.replace("(", "Dataset(") |
|
return [eval(s) for s in dataset_str.split("+")] |
|
|