Diffsplat / src /models /gs_render /deferred_bp.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
import torch
from torch.utils.checkpoint import _get_autocast_kwargs
from src.models.gs_render.gs_util import render
class DeferredBP(torch.autograd.Function):
@staticmethod
def forward(ctx,
xyz, rgb, scale, rotation, opacity,
height, width, C2W, fxfycxcy,
patch_size, gaussian_model,
znear, zfar,
bg_color,
scaling_modifier,
render_dn,
):
""" Forward rendering. """
assert (xyz.dim() == 3) and (rgb.dim() == 3) and (scale.dim() == 3) and (rotation.dim() == 3)
assert height % patch_size == 0 and width % patch_size == 0
ctx.save_for_backward(xyz, rgb, scale, rotation, opacity) # save tensors for backward
ctx.height = height
ctx.width = width
ctx.C2W = C2W
ctx.fxfycxcy = fxfycxcy
ctx.patch_size = patch_size
ctx.gaussian_model = gaussian_model
ctx.znear = znear
ctx.zfar = zfar
ctx.bg_color = bg_color
ctx.scaling_modifier = scaling_modifier
ctx.render_dn = render_dn
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
ctx.manual_seeds = []
with (
torch.no_grad(),
torch.autocast("cuda", **ctx.gpu_autocast_kwargs),
torch.autocast("cpu", **ctx.cpu_autocast_kwargs),
):
device, (B, V) = C2W.device, C2W.shape[:2]
images = torch.zeros(B, V, 3, height, width, device=device)
alphas = torch.zeros(B, V, 1, height, width, device=device)
depths = torch.zeros(B, V, 1, height, width, device=device)
normals = torch.zeros(B, V, 3, height, width, device=device)
for i in range(B):
ctx.manual_seeds.append([])
pc = ctx.gaussian_model.set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i])
for j in range(V):
fxfycxcy_ij = fxfycxcy[i, j]
fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3]
for m in range(0, ctx.width//ctx.patch_size):
for n in range(0, ctx.height //ctx.patch_size):
seed = torch.randint(0, 2**32, (1,)).long().item()
ctx.manual_seeds[-1].append(seed)
# Transform intrinsics
center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width
center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height
scale_x = ctx.width // ctx.patch_size
scale_y = ctx.height // ctx.patch_size
trans_x = 0.5 - scale_x * center_x
trans_y = 0.5 - scale_y * center_y
new_fx = scale_x * fx
new_fy = scale_y * fy
new_cx = scale_x * cx + trans_x
new_cy = scale_y * cy + trans_y
new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0)
render_results = render(pc, patch_size, patch_size, C2W[i, j], new_fxfycxcy, znear, zfar, bg_color, scaling_modifier, render_dn)
images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["image"]
alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["alpha"]
depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["depth"]
normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size] = render_results["normal"]
return images, alphas, depths, normals
@staticmethod
def backward(ctx, grad_images, grad_alphas, grad_depths, grad_normals):
""" Backward process. """
xyz, rgb, scale, rotation, opacity = ctx.saved_tensors
xyz_nosync = xyz.detach().clone()
xyz_nosync.requires_grad = True
xyz_nosync.grad = None
rgb_nosync = rgb.detach().clone()
rgb_nosync.requires_grad = True
rgb_nosync.grad = None
scale_nosync = scale.detach().clone()
scale_nosync.requires_grad = True
scale_nosync.grad = None
rotation_nosync = rotation.detach().clone()
rotation_nosync.requires_grad = True
rotation_nosync.grad = None
opacity_nosync = opacity.detach().clone()
opacity_nosync.requires_grad = True
opacity_nosync.grad = None
with (
torch.enable_grad(),
torch.autocast("cuda", **ctx.gpu_autocast_kwargs),
torch.autocast("cpu", **ctx.cpu_autocast_kwargs)
):
B, V = ctx.C2W.shape[:2]
for i in range(B):
ctx.manual_seeds.append([])
pc = ctx.gaussian_model.set_data(xyz_nosync[i], rgb_nosync[i], scale_nosync[i], rotation_nosync[i], opacity_nosync[i])
for j in range(V):
fxfycxcy_ij = ctx.fxfycxcy[i, j]
fx, fy, cx, cy = fxfycxcy_ij[0], fxfycxcy_ij[1], fxfycxcy_ij[2], fxfycxcy_ij[3]
for m in range(0, ctx.width//ctx.patch_size):
for n in range(0, ctx.height //ctx.patch_size):
grad_images_split = grad_images[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size]
grad_alphas_split = grad_alphas[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size]
grad_depths_split = grad_depths[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size]
grad_normals_split = grad_normals[i, j, :, n*ctx.patch_size:(n+1)*ctx.patch_size, m*ctx.patch_size:(m+1)*ctx.patch_size]
seed = torch.randint(0, 2**32, (1,)).long().item()
ctx.manual_seeds[-1].append(seed)
# Transform intrinsics
center_x = (m*ctx.patch_size + ctx.patch_size//2) / ctx.width
center_y = (n*ctx.patch_size + ctx.patch_size//2) / ctx.height
scale_x = ctx.width // ctx.patch_size
scale_y = ctx.height // ctx.patch_size
trans_x = 0.5 - scale_x * center_x
trans_y = 0.5 - scale_y * center_y
new_fx = scale_x * fx
new_fy = scale_y * fy
new_cx = scale_x * cx + trans_x
new_cy = scale_y * cy + trans_y
new_fxfycxcy = torch.stack([new_fx, new_fy, new_cx, new_cy], dim=0)
render_results = render(pc, ctx.patch_size, ctx.patch_size, ctx.C2W[i, j], new_fxfycxcy, ctx.znear, ctx.zfar, ctx.bg_color, ctx.scaling_modifier)
color_split = render_results["image"]
alpha_split = render_results["alpha"]
depth_split = render_results["depth"]
normal_split = render_results["normal"]
render_split = torch.cat([color_split, alpha_split, depth_split, normal_split], dim=0)
grad_split = torch.cat([grad_images_split, grad_alphas_split, grad_depths_split, grad_normals_split], dim=0)
render_split.backward(grad_split)
return xyz_nosync.grad, rgb_nosync.grad, scale_nosync.grad, rotation_nosync.grad, opacity_nosync.grad, None, None, None, None, None, None, None, None, None, None, None
def deferred_bp(
xyz, rgb, scale, rotation, opacity,
height, width, C2W, fxfycxcy,
patch_size, gaussian_model,
znear, zfar,
bg_color,
scaling_modifier,
render_dn,
):
return DeferredBP.apply(
xyz, rgb, scale, rotation, opacity,
height, width, C2W, fxfycxcy,
patch_size, gaussian_model,
znear, zfar,
bg_color,
scaling_modifier,
render_dn,
)