caozidong
init
3ae7741
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage import map_coordinates
import cv2
import math
from os import makedirs
from os.path import join, exists
# Based on https://github.com/sunset1995/py360convert
class Equirec2Cube:
def __init__(self, equ_h, equ_w, face_w):
'''
equ_h: int, height of the equirectangular image
equ_w: int, width of the equirectangular image
face_w: int, the length of each face of the cubemap
'''
self.equ_h = equ_h
self.equ_w = equ_w
self.face_w = face_w
self._xyzcube()
self._xyz2coor()
# For convert R-distance to Z-depth for CubeMaps
cosmap = 1 / np.sqrt((2 * self.grid[..., 0]) ** 2 + (2 * self.grid[..., 1]) ** 2 + 1)
self.cosmaps = np.concatenate(6 * [cosmap], axis=1)[..., np.newaxis]
def _xyzcube(self):
'''
Compute the xyz cordinates of the unit cube in [F R B L U D] format.
'''
self.xyz = np.zeros((self.face_w, self.face_w * 6, 3), np.float32)
rng = np.linspace(-0.5, 0.5, num=self.face_w, dtype=np.float32)
self.grid = np.stack(np.meshgrid(rng, -rng), -1)
# Front face (z = 0.5)
self.xyz[:, 0 * self.face_w:1 * self.face_w, [0, 1]] = self.grid
self.xyz[:, 0 * self.face_w:1 * self.face_w, 2] = 0.5
# Right face (x = 0.5)
self.xyz[:, 1 * self.face_w:2 * self.face_w, [2, 1]] = self.grid[:, ::-1]
self.xyz[:, 1 * self.face_w:2 * self.face_w, 0] = 0.5
# Back face (z = -0.5)
self.xyz[:, 2 * self.face_w:3 * self.face_w, [0, 1]] = self.grid[:, ::-1]
self.xyz[:, 2 * self.face_w:3 * self.face_w, 2] = -0.5
# Left face (x = -0.5)
self.xyz[:, 3 * self.face_w:4 * self.face_w, [2, 1]] = self.grid
self.xyz[:, 3 * self.face_w:4 * self.face_w, 0] = -0.5
# Up face (y = 0.5)
self.xyz[:, 4 * self.face_w:5 * self.face_w, [0, 2]] = self.grid[::-1, :]
self.xyz[:, 4 * self.face_w:5 * self.face_w, 1] = 0.5
# Down face (y = -0.5)
self.xyz[:, 5 * self.face_w:6 * self.face_w, [0, 2]] = self.grid
self.xyz[:, 5 * self.face_w:6 * self.face_w, 1] = -0.5
def _xyz2coor(self):
# x, y, z to longitude and latitude
x, y, z = np.split(self.xyz, 3, axis=-1)
lon = np.arctan2(x, z)
c = np.sqrt(x ** 2 + z ** 2)
lat = np.arctan2(y, c)
# longitude and latitude to equirectangular coordinate
self.coor_x = (lon / (2 * np.pi) + 0.5) * self.equ_w - 0.5
self.coor_y = (-lat / np.pi + 0.5) * self.equ_h - 0.5
def sample_equirec(self, e_img, order=0):
pad_u = np.roll(e_img[[0]], self.equ_w // 2, 1)
pad_d = np.roll(e_img[[-1]], self.equ_w // 2, 1)
e_img = np.concatenate([e_img, pad_d, pad_u], 0)
# pad_l = e_img[:, [0]]
# pad_r = e_img[:, [-1]]
# e_img = np.concatenate([e_img, pad_l, pad_r], 1)
return map_coordinates(e_img, [self.coor_y, self.coor_x],
order=order, mode='wrap')[..., 0]
def run(self, equ_img, equ_dep=None):
h, w = equ_img.shape[:2]
if h != self.equ_h or w != self.equ_w:
equ_img = cv2.resize(equ_img, (self.equ_w, self.equ_h))
if equ_dep is not None:
equ_dep = cv2.resize(equ_dep, (self.equ_w, self.equ_h), interpolation=cv2.INTER_NEAREST)
cube_img = np.stack([self.sample_equirec(equ_img[..., i], order=1)
for i in range(equ_img.shape[2])], axis=-1)
if equ_dep is not None:
cube_dep = np.stack([self.sample_equirec(equ_dep[..., i], order=0)
for i in range(equ_dep.shape[2])], axis=-1)
cube_dep = cube_dep * self.cosmaps
if equ_dep is not None:
return cube_img, cube_dep
else:
return cube_img
# Based on https://github.com/sunset1995/py360convert
class Cube2Equirec(nn.Module):
def __init__(self, face_w, equ_h, equ_w):
super(Cube2Equirec, self).__init__()
'''
face_w: int, the length of each face of the cubemap
equ_h: int, height of the equirectangular image
equ_w: int, width of the equirectangular image
'''
self.face_w = face_w
self.equ_h = equ_h
self.equ_w = equ_w
# Get face id to each pixel: 0F 1R 2B 3L 4U 5D
self._equirect_facetype()
self._equirect_faceuv()
def _equirect_facetype(self):
'''
0F 1R 2B 3L 4U 5D
'''
tp = np.roll(np.arange(4).repeat(self.equ_w // 4)[None, :].repeat(self.equ_h, 0), 3 * self.equ_w // 8, 1)
# Prepare ceil mask
mask = np.zeros((self.equ_h, self.equ_w // 4), bool)
idx = np.linspace(-np.pi, np.pi, self.equ_w // 4) / 4
idx = self.equ_h // 2 - np.round(np.arctan(np.cos(idx)) * self.equ_h / np.pi).astype(int)
for i, j in enumerate(idx):
mask[:j, i] = 1
mask = np.roll(np.concatenate([mask] * 4, 1), 3 * self.equ_w // 8, 1)
tp[mask] = 4
tp[np.flip(mask, 0)] = 5
self.tp = tp
self.mask = mask
def _equirect_faceuv(self):
lon = ((np.linspace(0, self.equ_w -1, num=self.equ_w, dtype=np.float32 ) +0.5 ) /self.equ_w - 0.5 ) * 2 *np.pi
lat = -((np.linspace(0, self.equ_h -1, num=self.equ_h, dtype=np.float32 ) +0.5 ) /self.equ_h -0.5) * np.pi
lon, lat = np.meshgrid(lon, lat)
coor_u = np.zeros((self.equ_h, self.equ_w), dtype=np.float32)
coor_v = np.zeros((self.equ_h, self.equ_w), dtype=np.float32)
for i in range(4):
mask = (self.tp == i)
coor_u[mask] = 0.5 * np.tan(lon[mask] - np.pi * i / 2)
coor_v[mask] = -0.5 * np.tan(lat[mask]) / np.cos(lon[mask] - np.pi * i / 2)
mask = (self.tp == 4)
c = 0.5 * np.tan(np.pi / 2 - lat[mask])
coor_u[mask] = c * np.sin(lon[mask])
coor_v[mask] = c * np.cos(lon[mask])
mask = (self.tp == 5)
c = 0.5 * np.tan(np.pi / 2 - np.abs(lat[mask]))
coor_u[mask] = c * np.sin(lon[mask])
coor_v[mask] = -c * np.cos(lon[mask])
# Final renormalize
coor_u = (np.clip(coor_u, -0.5, 0.5)) * 2
coor_v = (np.clip(coor_v, -0.5, 0.5)) * 2
# Convert to torch tensor
self.tp = torch.from_numpy(self.tp.astype(np.float32) / 2.5 - 1)
self.coor_u = torch.from_numpy(coor_u)
self.coor_v = torch.from_numpy(coor_v)
sample_grid = torch.stack([self.coor_u, self.coor_v, self.tp], dim=-1).view(1, 1, self.equ_h, self.equ_w, 3)
self.sample_grid = nn.Parameter(sample_grid, requires_grad=False)
def forward(self, cube_feat):
bs, ch, h, w = cube_feat.shape
assert h == self.face_w and w // 6 == self.face_w
cube_feat = cube_feat.view(bs, ch, 1, h, w)
cube_feat = torch.cat(torch.split(cube_feat, self.face_w, dim=-1), dim=2)
cube_feat = cube_feat.view([bs, ch, 6, self.face_w, self.face_w])
sample_grid = torch.cat(bs * [self.sample_grid], dim=0)
equi_feat = F.grid_sample(cube_feat, sample_grid, padding_mode="border", align_corners=True)
return equi_feat.squeeze(2)
# generate patches in a closed-form
# the transformation and equation is referred from http://blog.nitishmutha.com/equirectangular/360degree/2017/06/12/How-to-project-Equirectangular-image-to-rectilinear-view.html
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def uv2xyz(uv):
xyz = np.zeros((*uv.shape[:-1], 3), dtype = np.float32)
xyz[..., 0] = np.multiply(np.cos(uv[..., 1]), np.sin(uv[..., 0]))
xyz[..., 1] = np.multiply(np.cos(uv[..., 1]), np.cos(uv[..., 0]))
xyz[..., 2] = np.sin(uv[..., 1])
return xyz
def equi2pers(erp_img, fov, nrows, patch_size):
bs, _, erp_h, erp_w = erp_img.shape
height, width = pair(patch_size)
fov_h, fov_w = pair(fov)
FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32)
PI = math.pi
PI_2 = math.pi * 0.5
PI2 = math.pi * 2
yy, xx = torch.meshgrid(torch.linspace(0, 1, height), torch.linspace(0, 1, width))
screen_points = torch.stack([xx.flatten(), yy.flatten()], -1)
if nrows==4:
num_rows = 4
num_cols = [3, 6, 6, 3]
phi_centers = [-67.5, -22.5, 22.5, 67.5]
if nrows==6:
num_rows = 6
num_cols = [3, 8, 12, 12, 8, 3]
phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2]
if nrows==3:
num_rows = 3
num_cols = [3, 4, 3]
phi_centers = [-60, 0, 60]
if nrows==5:
num_rows = 5
num_cols = [3, 6, 8, 6, 3]
phi_centers = [-72.2, -36.1, 0, 36.1, 72.2]
phi_interval = 180 // num_rows
all_combos = []
erp_mask = []
for i, n_cols in enumerate(num_cols):
for j in np.arange(n_cols):
theta_interval = 360 / n_cols
theta_center = j * theta_interval + theta_interval / 2
center = [theta_center, phi_centers[i]]
all_combos.append(center)
up = phi_centers[i] + phi_interval / 2
down = phi_centers[i] - phi_interval / 2
left = theta_center - theta_interval / 2
right = theta_center + theta_interval / 2
up = int((up + 90) / 180 * erp_h)
down = int((down + 90) / 180 * erp_h)
left = int(left / 360 * erp_w)
right = int(right / 360 * erp_w)
mask = np.zeros((erp_h, erp_w), dtype=int)
mask[down:up, left:right] = 1
erp_mask.append(mask)
all_combos = np.vstack(all_combos)
shifts = np.arange(all_combos.shape[0]) * width
shifts = torch.from_numpy(shifts).float()
erp_mask = np.stack(erp_mask)
erp_mask = torch.from_numpy(erp_mask).float()
num_patch = all_combos.shape[0]
center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90
center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1
center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1
cp = center_point * 2 - 1
center_p = cp.clone()
cp[:, 0] = cp[:, 0] * PI
cp[:, 1] = cp[:, 1] * PI_2
cp = cp.unsqueeze(1)
convertedCoord = screen_points * 2 - 1
convertedCoord[:, 0] = convertedCoord[:, 0] * PI
convertedCoord[:, 1] = convertedCoord[:, 1] * PI_2
convertedCoord = convertedCoord * (torch.ones(screen_points.shape, dtype=torch.float32) * FOV)
convertedCoord = convertedCoord.unsqueeze(0).repeat(cp.shape[0], 1, 1)
x = convertedCoord[:, :, 0]
y = convertedCoord[:, :, 1]
rou = torch.sqrt(x ** 2 + y ** 2)
c = torch.atan(rou)
sin_c = torch.sin(c)
cos_c = torch.cos(c)
lat = torch.asin(cos_c * torch.sin(cp[:, :, 1]) + (y * sin_c * torch.cos(cp[:, :, 1])) / rou)
lon = cp[:, :, 0] + torch.atan2(x * sin_c, rou * torch.cos(cp[:, :, 1]) * cos_c - y * torch.sin(cp[:, :, 1]) * sin_c)
lat_new = lat / PI_2
lon_new = lon / PI
lon_new[lon_new > 1] -= 2
lon_new[lon_new<-1] += 2
lon_new = lon_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width)
lat_new = lat_new.view(1, num_patch, height, width).permute(0, 2, 1, 3).contiguous().view(height, num_patch*width)
grid = torch.stack([lon_new, lat_new], -1)
grid = grid.unsqueeze(0).repeat(bs, 1, 1, 1).to(erp_img.device)
pers = F.grid_sample(erp_img, grid, mode='bilinear', padding_mode='border', align_corners=True)
pers = F.unfold(pers, kernel_size=(height, width), stride=(height, width))
pers = pers.reshape(bs, -1, height, width, num_patch)
grid_tmp = torch.stack([lon, lat], -1)
xyz = uv2xyz(grid_tmp)
xyz = xyz.reshape(num_patch, height, width, 3).transpose(0, 3, 1, 2)
xyz = torch.from_numpy(xyz).to(pers.device).contiguous()
uv = grid[0, ...].reshape(height, width, num_patch, 2).permute(2, 3, 0, 1)
uv = uv.contiguous()
return pers, xyz, uv, center_p
def pers2equi(pers_img, fov, nrows, patch_size, erp_size, layer_name):
bs = pers_img.shape[0]
channel = pers_img.shape[1]
device=pers_img.device
height, width = pair(patch_size)
fov_h, fov_w = pair(fov)
erp_h, erp_w = pair(erp_size)
n_patch = pers_img.shape[-1]
grid_dir = './grid'
if not exists(grid_dir):
makedirs(grid_dir)
grid_file = join(grid_dir, layer_name + '.pth')
if not exists(grid_file):
FOV = torch.tensor([fov_w/360.0, fov_h/180.0], dtype=torch.float32)
PI = math.pi
PI_2 = math.pi * 0.5
PI2 = math.pi * 2
if nrows==4:
num_rows = 4
num_cols = [3, 6, 6, 3]
phi_centers = [-67.5, -22.5, 22.5, 67.5]
if nrows==6:
num_rows = 6
num_cols = [3, 8, 12, 12, 8, 3]
phi_centers = [-75.2, -45.93, -15.72, 15.72, 45.93, 75.2]
if nrows==3:
num_rows = 3
num_cols = [3, 4, 3]
phi_centers = [-59.6, 0, 59.6]
if nrows==5:
num_rows = 5
num_cols = [3, 6, 8, 6, 3]
phi_centers = [-72.2, -36.1, 0, 36.1, 72.2]
phi_interval = 180 // num_rows
all_combos = []
for i, n_cols in enumerate(num_cols):
for j in np.arange(n_cols):
theta_interval = 360 / n_cols
theta_center = j * theta_interval + theta_interval / 2
center = [theta_center, phi_centers[i]]
all_combos.append(center)
all_combos = np.vstack(all_combos)
n_patch = all_combos.shape[0]
center_point = torch.from_numpy(all_combos).float() # -180 to 180, -90 to 90
center_point[:, 0] = (center_point[:, 0]) / 360 #0 to 1
center_point[:, 1] = (center_point[:, 1] + 90) / 180 #0 to 1
cp = center_point * 2 - 1
cp[:, 0] = cp[:, 0] * PI
cp[:, 1] = cp[:, 1] * PI_2
cp = cp.unsqueeze(1)
lat_grid, lon_grid = torch.meshgrid(torch.linspace(-PI_2, PI_2, erp_h), torch.linspace(-PI, PI, erp_w))
lon_grid = lon_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1)
lat_grid = lat_grid.float().reshape(1, -1)#.repeat(num_rows*num_cols, 1)
cos_c = torch.sin(cp[..., 1]) * torch.sin(lat_grid) + torch.cos(cp[..., 1]) * torch.cos(lat_grid) * torch.cos(lon_grid - cp[..., 0])
new_x = (torch.cos(lat_grid) * torch.sin(lon_grid - cp[..., 0])) / cos_c
new_y = (torch.cos(cp[..., 1])*torch.sin(lat_grid) - torch.sin(cp[...,1])*torch.cos(lat_grid)*torch.cos(lon_grid-cp[...,0])) / cos_c
new_x = new_x / FOV[0] / PI # -1 to 1
new_y = new_y / FOV[1] / PI_2
cos_c_mask = cos_c.reshape(n_patch, erp_h, erp_w)
cos_c_mask = torch.where(cos_c_mask > 0, 1, 0)
w_list = torch.zeros((n_patch, erp_h, erp_w, 4), dtype=torch.float32)
new_x_patch = (new_x + 1) * 0.5 * height
new_y_patch = (new_y + 1) * 0.5 * width
new_x_patch = new_x_patch.reshape(n_patch, erp_h, erp_w)
new_y_patch = new_y_patch.reshape(n_patch, erp_h, erp_w)
mask = torch.where((new_x_patch < width) & (new_x_patch > 0) & (new_y_patch < height) & (new_y_patch > 0), 1, 0)
mask *= cos_c_mask
x0 = torch.floor(new_x_patch).type(torch.int64)
x1 = x0 + 1
y0 = torch.floor(new_y_patch).type(torch.int64)
y1 = y0 + 1
x0 = torch.clamp(x0, 0, width-1)
x1 = torch.clamp(x1, 0, width-1)
y0 = torch.clamp(y0, 0, height-1)
y1 = torch.clamp(y1, 0, height-1)
wa = (x1.type(torch.float32)-new_x_patch) * (y1.type(torch.float32)-new_y_patch)
wb = (x1.type(torch.float32)-new_x_patch) * (new_y_patch-y0.type(torch.float32))
wc = (new_x_patch-x0.type(torch.float32)) * (y1.type(torch.float32)-new_y_patch)
wd = (new_x_patch-x0.type(torch.float32)) * (new_y_patch-y0.type(torch.float32))
wa = wa * mask.expand_as(wa)
wb = wb * mask.expand_as(wb)
wc = wc * mask.expand_as(wc)
wd = wd * mask.expand_as(wd)
w_list[..., 0] = wa
w_list[..., 1] = wb
w_list[..., 2] = wc
w_list[..., 3] = wd
save_file = {'x0':x0, 'y0':y0, 'x1':x1, 'y1':y1, 'w_list': w_list, 'mask':mask}
torch.save(save_file, grid_file)
else:
# the online merge really takes time
# pre-calculate the grid for once and use it during training
load_file = torch.load(grid_file)
#print('load_file')
x0 = load_file['x0']
y0 = load_file['y0']
x1 = load_file['x1']
y1 = load_file['y1']
w_list = load_file['w_list']
mask = load_file['mask']
w_list = w_list.to(device)
mask = mask.to(device)
z = torch.arange(n_patch)
z = z.reshape(n_patch, 1, 1)
Ia = pers_img[:, :, y0, x0, z]
Ib = pers_img[:, :, y1, x0, z]
Ic = pers_img[:, :, y0, x1, z]
Id = pers_img[:, :, y1, x1, z]
output_a = Ia * mask.expand_as(Ia)
output_b = Ib * mask.expand_as(Ib)
output_c = Ic * mask.expand_as(Ic)
output_d = Id * mask.expand_as(Id)
output_a = output_a.permute(0, 1, 3, 4, 2)
output_b = output_b.permute(0, 1, 3, 4, 2)
output_c = output_c.permute(0, 1, 3, 4, 2)
output_d = output_d.permute(0, 1, 3, 4, 2)
w_list = w_list.permute(1, 2, 0, 3)
w_list = w_list.flatten(2)
w_list *= torch.gt(w_list, 1e-5).type(torch.float32)
w_list = F.normalize(w_list, p=1, dim=-1).reshape(erp_h, erp_w, n_patch, 4)
w_list = w_list.unsqueeze(0).unsqueeze(0)
output = output_a * w_list[..., 0] + output_b * w_list[..., 1] + \
output_c * w_list[..., 2] + output_d * w_list[..., 3]
img_erp = output.sum(-1)
return img_erp
def img2windows(img, H_sp, W_sp):
"""
img: B C H W
"""
B, C, H, W = img.shape
img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp, W_sp, C)
return img_perm
def windows2img(img_splits_hw, H_sp, W_sp, H, W):
"""
img_splits_hw: B' H W C
"""
B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return img