import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scipy.ndimage import map_coordinates import cv2 import math from os import makedirs from os.path import join, exists # Based on https://github.com/sunset1995/py360convert class Equirec2Cube: def __init__(self, equ_h, equ_w, face_w): ''' equ_h: int, height of the equirectangular image equ_w: int, width of the equirectangular image face_w: int, the length of each face of the cubemap ''' self.equ_h = equ_h self.equ_w = equ_w self.face_w = face_w self._xyzcube() self._xyz2coor() # For convert R-distance to Z-depth for CubeMaps cosmap = 1 / np.sqrt((2 * self.grid[..., 0]) ** 2 + (2 * self.grid[..., 1]) ** 2 + 1) self.cosmaps = np.concatenate(6 * [cosmap], axis=1)[..., np.newaxis] def _xyzcube(self): ''' Compute the xyz cordinates of the unit cube in [F R B L U D] format. ''' self.xyz = np.zeros((self.face_w, self.face_w * 6, 3), np.float32) rng = np.linspace(-0.5, 0.5, num=self.face_w, dtype=np.float32) self.grid = np.stack(np.meshgrid(rng, -rng), -1) # Front face (z = 0.5) self.xyz[:, 0 * self.face_w:1 * self.face_w, [0, 1]] = self.grid self.xyz[:, 0 * self.face_w:1 * self.face_w, 2] = 0.5 # Right face (x = 0.5) self.xyz[:, 1 * self.face_w:2 * self.face_w, [2, 1]] = self.grid[:, ::-1] self.xyz[:, 1 * self.face_w:2 * self.face_w, 0] = 0.5 # Back face (z = -0.5) self.xyz[:, 2 * self.face_w:3 * self.face_w, [0, 1]] = self.grid[:, ::-1] self.xyz[:, 2 * self.face_w:3 * self.face_w, 2] = -0.5 # Left face (x = -0.5) self.xyz[:, 3 * self.face_w:4 * self.face_w, [2, 1]] = self.grid self.xyz[:, 3 * self.face_w:4 * self.face_w, 0] = -0.5 # Up face (y = 0.5) self.xyz[:, 4 * self.face_w:5 * self.face_w, [0, 2]] = self.grid[::-1, :] self.xyz[:, 4 * self.face_w:5 * self.face_w, 1] = 0.5 # Down face (y = -0.5) self.xyz[:, 5 * self.face_w:6 * self.face_w, [0, 2]] = self.grid self.xyz[:, 5 * self.face_w:6 * self.face_w, 1] = -0.5 def _xyz2coor(self): # x, y, z to longitude and latitude x, y, z = np.split(self.xyz, 3, axis=-1) lon = np.arctan2(x, z) c = np.sqrt(x ** 2 + z ** 2) lat = np.arctan2(y, c) # longitude and latitude to equirectangular coordinate self.coor_x = (lon / (2 * np.pi) + 0.5) * self.equ_w - 0.5 self.coor_y = (-lat / np.pi + 0.5) * self.equ_h - 0.5 def sample_equirec(self, e_img, order=0): pad_u = np.roll(e_img[[0]], self.equ_w // 2, 1) pad_d = np.roll(e_img[[-1]], self.equ_w // 2, 1) e_img = np.concatenate([e_img, pad_d, pad_u], 0) # pad_l = e_img[:, [0]] # pad_r = e_img[:, [-1]] # e_img = np.concatenate([e_img, pad_l, pad_r], 1) return map_coordinates(e_img, [self.coor_y, self.coor_x], order=order, mode='wrap')[..., 0] def run(self, equ_img, equ_dep=None): h, w = equ_img.shape[:2] if h != self.equ_h or w != self.equ_w: equ_img = cv2.resize(equ_img, (self.equ_w, self.equ_h)) if equ_dep is not None: equ_dep = cv2.resize(equ_dep, (self.equ_w, self.equ_h), interpolation=cv2.INTER_NEAREST) cube_img = np.stack([self.sample_equirec(equ_img[..., i], order=1) for i in range(equ_img.shape[2])], axis=-1) if equ_dep is not None: cube_dep = np.stack([self.sample_equirec(equ_dep[..., i], order=0) for i in range(equ_dep.shape[2])], axis=-1) cube_dep = cube_dep * self.cosmaps if equ_dep is not None: return cube_img, cube_dep else: return cube_img # Based on https://github.com/sunset1995/py360convert class Cube2Equirec(nn.Module): def __init__(self, face_w, equ_h, equ_w): super(Cube2Equirec, self).__init__() ''' face_w: int, the length of each face of the cubemap equ_h: int, height of the equirectangular image equ_w: int, width of the equirectangular image ''' self.face_w = face_w self.equ_h = equ_h self.equ_w = equ_w # Get face id to each pixel: 0F 1R 2B 3L 4U 5D self._equirect_facetype() self._equirect_faceuv() def _equirect_facetype(self): ''' 0F 1R 2B 3L 4U 5D ''' tp = np.roll(np.arange(4).repeat(self.equ_w // 4)[None, :].repeat(self.equ_h, 0), 3 * self.equ_w // 8, 1) # Prepare ceil mask mask = np.zeros((self.equ_h, self.equ_w // 4), bool) idx = np.linspace(-np.pi, np.pi, self.equ_w // 4) / 4 idx = self.equ_h // 2 - np.round(np.arctan(np.cos(idx)) * self.equ_h / np.pi).astype(int) for i, j in enumerate(idx): mask[:j, i] = 1 mask = np.roll(np.concatenate([mask] * 4, 1), 3 * self.equ_w // 8, 1) tp[mask] = 4 tp[np.flip(mask, 0)] = 5 self.tp = tp self.mask = mask def _equirect_faceuv(self): lon = ((np.linspace(0, self.equ_w -1, num=self.equ_w, dtype=np.float32 ) +0.5 ) /self.equ_w - 0.5 ) * 2 *np.pi lat = -((np.linspace(0, self.equ_h -1, num=self.equ_h, dtype=np.float32 ) +0.5 ) /self.equ_h -0.5) * np.pi lon, lat = np.meshgrid(lon, lat) coor_u = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) coor_v = np.zeros((self.equ_h, self.equ_w), dtype=np.float32) for i in range(4): mask = (self.tp == i) coor_u[mask] = 0.5 * np.tan(lon[mask] - np.pi * i / 2) coor_v[mask] = -0.5 * np.tan(lat[mask]) / np.cos(lon[mask] - np.pi * i / 2) mask = (self.tp == 4) c = 0.5 * np.tan(np.pi / 2 - lat[mask]) coor_u[mask] = c * np.sin(lon[mask]) coor_v[mask] = c * np.cos(lon[mask]) mask = (self.tp == 5) c = 0.5 * np.tan(np.pi / 2 - np.abs(lat[mask])) coor_u[mask] = c * np.sin(lon[mask]) coor_v[mask] = -c * np.cos(lon[mask]) # Final renormalize coor_u = (np.clip(coor_u, -0.5, 0.5)) * 2 coor_v = (np.clip(coor_v, -0.5, 0.5)) * 2 # Convert to torch tensor self.tp = torch.from_numpy(self.tp.astype(np.float32) / 2.5 - 1) self.coor_u = torch.from_numpy(coor_u) self.coor_v = torch.from_numpy(coor_v) sample_grid = torch.stack([self.coor_u, self.coor_v, self.tp], dim=-1).view(1, 1, self.equ_h, self.equ_w, 3) self.sample_grid = nn.Parameter(sample_grid, requires_grad=False) def forward(self, cube_feat): bs, ch, h, w = cube_feat.shape assert h == self.face_w and w // 6 == self.face_w cube_feat = cube_feat.view(bs, ch, 1, h, w) cube_feat = torch.cat(torch.split(cube_feat, self.face_w, dim=-1), dim=2) cube_feat = cube_feat.view([bs, ch, 6, self.face_w, self.face_w]) sample_grid = torch.cat(bs * [self.sample_grid], dim=0) equi_feat = F.grid_sample(cube_feat, sample_grid, padding_mode="border", align_corners=True) return equi_feat.squeeze(2) # generate patches in a closed-form # the transformation and equation is referred from http://blog.nitishmutha.com/equirectangular/360degree/2017/06/12/How-to-project-Equirectangular-image-to-rectilinear-view.html def pair(t): return t if isinstance(t, tuple) else (t, t) def uv2xyz(uv): xyz = np.zeros((*uv.shape[:-1], 3), dtype = np.float32) xyz[..., 0] = np.multiply(np.cos(uv[..., 1]), np.sin(uv[..., 0])) xyz[..., 1] = np.multiply(np.cos(uv[..., 1]), np.cos(uv[..., 0])) xyz[..., 2] = np.sin(uv[..., 1]) return xyz def equi2pers(erp_img, fov, nrows, patch_size): bs, _, erp_h, erp_w = erp_img.shape height, width = pair(patch_size) fov_h, fov_w = pair(fov) FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) PI = math.pi PI_2 = math.pi * 0.5 PI2 = math.pi * 2 yy, xx = torch.meshgrid(torch.linspace(0, 1, height), torch.linspace(0, 1, width)) screen_points = torch.stack([xx.flatten(), yy.flatten()], -1) if nrows==4: num_rows = 4 num_cols = [3, 6, 6, 3] phi_centers = [-67.5, -22.5, 22.5, 67.5] if nrows==6: num_rows = 6 num_cols = [3, 8, 12, 12, 8, 3] phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] if nrows==3: num_rows = 3 num_cols = [3, 4, 3] phi_centers = [-60, 0, 60] if nrows==5: num_rows = 5 num_cols = [3, 6, 8, 6, 3] phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] phi_interval = 180 // num_rows all_combos = [] erp_mask = [] for i, n_cols in enumerate(num_cols): for j in np.arange(n_cols): theta_interval = 360 / n_cols theta_center = j * theta_interval + theta_interval / 2 center = [theta_center, phi_centers[i]] all_combos.append(center) up = phi_centers[i] + phi_interval / 2 down = phi_centers[i] - phi_interval / 2 left = theta_center - theta_interval / 2 right = theta_center + theta_interval / 2 up = int((up + 90) / 180 * erp_h) down = int((down + 90) / 180 * erp_h) left = int(left / 360 * erp_w) right = int(right / 360 * erp_w) mask = np.zeros((erp_h, erp_w), dtype=int) mask[down:up, left:right] = 1 erp_mask.append(mask) all_combos = np.vstack(all_combos) shifts = np.arange(all_combos.shape[0]) * width shifts = torch.from_numpy(shifts).float() erp_mask = np.stack(erp_mask) erp_mask = torch.from_numpy(erp_mask).float() num_patch = all_combos.shape[0] center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90 center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1 center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1 cp = center_point * 2 - 1 center_p = cp.clone() cp[:, 0] = cp[:, 0] * PI cp[:, 1] = cp[:, 1] * PI_2 cp = cp.unsqueeze(1) convertedCoord = screen_points * 2 - 1 convertedCoord[:, 0] = convertedCoord[:, 0] * PI convertedCoord[:, 1] = convertedCoord[:, 1] * PI_2 convertedCoord = convertedCoord * (torch.ones(screen_points.shape, dtype=torch.float32) * FOV) convertedCoord = convertedCoord.unsqueeze(0).repeat(cp.shape[0], 1, 1) x = convertedCoord[:, :, 0] y = convertedCoord[:, :, 1] rou = torch.sqrt(x ** 2 + y ** 2) c = torch.atan(rou) sin_c = torch.sin(c) cos_c = torch.cos(c) lat = torch.asin(cos_c * torch.sin(cp[:, :, 1]) + (y * sin_c * torch.cos(cp[:, :, 1])) / rou) lon = cp[:, :, 0] + torch.atan2(x * sin_c, rou * torch.cos(cp[:, :, 1]) * cos_c - y * torch.sin(cp[:, :, 1]) * sin_c) lat_new = lat / PI_2 lon_new = lon / PI lon_new[lon_new > 1] -= 2 lon_new[lon_new<-1] += 2 lon_new = lon_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) lat_new = lat_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width) grid = torch.stack([lon_new, lat_new], -1) grid = grid.unsqueeze(0).repeat(bs, 1, 1, 1).to(erp_img.device) pers = F.grid_sample(erp_img, grid, mode='bilinear', padding_mode='border', align_corners=True) pers = F.unfold(pers, kernel_size=(height, width), stride=(height, width)) pers = pers.reshape(bs, -1, height, width, num_patch) grid_tmp = torch.stack([lon, lat], -1) xyz = uv2xyz(grid_tmp) xyz = xyz.reshape(num_patch, height, width, 3).transpose(0, 3, 1, 2) xyz = torch.from_numpy(xyz).to(pers.device).contiguous() uv = grid[0, ...].reshape(height, width, num_patch, 2).permute(2, 3, 0, 1) uv = uv.contiguous() return pers, xyz, uv, center_p def pers2equi(pers_img, fov, nrows, patch_size, erp_size, layer_name): bs = pers_img.shape[0] channel = pers_img.shape[1] device=pers_img.device height, width = pair(patch_size) fov_h, fov_w = pair(fov) erp_h, erp_w = pair(erp_size) n_patch = pers_img.shape[-1] grid_dir = './grid' if not exists(grid_dir): makedirs(grid_dir) grid_file = join(grid_dir, layer_name + '.pth') if not exists(grid_file): FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32) PI = math.pi PI_2 = math.pi * 0.5 PI2 = math.pi * 2 if nrows==4: num_rows = 4 num_cols = [3, 6, 6, 3] phi_centers = [-67.5, -22.5, 22.5, 67.5] if nrows==6: num_rows = 6 num_cols = [3, 8, 12, 12, 8, 3] phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2] if nrows==3: num_rows = 3 num_cols = [3, 4, 3] phi_centers = [-59.6, 0, 59.6] if nrows==5: num_rows = 5 num_cols = [3, 6, 8, 6, 3] phi_centers = [-72.2, -36.1, 0, 36.1, 72.2] phi_interval = 180 // num_rows all_combos = [] for i, n_cols in enumerate(num_cols): for j in np.arange(n_cols): theta_interval = 360 / n_cols theta_center = j * theta_interval + theta_interval / 2 center = [theta_center, phi_centers[i]] all_combos.append(center) all_combos = np.vstack(all_combos) n_patch = all_combos.shape[0] center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90 center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1 center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1 cp = center_point * 2 - 1 cp[:, 0] = cp[:, 0] * PI cp[:, 1] = cp[:, 1] * PI_2 cp = cp.unsqueeze(1) lat_grid, lon_grid = torch.meshgrid(torch.linspace(-PI_2, PI_2, erp_h), torch.linspace(-PI, PI, erp_w)) lon_grid = lon_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1) lat_grid = lat_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1) cos_c = torch.sin(cp[..., 1]) * torch.sin(lat_grid) + torch.cos(cp[..., 1]) * torch.cos(lat_grid) * torch.cos(lon_grid - cp[..., 0]) new_x = (torch.cos(lat_grid) * torch.sin(lon_grid - cp[..., 0])) / cos_c new_y = (torch.cos(cp[..., 1])*torch.sin(lat_grid) - torch.sin(cp[...,1])*torch.cos(lat_grid)*torch.cos(lon_grid-cp[...,0])) / cos_c new_x = new_x / FOV[0] / PI # -1 to 1 new_y = new_y / FOV[1] / PI_2 cos_c_mask = cos_c.reshape(n_patch, erp_h, erp_w) cos_c_mask = torch.where(cos_c_mask > 0, 1, 0) w_list = torch.zeros((n_patch, erp_h, erp_w, 4), dtype=torch.float32) new_x_patch = (new_x + 1) * 0.5 * height new_y_patch = (new_y + 1) * 0.5 * width new_x_patch = new_x_patch.reshape(n_patch, erp_h, erp_w) new_y_patch = new_y_patch.reshape(n_patch, erp_h, erp_w) mask = torch.where((new_x_patch < width) & (new_x_patch > 0) & (new_y_patch < height) & (new_y_patch > 0), 1, 0) mask *= cos_c_mask x0 = torch.floor(new_x_patch).type(torch.int64) x1 = x0 + 1 y0 = torch.floor(new_y_patch).type(torch.int64) y1 = y0 + 1 x0 = torch.clamp(x0, 0, width-1) x1 = torch.clamp(x1, 0, width-1) y0 = torch.clamp(y0, 0, height-1) y1 = torch.clamp(y1, 0, height-1) wa = (x1.type(torch.float32)-new_x_patch) * (y1.type(torch.float32)-new_y_patch) wb = (x1.type(torch.float32)-new_x_patch) * (new_y_patch-y0.type(torch.float32)) wc = (new_x_patch-x0.type(torch.float32)) * (y1.type(torch.float32)-new_y_patch) wd = (new_x_patch-x0.type(torch.float32)) * (new_y_patch-y0.type(torch.float32)) wa = wa * mask.expand_as(wa) wb = wb * mask.expand_as(wb) wc = wc * mask.expand_as(wc) wd = wd * mask.expand_as(wd) w_list[..., 0] = wa w_list[..., 1] = wb w_list[..., 2] = wc w_list[..., 3] = wd save_file = {'x0':x0, 'y0':y0, 'x1':x1, 'y1':y1, 'w_list': w_list, 'mask':mask} torch.save(save_file, grid_file) else: # the online merge really takes time # pre-calculate the grid for once and use it during training load_file = torch.load(grid_file) #print('load_file') x0 = load_file['x0'] y0 = load_file['y0'] x1 = load_file['x1'] y1 = load_file['y1'] w_list = load_file['w_list'] mask = load_file['mask'] w_list = w_list.to(device) mask = mask.to(device) z = torch.arange(n_patch) z = z.reshape(n_patch, 1, 1) Ia = pers_img[:, :, y0, x0, z] Ib = pers_img[:, :, y1, x0, z] Ic = pers_img[:, :, y0, x1, z] Id = pers_img[:, :, y1, x1, z] output_a = Ia * mask.expand_as(Ia) output_b = Ib * mask.expand_as(Ib) output_c = Ic * mask.expand_as(Ic) output_d = Id * mask.expand_as(Id) output_a = output_a.permute(0, 1, 3, 4, 2) output_b = output_b.permute(0, 1, 3, 4, 2) output_c = output_c.permute(0, 1, 3, 4, 2) output_d = output_d.permute(0, 1, 3, 4, 2) w_list = w_list.permute(1, 2, 0, 3) w_list = w_list.flatten(2) w_list *= torch.gt(w_list, 1e-5).type(torch.float32) w_list = F.normalize(w_list, p=1, dim=-1).reshape(erp_h, erp_w, n_patch, 4) w_list = w_list.unsqueeze(0).unsqueeze(0) output = output_a * w_list[..., 0] + output_b * w_list[..., 1] + \ output_c * w_list[..., 2] + output_d * w_list[..., 3] img_erp = output.sum(-1) return img_erp def img2windows(img, H_sp, W_sp): """ img: B C H W """ B, C, H, W = img.shape img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp, W_sp, C) return img_perm def windows2img(img_splits_hw, H_sp, W_sp, H, W): """ img_splits_hw: B' H W C """ B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return img