File size: 1,573 Bytes
f499d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
from numpy import ndarray

from typing import Tuple

from .asset import Asset
from .spec import ConfigSpec

@dataclass
class TailConfig(ConfigSpec):
    '''
    Config to handle tails.
    '''
    
    # copy joints to tails
    copy_joint_to_tail: bool
    
    # if the joint has only one son, then connect tail to son's joint
    connect_tail_to_unique_son: bool
    
    @classmethod
    def parse(cls, config) -> 'TailConfig':
        cls.check_keys(config)
        return TailConfig(
            copy_joint_to_tail=config.copy_joint_to_tail,
            connect_tail_to_unique_son=config.connect_tail_to_unique_son,
        )

class Tail():
    
    def __init__(self, config: TailConfig):
        self.config = config
    
    def process_tail(self, asset: Asset):
        if self.config.copy_joint_to_tail:
            assert asset.tails is None, 'copying joints to existing tails is not permitted, please change copy_joint_to_tail to False in transform config'
            asset.tails = asset.joints.copy()
        if self.config.connect_tail_to_unique_son and asset.tails is not None:
            children = defaultdict(list)
            for (id, p) in enumerate(asset.parents):
                if p is not None:
                    children[p].append(id)
            for i in range(asset.J):
                if len(children[i]) == 1:
                    asset.tails[i] = asset.joints[children[i][0]]

def get_tail(config: TailConfig) -> Tail:
    return Tail(config=config)