File size: 4,765 Bytes
9d7cf7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

from collections import defaultdict
from dataclasses import dataclass
from numpy import ndarray
from omegaconf import OmegaConf
from typing import Dict, List, Tuple, Optional

from .spec import ConfigSpec

@dataclass
class Order(ConfigSpec):
    
    # {part_name: [bone_name_1, bone_name_2, ...]}
    parts: Dict[str, Dict[str, List[str]]]
    
    # parts of bones to be arranged in [part_name_1, part_name_2, ...]
    parts_order: Dict[str, List[str]]
    
    # {skeleton_name: path}
    skeleton_path: Optional[Dict[str, str]]=None
    
    sort_by_xyz: bool=False
    
    @classmethod
    def parse(cls, **kwargs) -> 'Order':
        cls.check_keys(kwargs)
        skeleton_path = kwargs.get('skeleton_path', None)
        if skeleton_path is not None:
            parts = {}
            parts_order = {}
            for (cls, path) in skeleton_path.items():
                assert cls not in parts, 'cls conflicts'
                d = OmegaConf.load(path)
                parts[cls] = d.parts
                parts_order[cls] = d.parts_order
        else:
            parts = kwargs.get('parts')
            parts_order = kwargs.get('parts_order')
            assert parts is not None
            assert parts_order is not None
        return Order(
            skeleton_path=skeleton_path,
            parts=parts,
            parts_order=parts_order,
            sort_by_xyz=kwargs.get('sort_by_xyz', False),
        )
    
    def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
        '''
        Check if part exists.
        '''
        if part not in self.parts[cls]:
            return False
        for name in self.parts[cls][part]:
            if name not in names:
                return False
        return True
    
    def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]:
        '''
        Get names for specified cls.
        '''
        names = []
        for part in parts:
            if part is None: # spring
                continue
            if cls in self.parts and part in self.parts[cls]:
                names.extend(self.parts[cls][part])
        assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
        for i in range(len(names), num_bones):
            names.append(f"bone_{i}")
        return names
    
    def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]:
        '''
        Arrange names according to required parts order.
        '''
        def sort_by_xyz(joints):
            return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1]))
        
        if self.sort_by_xyz:
            assert joints is not None
            new_names = []
            root = -1
            son = defaultdict(list)
            not_root = {}
            for (i, p) in enumerate(parents):
                if p != -1:
                    son[p].append(i)
                    not_root[i] = True
            for i in range(len(parents)):
                if not_root.get(i, False) == False:
                    root = i
                    break
            Q = [root]
            while Q:
                u = Q.pop(0)
                new_names.append(names[u])
                wait = []
                for v in son[u]:
                    wait.append((v, joints[v]))
                wait_sorted = sort_by_xyz(wait)
                new_wait = [v for v, _ in wait_sorted]
                Q = new_wait + Q
            return new_names, {}
        if cls not in self.parts_order:
            return names, {0: None} # add a spring token
        vis = defaultdict(bool)
        name_to_id = {name: i for (i, name) in enumerate(names)}
        new_names = []
        parts_bias = {}
        for part in self.parts_order[cls]:
            if self.part_exists(cls=cls, part=part, names=names):
                for name in self.parts[cls][part]:
                    vis[name] = True
                flag = False
                for name in self.parts[cls][part]:
                    pid = parents[name_to_id[name]]
                    if pid==-1:
                        continue
                    if not vis[names[pid]]:
                        flag = True
                        break
                if flag: # incorrect parts order and should immediately add a spring token
                    break
                parts_bias[len(new_names)] = part
                new_names.extend(self.parts[cls][part])
        parts_bias[len(new_names)] = None # add a spring token
        for name in names:
            if name not in new_names:
                new_names.append(name)
        return new_names, parts_bias