|
"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
|
|
Attention, antialiasing step is missing in current version.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import kornia
|
|
from kornia.geometry.camera import pixel2cam
|
|
import numpy as np
|
|
from typing import List
|
|
import nvdiffrast.torch as dr
|
|
from scipy.io import loadmat
|
|
from torch import nn
|
|
|
|
def ndc_projection(x=0.1, n=1.0, f=50.0):
|
|
return np.array([[n/x, 0, 0, 0],
|
|
[ 0, n/-x, 0, 0],
|
|
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
|
[ 0, 0, -1, 0]]).astype(np.float32)
|
|
|
|
class MeshRenderer(nn.Module):
|
|
def __init__(self,
|
|
rasterize_fov,
|
|
znear=0.1,
|
|
zfar=10,
|
|
rasterize_size=224):
|
|
super(MeshRenderer, self).__init__()
|
|
|
|
x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
|
|
self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
|
|
torch.diag(torch.tensor([1., -1, -1, 1])))
|
|
self.rasterize_size = rasterize_size
|
|
self.glctx = None
|
|
|
|
def forward(self, vertex, tri, feat=None):
|
|
"""
|
|
Return:
|
|
mask -- torch.tensor, size (B, 1, H, W)
|
|
depth -- torch.tensor, size (B, 1, H, W)
|
|
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
|
|
|
|
Parameters:
|
|
vertex -- torch.tensor, size (B, N, 3)
|
|
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
|
|
feat(optional) -- torch.tensor, size (B, C), features
|
|
"""
|
|
device = vertex.device
|
|
rsize = int(self.rasterize_size)
|
|
ndc_proj = self.ndc_proj.to(device)
|
|
|
|
if vertex.shape[-1] == 3:
|
|
vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
|
|
vertex[..., 1] = -vertex[..., 1]
|
|
|
|
|
|
vertex_ndc = vertex @ ndc_proj.t()
|
|
if self.glctx is None:
|
|
self.glctx = dr.RasterizeGLContext(device=device)
|
|
print("create glctx on device cuda:%d"%device.index)
|
|
|
|
ranges = None
|
|
if isinstance(tri, List) or len(tri.shape) == 3:
|
|
vum = vertex_ndc.shape[1]
|
|
fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
|
|
fstartidx = torch.cumsum(fnum, dim=0) - fnum
|
|
ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
|
|
for i in range(tri.shape[0]):
|
|
tri[i] = tri[i] + i*vum
|
|
vertex_ndc = torch.cat(vertex_ndc, dim=0)
|
|
tri = torch.cat(tri, dim=0)
|
|
|
|
|
|
tri = tri.type(torch.int32).contiguous()
|
|
rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)
|
|
|
|
depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri)
|
|
depth = depth.permute(0, 3, 1, 2)
|
|
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
|
|
depth = mask * depth
|
|
|
|
|
|
image = None
|
|
if feat is not None:
|
|
image, _ = dr.interpolate(feat, rast_out, tri)
|
|
image = image.permute(0, 3, 1, 2)
|
|
image = mask * image
|
|
|
|
return mask, depth, image
|
|
|
|
|