jkorstad's picture
Correctly add UniRig source files
f499d3b
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
from numpy import ndarray
from typing import Dict, Union, List, Tuple
from .order import Order
from .raw_data import RawData
from .exporter import Exporter
from ..tokenizer.spec import TokenizeInput
from .utils import linear_blend_skinning
import trimesh
@dataclass
class Asset(Exporter):
'''
Dataclass to handle data parsed from raw data.
'''
# data class
cls: str
# where is this asset from
path: str
# data file name
data_name: str
# vertices of the mesh, shape (N, 3), float32
vertices: ndarray
# normals of vertices, shape (N, 3), float32
vertex_normals: ndarray
# faces of mesh, shape (F, 3), face id starts from 0 to F-1, int64
faces: ndarray
# face normal of mesh, shape (F, 3), float32
face_normals: ndarray
# joints of bones, shape (J, 3), float32
joints: Union[ndarray, None]=None
# tails of joints, shape (J, 3), float32
tails: Union[ndarray, None]=None
# skinning of joints, shape (N, J), float32
skin: Union[ndarray, None]=None
# whether the joint has skin, bool
no_skin: Union[ndarray, None]=None
# vertex groups
vertex_groups: Union[Dict[str, ndarray], None]=None
# parents of joints, None represents no parent(a root joint)
# make sure parent[k] < k
parents: Union[List[Union[int, None]], None]=None
# names of joints
names: Union[List[str], None]=None
# sampled vertices, shape (N, 3)
sampled_vertices: Union[ndarray, None]=None
# sampled normals, shape (N, 3)
sampled_normals: Union[ndarray, None]=None
# sampled vertex groups, every vertex group should be (N, J)
sampled_vertex_groups: Union[Dict[str, ndarray], None]=None
# {id: part}, part==None -> a spring token
parts_bias: Union[Dict[int, Union[str, None]], None]=None
# local coordinate, shape (J, 4, 4)
matrix_local: Union[ndarray, None]=None
# pose matrix for skinning loss calculation, shape (J, 4, 4)
pose_matrix: Union[ndarray, None]=None
meta: Union[Dict[str, ...], None]=None
@property
def N(self):
'''
number of vertices
'''
return self.vertices.shape[0]
@property
def F(self):
'''
number of faces
'''
return self.faces.shape[0]
@property
def J(self):
'''
number of joints
'''
return self.joints.shape[0]
def get_matrix(self, matrix_basis: ndarray, matrix_local: Union[ndarray, None]=None):
'''
get matrix
matrix_basis: (J, 4, 4)
'''
if matrix_local is None:
assert self.joints is not None
matrix_local = self.matrix_local
if matrix_local is None:
matrix_local = np.zeros((self.J, 4, 4))
matrix_local[:, 0, 0] = 1.
matrix_local[:, 1, 1] = 1.
matrix_local[:, 2, 2] = 1.
matrix_local[:, 3, 3] = 1.
for i in range(self.J):
matrix_local[i, :3, 3] = self.joints[i]
matrix = np.zeros((self.J, 4, 4))
for i in range(self.J):
if i==0:
matrix[i] = matrix_local[i] @ matrix_basis[i]
else:
pid = self.parents[i]
matrix_parent = matrix[pid]
matrix_local_parent = matrix_local[pid]
matrix[i] = (
matrix_parent @
(np.linalg.inv(matrix_local_parent) @ matrix_local[i]) @
matrix_basis[i]
)
return matrix
def apply_matrix_basis(self, matrix_basis: ndarray):
'''
apply a pose to armature
matrix_basis: (J, 4, 4)
'''
matrix_local = self.matrix_local
if matrix_local is None:
matrix_local = np.zeros((self.J, 4, 4))
matrix_local[:, 0, 0] = 1.
matrix_local[:, 1, 1] = 1.
matrix_local[:, 2, 2] = 1.
matrix_local[:, 3, 3] = 1.
for i in range(self.J):
matrix_local[i, :3, 3] = self.joints[i].copy()
matrix = self.get_matrix(matrix_basis=matrix_basis, matrix_local=matrix_local)
self.joints = matrix[:, :3, 3].copy()
vertices = linear_blend_skinning(self.vertices, matrix_local, matrix, self.skin, pad=1, value=1.)
# update matrix_local
self.matrix_local = matrix.copy()
# change tails
if self.tails is not None:
t_skin = np.eye(self.J)
self.tails = linear_blend_skinning(self.tails, matrix_local, matrix, t_skin, pad=1, value=1.)
# in accordance with trimesh's normals
mesh = trimesh.Trimesh(vertices=vertices, faces=self.faces, process=False)
self.vertices = vertices
self.vertex_normals = mesh.vertex_normals.copy()
self.face_normals = mesh.face_normals.copy()
def set_order_by_names(self, new_names: List[str]):
assert len(new_names) == len(self.names)
name_to_id = {name: id for (id, name) in enumerate(self.names)}
new_name_to_id = {name: id for (id, name) in enumerate(new_names)}
perm = []
new_parents = []
for (new_id, name) in enumerate(new_names):
perm.append(name_to_id[name])
pid = self.parents[name_to_id[name]]
if new_id == 0:
assert pid is None, 'first bone is not root bone'
else:
pname = self.names[pid]
pid = new_name_to_id[pname]
assert pid < new_id, 'new order does not form a tree'
new_parents.append(pid)
if self.joints is not None:
self.joints = self.joints[perm]
self.parents = new_parents
if self.tails is not None:
self.tails = self.tails[perm]
if self.skin is not None:
self.skin = self.skin[:, perm]
if self.no_skin is not None:
self.no_skin = self.no_skin[perm]
if self.matrix_local is not None:
self.matrix_local = self.matrix_local[perm]
self.names = new_names
def set_order(self, order: Order):
if self.names is None or self.parents is None:
return
new_names, self.parts_bias = order.arrange_names(cls=self.cls, names=self.names, parents=self.parents)
self.set_order_by_names(new_names=new_names)
def collapse(self, keep: List[str]):
dsu = [i for i in range(self.J)]
def find(x: int) -> int:
if dsu[x] == x:
return x
y = find(dsu[x])
dsu[x] = y
return y
def merge(x: int, y: int):
dsu[find(x)] = find(y)
if self.tails is not None:
new_tails = self.tails.copy()
else:
new_tails = None
if self.skin is not None:
new_skin = self.skin.copy()
else:
new_skin = None
if self.no_skin is not None:
new_no_skin = self.no_skin.copy()
else:
new_no_skin = None
if self.matrix_local is not None:
matrix_local = self.matrix_local.copy()
else:
matrix_local = None
new_names = []
new_parents = []
perm = []
new_name_to_id = {}
tot = 0
for (i, name) in enumerate(self.names):
if name in keep:
new_names.append(name)
new_name_to_id[name] = tot
tot += 1
perm.append(i)
pid = self.parents[i]
if pid is None:
new_parents.append(None)
else:
pid = find(pid)
new_parents.append(new_name_to_id[self.names[pid]])
continue
assert i != 0, 'cannot remove root'
id = find(i)
pid = find(self.parents[id])
# be careful !
# do not copy tail here because you dont know which child to inherit from
if new_skin is not None:
new_skin[:, pid] += new_skin[:, id]
if new_no_skin is not None:
new_no_skin[pid] &= new_no_skin[id]
merge(id, pid)
if new_tails is not None:
new_tails = new_tails[perm]
if new_skin is not None:
new_skin = new_skin[:, perm]
if new_no_skin is not None:
new_no_skin = new_no_skin[perm]
if matrix_local is not None:
matrix_local = matrix_local[perm]
if self.joints is not None:
self.joints = self.joints[perm]
self.parents = new_parents
self.tails = new_tails
self.skin = new_skin
self.no_skin = new_no_skin
self.names = new_names
self.matrix_local = matrix_local
@staticmethod
def from_raw_data(
raw_data: RawData,
cls: str,
path: str,
data_name: str,
) -> 'Asset':
'''
Return an asset initialized from raw data and do transform.
'''
return Asset(
cls=cls,
path=path,
data_name=data_name,
vertices=raw_data.vertices,
vertex_normals=raw_data.vertex_normals,
faces=raw_data.faces,
face_normals=raw_data.face_normals,
joints=raw_data.joints,
tails=raw_data.tails,
skin=raw_data.skin,
no_skin=raw_data.no_skin,
parents=raw_data.parents,
names=raw_data.names,
matrix_local=raw_data.matrix_local,
meta={},
)
def get_tokenize_input(self) -> TokenizeInput:
children = defaultdict(list)
for (id, p) in enumerate(self.parents):
if p is not None:
children[p].append(id)
bones = []
branch = []
is_leaf = []
last = None
for i in range(self.J):
is_leaf.append(len(children[i])==0)
if i == 0:
bones.append(np.concatenate([self.joints[i], self.joints[i]]))
branch.append(False)
else:
pid = self.parents[i]
bones.append(np.concatenate([self.joints[pid], self.joints[i]]))
branch.append(pid!=last)
last = i
bones = np.stack(bones)
branch = np.array(branch, dtype=bool)
is_leaf = np.array(is_leaf, dtype=bool)
return TokenizeInput(
bones=bones,
tails=self.tails,
branch=branch,
is_leaf=is_leaf,
no_skin=self.no_skin,
cls=self.cls,
parts_bias=self.parts_bias,
)
def export_pc(self, path: str, with_normal: bool=True, normal_size=0.01):
'''
export point cloud
'''
vertices = self.vertices
normals = self.vertex_normals
if self.sampled_vertices is not None:
vertices = self.sampled_vertices
normals = self.sampled_normals
if with_normal == False:
normals = None
self._export_pc(vertices=vertices, path=path, vertex_normals=normals, normal_size=normal_size)
def export_mesh(self, path: str):
'''
export mesh
'''
self._export_mesh(vertices=self.vertices, faces=self.faces, path=path)
def export_skeleton(self, path: str):
'''
export spring
'''
self._export_skeleton(joints=self.joints, parents=self.parents, path=path)
def export_skeleton_sequence(self, path: str):
'''
export spring
'''
self._export_skeleton_sequence(joints=self.joints, parents=self.parents, path=path)
def export_fbx(
self,
path: str,
vertex_group_name: str,
extrude_size: float=0.03,
group_per_vertex: int=-1,
add_root: bool=False,
do_not_normalize: bool=False,
use_extrude_bone: bool=True,
use_connect_unique_child: bool=True,
extrude_from_parent: bool=True,
use_tail: bool=False,
use_origin: bool=False,
):
'''
export the whole model with skining
'''
self._export_fbx(
path=path,
vertices=self.vertices if use_origin else self.sampled_vertices,
joints=self.joints,
skin=self.sampled_vertex_groups[vertex_group_name],
parents=self.parents,
names=self.names,
faces=self.faces if use_origin else None,
extrude_size=extrude_size,
group_per_vertex=group_per_vertex,
add_root=add_root,
do_not_normalize=do_not_normalize,
use_extrude_bone=use_extrude_bone,
use_connect_unique_child=use_connect_unique_child,
extrude_from_parent=extrude_from_parent,
tails=self.tails if use_tail else None,
)
def export_render(self, path: str, resolution: Tuple[int, int]=[256, 256], use_tail: bool=False):
if use_tail:
assert self.tails is not None
self._export_render(
path=path,
vertices=self.vertices,
faces=self.faces,
bones=np.concatenate([self.joints, self.tails], axis=-1),
resolution=resolution,
)
else:
pjoints = self.joints[self.parents[1:]]
self._export_render(
path=path,
vertices=self.vertices,
faces=self.faces,
bones=np.concatenate([pjoints, self.joints[1:]], axis=-1),
resolution=resolution,
)