from typing import Dict, List, Tuple, Union from collections import defaultdict from dataclasses import dataclass import yaml from box import Box from .spec import ConfigSpec @dataclass class OrderConfig(ConfigSpec): ''' Config to handle bones re-ordering. ''' # {skeleton_name: path} skeleton_path: Dict[str, str] # {cls: {part_name: [bone_name_1, bone_name_2, ...]}} parts: Dict[str, Dict[str, List[str]]] # {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]} parts_order: Dict[str, List[str]] @classmethod def parse(cls, config): cls.check_keys(config) skeleton_path = config.skeleton_path parts = {} parts_order = {} for (cls, path) in skeleton_path.items(): assert cls not in parts, 'cls conflicts' d = Box(yaml.safe_load(open(path, 'r'))) parts[cls] = d.parts parts_order[cls] = d.parts_order return OrderConfig( skeleton_path=skeleton_path, parts=parts, parts_order=parts_order, ) class Order(): # {part_name: [bone_name_1, bone_name_2, ...]} parts: Dict[str, Dict[str, List[str]]] # parts of bones to be arranged in [part_name_1, part_name_2, ...] parts_order: Dict[str, List[str]] def __init__(self, config: OrderConfig): self.parts = config.parts self.parts_order = config.parts_order def part_exists(self, cls: str, part: str, names: List[str]) -> bool: ''' Check if part exists. ''' if part not in self.parts[cls]: return False for name in self.parts[cls][part]: if name not in names: return False return True def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]: ''' Get names for specified cls. ''' names = [] for part in parts: if part is None: # spring continue if cls in self.parts and part in self.parts[cls]: names.extend(self.parts[cls][part]) assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones" for i in range(len(names), num_bones): names.append(f"bone_{i}") return names def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]: ''' Arrange names according to required parts order. ''' if cls not in self.parts_order: return names, {0: None} # add a spring token vis = defaultdict(bool) name_to_id = {name: i for (i, name) in enumerate(names)} new_names = [] parts_bias = {} for part in self.parts_order[cls]: if self.part_exists(cls=cls, part=part, names=names): for name in self.parts[cls][part]: vis[name] = True flag = False for name in self.parts[cls][part]: pid = parents[name_to_id[name]] if pid is None: continue if not vis[names[pid]]: flag = True break if flag: # incorrect parts order and should immediately add a spring token break parts_bias[len(new_names)] = part new_names.extend(self.parts[cls][part]) parts_bias[len(new_names)] = None # add a spring token for name in names: if name not in new_names: new_names.append(name) return new_names, parts_bias def get_order(config: OrderConfig) -> Order: return Order(config=config)