File size: 6,748 Bytes
ff00a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
import numpy as np
from scipy.spatial.transform import Rotation


def inverse(matrices):
    """
    torch.inverse() sometimes produces outputs with nan the when batch size is 2.
    Ref https://github.com/pytorch/pytorch/issues/47272
    this function keeps inversing the matrix until successful or maximum tries is reached
    :param matrices Bx3x3
    """
    inverse = None
    max_tries = 5
    while (inverse is None) or (torch.isnan(inverse)).any():
        #torch.cuda.synchronize()
        inverse = torch.inverse(matrices)

        # Break out of the loop when the inverse is successful or there"re no more tries
        max_tries -= 1
        if max_tries == 0:
            break

    # Raise an Exception if the inverse contains nan
    if (torch.isnan(inverse)).any():
        raise Exception("Matrix inverse contains nan!")
    return inverse


class HomographySample:
    def __init__(self, H_tgt, W_tgt, device=None):
        if device is None:
            self.device = torch.device("cpu")
        else:
            self.device = device

        self.Height_tgt = H_tgt
        self.Width_tgt = W_tgt
        self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device)
        self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous()  # 3xHxW

        self.n = self.plane_normal_generation(self.device)

    @staticmethod
    def grid_generation(H, W, device):
        x = np.linspace(0, W-1, W)
        y = np.linspace(0, H-1, H)
        xv, yv = np.meshgrid(x, y)  # HxW
        xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device)
        yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device)
        ones = torch.ones_like(xv)
        meshgrid = torch.stack((xv, yv, ones), dim=2)  # HxWx3
        return meshgrid

    @staticmethod
    def plane_normal_generation(device):
        n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
        return n

    @staticmethod
    def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False):
        """
        Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame,
        i.e, rot_mtx * p_tgt = p_src
        Therefore we need to add negative to x/y/z_angle
        :param roll:
        :param pitch:
        :param yaw:
        :return:
        """
        r = Rotation.from_euler(seq,
                                [-x_angle, -y_angle, -z_angle],
                                degrees=degrees)
        rot_mtx = r.as_matrix().astype(np.float32)
        return rot_mtx


    def sample(self, src_BCHW, d_src_B,
               G_tgt_src,
               K_src_inv, K_tgt):
        """
        Coordinate system: x, y are the image directions, z is pointing to depth direction
        :param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW
                         Assume to be at position P=[I|0]
        :param d_src_B: distance of image plane to src camera origin
        :param G_tgt_src: Bx4x4
        :param K_src_inv: Bx3x3
        :param K_tgt: Bx3x3
        :return: tgt_BCHW
        """
        # parameter processing ------ begin ------
        B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3)
        R_tgt_src = G_tgt_src[:, 0:3, 0:3]
        t_tgt_src = G_tgt_src[:, 0:3, 3]

        Height_tgt = self.Height_tgt
        Width_tgt = self.Width_tgt
        # if R_src_tgt is None:
        #     R_src_tgt = torch.eye(3, dtype=torch.float32, device=src_BCHW.device)
        #     R_src_tgt = R_src_tgt.unsqueeze(0).expand(B, 3, 3)
        # if t_src_tgt is None:
        #     t_src_tgt = torch.tensor([0, 0, 0],
        #                              dtype=torch.float32,
        #                              device=src_BCHW.device)
        #     t_src_tgt = t_src_tgt.unsqueeze(0).expand(B, 3)

        # relationship between FoV and focal length:
        # assume W > H
        # W / 2 = f*tan(\theta / 2)
        # here we default the horizontal FoV as 53.13 degree
        # the vertical FoV can be computed as H/2 = W*tan(\theta/2)

        R_tgt_src = R_tgt_src.to(device=src_BCHW.device)
        t_tgt_src = t_tgt_src.to(device=src_BCHW.device)
        K_src_inv = K_src_inv.to(device=src_BCHW.device)
        K_tgt = K_tgt.to(device=src_BCHW.device)
        # parameter processing ------ end ------

        # the goal is compute H_src_tgt, that maps a tgt pixel to src pixel
        # so we compute H_tgt_src first, and then inverse
        n = self.n.to(device=src_BCHW.device)
        n = n.unsqueeze(0).repeat(B, 1)  # Bx3
        # Bx3x3 - (Bx3x1 * Bx1x3)
        # note here we use -d_src, because the plane function is n^T * X - d_src = 0
        d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3)  # B -> Bx3x3
        R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33
        H_tgt_src = torch.matmul(K_tgt,
                                 torch.matmul(R_tnd, K_src_inv))

        # TODO: fix cuda inverse
        with torch.no_grad():
            H_src_tgt = inverse(H_tgt_src)

        # create tgt image grid, and map to src
        meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device)
        # 3xHxW -> Bx3xHxW
        meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt)

        # wrap meshgrid_tgt_homo to meshgrid_src
        meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1)  # Bx3xHW
        meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N)  # Bx3x3 * Bx3xHW -> Bx3xHW
        # Bx3xHW -> Bx3xHxW -> BxHxWx3
        meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1)
        meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:]  # BxHxWx2

        valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src,
                                         meshgrid_src[:, :, :, 0] > -1)
        valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src,
                                         meshgrid_src[:, :, :, 1] > -1)
        valid_mask = torch.logical_and(valid_mask_x, valid_mask_y)  # BxHxW

        # sample from src_BCHW
        # normalize meshgrid_src to [-1,1]
        meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1
        meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1
        tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border',
                                                   align_corners=False)
        # BxCxHxW, BxHxW
        return tgt_BCHW, valid_mask