Spaces:
Runtime error
Runtime error
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. | |
# !/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import torch | |
from tqdm import tqdm | |
def calculate_lfd_distance( | |
q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8): | |
with torch.no_grad(): | |
src_ArtCoeff = src_ArtCoeff.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1) | |
tgt_ArtCoeff = tgt_ArtCoeff.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1) | |
art_distance = q8_table[src_ArtCoeff.reshape(-1).long(), tgt_ArtCoeff.reshape(-1).long()] | |
art_distance = art_distance.reshape( | |
src_ArtCoeff.shape[0], src_ArtCoeff.shape[1], src_ArtCoeff.shape[2], | |
src_ArtCoeff.shape[3], | |
src_ArtCoeff.shape[4], src_ArtCoeff.shape[5]) | |
art_distance = torch.sum(art_distance, dim=-1) | |
src_FdCoeff_q8 = src_FdCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1) | |
tgt_FdCoeff_q8 = tgt_FdCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1) | |
fd_distance = q8_table[src_FdCoeff_q8.reshape(-1).long(), tgt_FdCoeff_q8.reshape(-1).long()] | |
fd_distance = fd_distance.reshape( | |
src_FdCoeff_q8.shape[0], src_FdCoeff_q8.shape[1], src_FdCoeff_q8.shape[2], | |
src_FdCoeff_q8.shape[3], src_FdCoeff_q8.shape[4], src_FdCoeff_q8.shape[5]) | |
fd_distance = torch.sum(fd_distance, dim=-1) * 2.0 | |
src_CirCoeff_q8 = src_CirCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1) | |
tgt_CirCoeff_q8 = tgt_CirCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10) | |
cir_distance = q8_table[src_CirCoeff_q8.reshape(-1).long(), tgt_CirCoeff_q8.reshape(-1).long()] | |
cir_distance = cir_distance.reshape( | |
src_CirCoeff_q8.shape[0], src_CirCoeff_q8.shape[1], | |
src_CirCoeff_q8.shape[2], | |
src_CirCoeff_q8.shape[3], src_CirCoeff_q8.shape[4]) | |
cir_distance = cir_distance * 2.0 | |
src_EccCoeff_q8 = src_EccCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1) | |
tgt_EccCoeff_q8 = tgt_EccCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10) | |
ecc_distance = q8_table[src_EccCoeff_q8.reshape(-1).long(), tgt_EccCoeff_q8.reshape(-1).long()] | |
ecc_distance = ecc_distance.reshape( | |
src_EccCoeff_q8.shape[0], src_EccCoeff_q8.shape[1], | |
src_EccCoeff_q8.shape[2], src_EccCoeff_q8.shape[3], | |
src_EccCoeff_q8.shape[4]) | |
cost = art_distance + fd_distance + cir_distance + ecc_distance | |
# find the cloest matching | |
# cost shape: batch_size x src_camera x src_angle x dst_camera x dst_angle | |
cost = cost.permute(0, 1, 3, 2, 4).long() | |
align_n = align_10[:, :10].reshape(-1) | |
cost_bxsrc_cxdst_cxsrc_axdst_a = cost | |
align_err = torch.gather( | |
input=cost_bxsrc_cxdst_cxsrc_axdst_a, | |
index=align_n.reshape(1, 1, 1, 60 * 10, 1).expand( | |
cost.shape[0], cost.shape[1], | |
cost.shape[2], 60 * 10, 10).long(), | |
dim=3) | |
align_err = align_err.reshape(cost.shape[0], cost.shape[1], cost.shape[2], 60, 10, 10) | |
sum_diag = 0 | |
for i in range(10): | |
sum_diag += align_err[:, :, :, :, i, i] | |
sum_diag = sum_diag.reshape(cost.shape[0], -1) | |
dist = torch.min(sum_diag, dim=-1)[0] | |
return dist | |
class LightFieldDistanceFunction(torch.autograd.Function): | |
def forward( | |
ctx, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log): | |
n = src_ArtCoeff.shape[0] | |
m = tgt_ArtCoeff.shape[0] | |
############## | |
# This is only calculating one pair of distance | |
print(f"src_size: {n}") | |
print(f"tgt_size: {m}") | |
all_dist = [] | |
with torch.no_grad(): | |
for i in tqdm(range(n), mininterval=60, disable=not log): | |
start_idx = 0 | |
n_all_run = tgt_ArtCoeff.shape[0] | |
n_each_run = 1000 | |
one_run_d = [] | |
while start_idx < n_all_run: | |
end_idx = min(n_all_run, start_idx + n_each_run) | |
run_length = end_idx - start_idx | |
d = calculate_lfd_distance( | |
q8_table, align_10, | |
src_ArtCoeff[i:i + 1].expand(run_length, -1, -1, -1), | |
src_FdCoeff_q8[i:i + 1].expand(run_length, -1, -1, -1), | |
src_CirCoeff_q8[i:i + 1].expand(run_length, -1, -1), | |
src_EccCoeff_q8[i:i + 1].expand(run_length, -1, -1), | |
tgt_ArtCoeff[start_idx:end_idx], | |
tgt_FdCoeff_q8[start_idx:end_idx], | |
tgt_CirCoeff_q8[start_idx:end_idx], | |
tgt_EccCoeff_q8[start_idx:end_idx]) | |
start_idx = end_idx | |
one_run_d.append(d) | |
d = torch.cat(one_run_d, dim=0) | |
all_dist.append(d.unsqueeze(dim=0)) | |
dist = torch.cat(all_dist, dim=0) | |
return dist | |
def backward(ctx, graddist): | |
raise NotImplementedError | |
return None, None, None, None, None, None, None, None, None, None | |
class LFD(torch.nn.Module): | |
def forward( | |
self, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log): | |
return LightFieldDistanceFunction.apply( | |
q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8, | |
tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log) | |