import torch from torch.utils.checkpoint import _get_autocast_kwargs from src.models.gs_render.gs_util import render class DeferredBP(torch.autograd.Function): @staticmethod def forward(ctx, xyz, rgb, scale, rotation, opacity, height, width, C2W, fxfycxcy, patch_size, gaussian_model, znear, zfar, bg_color, scaling_modifier, render_dn, ): """ Forward rendering. """ assert (xyz.dim() == 3) and (rgb.dim() == 3) and (scale.dim() == 3) and (rotation.dim() == 3) assert height % patch_size == 0 and width % patch_size == 0 ctx.save_for_backward(xyz, rgb, scale, rotation, opacity) # save tensors for backward ctx.height = height ctx.width = width ctx.C2W = C2W ctx.fxfycxcy = fxfycxcy ctx.patch_size = patch_size ctx.gaussian_model = gaussian_model ctx.znear = znear ctx.zfar = zfar ctx.bg_color = bg_color ctx.scaling_modifier = scaling_modifier ctx.render_dn = render_dn ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() ctx.manual_seeds = [] with ( torch.no_grad(), torch.autocast("cuda", **ctx.gpu_autocast_kwargs), torch.autocast("cpu", **ctx.cpu_autocast_kwargs), ): device, (B, V) = C2W.device, C2W.shape[:2] images = torch.zeros(B, V, 3, height, width, device=device) alphas = torch.zeros(B, V, 1, height, width, device=device) depths = torch.zeros(B, V, 1, height, width, device=device) normals = torch.zeros(B, V, 3, height, width, device=device) for i in range(B): ctx.manual_seeds.append([]) pc = ctx.gaussian_model.set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i]) for j in range(V): fxfycxcy_ij = fxfycxcy[i, j] fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3] for m in range(0, ctx.width//ctx.patch_size): for n in range(0, ctx.height //ctx.patch_size): seed = torch.randint(0, 2**32, (1,)).long().item() ctx.manual_seeds[-1].append(seed) # Transform intrinsics center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height scale_x = ctx.width // ctx.patch_size scale_y = ctx.height // ctx.patch_size trans_x = 0.5 - scale_x * center_x trans_y = 0.5 - scale_y * center_y new_fx = scale_x * fx new_fy = scale_y * fy new_cx = scale_x * cx + trans_x new_cy = scale_y * cy + trans_y new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0) render_results = render(pc, patch_size, patch_size, C2W[i, j], new_fxfycxcy, znear, zfar, bg_color, scaling_modifier, render_dn) images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["image"] alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["alpha"] depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["depth"] normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["normal"] return images, alphas, depths, normals @staticmethod def backward(ctx, grad_images, grad_alphas, grad_depths, grad_normals): """ Backward process. """ xyz, rgb, scale, rotation, opacity = ctx.saved_tensors xyz_nosync = xyz.detach().clone() xyz_nosync.requires_grad = True xyz_nosync.grad = None rgb_nosync = rgb.detach().clone() rgb_nosync.requires_grad = True rgb_nosync.grad = None scale_nosync = scale.detach().clone() scale_nosync.requires_grad = True scale_nosync.grad = None rotation_nosync = rotation.detach().clone() rotation_nosync.requires_grad = True rotation_nosync.grad = None opacity_nosync = opacity.detach().clone() opacity_nosync.requires_grad = True opacity_nosync.grad = None with ( torch.enable_grad(), torch.autocast("cuda", **ctx.gpu_autocast_kwargs), torch.autocast("cpu", **ctx.cpu_autocast_kwargs) ): B, V = ctx.C2W.shape[:2] for i in range(B): ctx.manual_seeds.append([]) pc = ctx.gaussian_model.set_data(xyz_nosync[i], rgb_nosync[i], scale_nosync[i], rotation_nosync[i], opacity_nosync[i]) for j in range(V): fxfycxcy_ij = ctx.fxfycxcy[i, j] fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3] for m in range(0, ctx.width//ctx.patch_size): for n in range(0, ctx.height //ctx.patch_size): grad_images_split = grad_images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] grad_alphas_split = grad_alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] grad_depths_split = grad_depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] grad_normals_split = grad_normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] seed = torch.randint(0, 2**32, (1,)).long().item() ctx.manual_seeds[-1].append(seed) # Transform intrinsics center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height scale_x = ctx.width // ctx.patch_size scale_y = ctx.height // ctx.patch_size trans_x = 0.5 - scale_x * center_x trans_y = 0.5 - scale_y * center_y new_fx = scale_x * fx new_fy = scale_y * fy new_cx = scale_x * cx + trans_x new_cy = scale_y * cy + trans_y new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0) render_results = render(pc, ctx.patch_size, ctx.patch_size, ctx.C2W[i, j], new_fxfycxcy, ctx.znear, ctx.zfar, ctx.bg_color, ctx.scaling_modifier) color_split = render_results["image"] alpha_split = render_results["alpha"] depth_split = render_results["depth"] normal_split = render_results["normal"] render_split = torch.cat([color_split, alpha_split, depth_split, normal_split], dim=0) grad_split = torch.cat([grad_images_split, grad_alphas_split, grad_depths_split, grad_normals_split], dim=0) render_split.backward(grad_split) return xyz_nosync.grad, rgb_nosync.grad, scale_nosync.grad, rotation_nosync.grad, opacity_nosync.grad, None, None, None, None, None, None, None, None, None, None, None def deferred_bp( xyz, rgb, scale, rotation, opacity, height, width, C2W, fxfycxcy, patch_size, gaussian_model, znear, zfar, bg_color, scaling_modifier, render_dn, ): return DeferredBP.apply( xyz, rgb, scale, rotation, opacity, height, width, C2W, fxfycxcy, patch_size, gaussian_model, znear, zfar, bg_color, scaling_modifier, render_dn, )