File size: 3,892 Bytes
f499d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Dict, List, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass
import yaml
from box import Box

from .spec import ConfigSpec

@dataclass
class OrderConfig(ConfigSpec):
    '''
    Config to handle bones re-ordering.
    '''
    
    # {skeleton_name: path}
    skeleton_path: Dict[str, str]
    
    # {cls: {part_name: [bone_name_1, bone_name_2, ...]}}
    parts: Dict[str, Dict[str, List[str]]]
    
    # {cls: parts of bones to be arranged in [part_name_1, part_name_2, ...]}
    parts_order: Dict[str, List[str]]
    
    @classmethod
    def parse(cls, config):
        cls.check_keys(config)
        skeleton_path = config.skeleton_path
        parts = {}
        parts_order = {}
        for (cls, path) in skeleton_path.items():
            assert cls not in parts, 'cls conflicts'
            d = Box(yaml.safe_load(open(path, 'r')))
            parts[cls] = d.parts
            parts_order[cls] = d.parts_order
        return OrderConfig(
            skeleton_path=skeleton_path,
            parts=parts,
            parts_order=parts_order,
        )

class Order():
    
    # {part_name: [bone_name_1, bone_name_2, ...]}
    parts: Dict[str, Dict[str, List[str]]]
    
    # parts of bones to be arranged in [part_name_1, part_name_2, ...]
    parts_order: Dict[str, List[str]]
    
    def __init__(self, config: OrderConfig):
        self.parts          = config.parts
        self.parts_order    = config.parts_order
    
    def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
        '''
        Check if part exists.
        '''
        if part not in self.parts[cls]:
            return False
        for name in self.parts[cls][part]:
            if name not in names:
                return False
        return True
    
    def make_names(self, cls: Union[str, None], parts: List[Union[str, None]], num_bones: int) -> List[str]:
        '''
        Get names for specified cls.
        '''
        names = []
        for part in parts:
            if part is None: # spring
                continue
            if cls in self.parts and part in self.parts[cls]:
                names.extend(self.parts[cls][part])
        assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
        for i in range(len(names), num_bones):
            names.append(f"bone_{i}")
        return names
    
    def arrange_names(self, cls: str, names: List[str], parents: List[Union[int, None]]) -> Tuple[List[str], Dict[int, Union[str]]]:
        '''
        Arrange names according to required parts order.
        '''
        if cls not in self.parts_order:
            return names, {0: None} # add a spring token
        vis = defaultdict(bool)
        name_to_id = {name: i for (i, name) in enumerate(names)}
        new_names = []
        parts_bias = {}
        for part in self.parts_order[cls]:
            if self.part_exists(cls=cls, part=part, names=names):
                for name in self.parts[cls][part]:
                    vis[name] = True
                flag = False
                for name in self.parts[cls][part]:
                    pid = parents[name_to_id[name]]
                    if pid is None:
                        continue
                    if not vis[names[pid]]:
                        flag = True
                        break
                if flag: # incorrect parts order and should immediately add a spring token
                    break
                parts_bias[len(new_names)] = part
                new_names.extend(self.parts[cls][part])
        parts_bias[len(new_names)] = None # add a spring token
        for name in names:
            if name not in new_names:
                new_names.append(name)
        return new_names, parts_bias

def get_order(config: OrderConfig) -> Order:
    return Order(config=config)