File size: 6,349 Bytes
7b127f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

# !/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
from tqdm import tqdm


def calculate_lfd_distance(
        q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8,
        tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8):
    with torch.no_grad():
        src_ArtCoeff = src_ArtCoeff.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1)
        tgt_ArtCoeff = tgt_ArtCoeff.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1)
        art_distance = q8_table[src_ArtCoeff.reshape(-1).long(), tgt_ArtCoeff.reshape(-1).long()]
        art_distance = art_distance.reshape(
            src_ArtCoeff.shape[0], src_ArtCoeff.shape[1], src_ArtCoeff.shape[2],
            src_ArtCoeff.shape[3],
            src_ArtCoeff.shape[4], src_ArtCoeff.shape[5])
        art_distance = torch.sum(art_distance, dim=-1)

        src_FdCoeff_q8 = src_FdCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1, -1)
        tgt_FdCoeff_q8 = tgt_FdCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10, -1)
        fd_distance = q8_table[src_FdCoeff_q8.reshape(-1).long(), tgt_FdCoeff_q8.reshape(-1).long()]
        fd_distance = fd_distance.reshape(
            src_FdCoeff_q8.shape[0], src_FdCoeff_q8.shape[1], src_FdCoeff_q8.shape[2],
            src_FdCoeff_q8.shape[3], src_FdCoeff_q8.shape[4], src_FdCoeff_q8.shape[5])
        fd_distance = torch.sum(fd_distance, dim=-1) * 2.0

        src_CirCoeff_q8 = src_CirCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1)
        tgt_CirCoeff_q8 = tgt_CirCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10)
        cir_distance = q8_table[src_CirCoeff_q8.reshape(-1).long(), tgt_CirCoeff_q8.reshape(-1).long()]
        cir_distance = cir_distance.reshape(
            src_CirCoeff_q8.shape[0], src_CirCoeff_q8.shape[1],
            src_CirCoeff_q8.shape[2],
            src_CirCoeff_q8.shape[3], src_CirCoeff_q8.shape[4])
        cir_distance = cir_distance * 2.0
        src_EccCoeff_q8 = src_EccCoeff_q8.unsqueeze(dim=1).unsqueeze(dim=1).expand(-1, 10, 10, -1, -1)
        tgt_EccCoeff_q8 = tgt_EccCoeff_q8.unsqueeze(dim=3).unsqueeze(dim=3).expand(-1, -1, -1, 10, 10)
        ecc_distance = q8_table[src_EccCoeff_q8.reshape(-1).long(), tgt_EccCoeff_q8.reshape(-1).long()]
        ecc_distance = ecc_distance.reshape(
            src_EccCoeff_q8.shape[0], src_EccCoeff_q8.shape[1],
            src_EccCoeff_q8.shape[2], src_EccCoeff_q8.shape[3],
            src_EccCoeff_q8.shape[4])
        cost = art_distance + fd_distance + cir_distance + ecc_distance
        # find the cloest matching
        # cost shape: batch_size x src_camera x src_angle x dst_camera x dst_angle
        cost = cost.permute(0, 1, 3, 2, 4).long()
        align_n = align_10[:, :10].reshape(-1)
        cost_bxsrc_cxdst_cxsrc_axdst_a = cost
        align_err = torch.gather(
            input=cost_bxsrc_cxdst_cxsrc_axdst_a,
            index=align_n.reshape(1, 1, 1, 60 * 10, 1).expand(
                cost.shape[0], cost.shape[1],
                cost.shape[2], 60 * 10, 10).long(),
            dim=3)
        align_err = align_err.reshape(cost.shape[0], cost.shape[1], cost.shape[2], 60, 10, 10)
        sum_diag = 0
        for i in range(10):
            sum_diag += align_err[:, :, :, :, i, i]
        sum_diag = sum_diag.reshape(cost.shape[0], -1)
        dist = torch.min(sum_diag, dim=-1)[0]
    return dist


class LightFieldDistanceFunction(torch.autograd.Function):
    @staticmethod
    def forward(
            ctx, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8,
            tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log):
        n = src_ArtCoeff.shape[0]
        m = tgt_ArtCoeff.shape[0]
        ##############
        # This is only calculating one pair of distance
        print(f"src_size: {n}")
        print(f"tgt_size: {m}")
        all_dist = []
        with torch.no_grad():
            for i in tqdm(range(n), mininterval=60, disable=not log):
                start_idx = 0
                n_all_run = tgt_ArtCoeff.shape[0]
                n_each_run = 1000
                one_run_d = []
                while start_idx < n_all_run:
                    end_idx = min(n_all_run, start_idx + n_each_run)
                    run_length = end_idx - start_idx
                    d = calculate_lfd_distance(
                        q8_table, align_10,
                        src_ArtCoeff[i:i + 1].expand(run_length, -1, -1, -1),
                        src_FdCoeff_q8[i:i + 1].expand(run_length, -1, -1, -1),
                        src_CirCoeff_q8[i:i + 1].expand(run_length, -1, -1),
                        src_EccCoeff_q8[i:i + 1].expand(run_length, -1, -1),
                        tgt_ArtCoeff[start_idx:end_idx],
                        tgt_FdCoeff_q8[start_idx:end_idx],
                        tgt_CirCoeff_q8[start_idx:end_idx],
                        tgt_EccCoeff_q8[start_idx:end_idx])
                    start_idx = end_idx
                    one_run_d.append(d)
                d = torch.cat(one_run_d, dim=0)
                all_dist.append(d.unsqueeze(dim=0))
        dist = torch.cat(all_dist, dim=0)

        return dist

    @staticmethod
    def backward(ctx, graddist):
        raise NotImplementedError
        return None, None, None, None, None, None, None, None, None, None


class LFD(torch.nn.Module):
    def forward(
            self, q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8,
            tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log):
        return LightFieldDistanceFunction.apply(
            q8_table, align_10, src_ArtCoeff, src_FdCoeff_q8, src_CirCoeff_q8, src_EccCoeff_q8,
            tgt_ArtCoeff, tgt_FdCoeff_q8, tgt_CirCoeff_q8, tgt_EccCoeff_q8, log)