jkorstad's picture
Correctly add UniRig source files
f499d3b
from typing import Dict, List, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass
import yaml
from box import Box
from .spec import ConfigSpec
@dataclass
class OrderConfig(ConfigSpec):
'''
Config to handle bones re-ordering.
'''
# {skeleton_name: path}
skeleton_path: Dict[str, str]
# {cls: {part_name: [bone_name_1, bone_name_2, ...]}}
parts: Dict[str, Dict[str, List[str]]]
# {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]}
parts_order: Dict[str, List[str]]
@classmethod
def parse(cls, config):
cls.check_keys(config)
skeleton_path = config.skeleton_path
parts = {}
parts_order = {}
for (cls, path) in skeleton_path.items():
assert cls not in parts, 'cls conflicts'
d = Box(yaml.safe_load(open(path, 'r')))
parts[cls] = d.parts
parts_order[cls] = d.parts_order
return OrderConfig(
skeleton_path=skeleton_path,
parts=parts,
parts_order=parts_order,
)
class Order():
# {part_name: [bone_name_1, bone_name_2, ...]}
parts: Dict[str, Dict[str, List[str]]]
# parts of bones to be arranged in [part_name_1, part_name_2, ...]
parts_order: Dict[str, List[str]]
def __init__(self, config: OrderConfig):
self.parts = config.parts
self.parts_order = config.parts_order
def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
'''
Check if part exists.
'''
if part not in self.parts[cls]:
return False
for name in self.parts[cls][part]:
if name not in names:
return False
return True
def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]:
'''
Get names for specified cls.
'''
names = []
for part in parts:
if part is None: # spring
continue
if cls in self.parts and part in self.parts[cls]:
names.extend(self.parts[cls][part])
assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
for i in range(len(names), num_bones):
names.append(f"bone_{i}")
return names
def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]:
'''
Arrange names according to required parts order.
'''
if cls not in self.parts_order:
return names, {0: None} # add a spring token
vis = defaultdict(bool)
name_to_id = {name: i for (i, name) in enumerate(names)}
new_names = []
parts_bias = {}
for part in self.parts_order[cls]:
if self.part_exists(cls=cls, part=part, names=names):
for name in self.parts[cls][part]:
vis[name] = True
flag = False
for name in self.parts[cls][part]:
pid = parents[name_to_id[name]]
if pid is None:
continue
if not vis[names[pid]]:
flag = True
break
if flag: # incorrect parts order and should immediately add a spring token
break
parts_bias[len(new_names)] = part
new_names.extend(self.parts[cls][part])
parts_bias[len(new_names)] = None # add a spring token
for name in names:
if name not in new_names:
new_names.append(name)
return new_names, parts_bias
def get_order(config: OrderConfig) -> Order:
return Order(config=config)