Spaces:
Running on Zero
Running on Zero
File size: 4,765 Bytes
9d7cf7f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from collections import defaultdict
from dataclasses import dataclass
from numpy import ndarray
from omegaconf import OmegaConf
from typing import Dict, List, Tuple, Optional
from .spec import ConfigSpec
@dataclass
class Order(ConfigSpec):
# {part_name: [bone_name_1, bone_name_2, ...]}
parts: Dict[str, Dict[str, List[str]]]
# parts of bones to be arranged in [part_name_1, part_name_2, ...]
parts_order: Dict[str, List[str]]
# {skeleton_name: path}
skeleton_path: Optional[Dict[str, str]]=None
sort_by_xyz: bool=False
@classmethod
def parse(cls, **kwargs) -> 'Order':
cls.check_keys(kwargs)
skeleton_path = kwargs.get('skeleton_path', None)
if skeleton_path is not None:
parts = {}
parts_order = {}
for (cls, path) in skeleton_path.items():
assert cls not in parts, 'cls conflicts'
d = OmegaConf.load(path)
parts[cls] = d.parts
parts_order[cls] = d.parts_order
else:
parts = kwargs.get('parts')
parts_order = kwargs.get('parts_order')
assert parts is not None
assert parts_order is not None
return Order(
skeleton_path=skeleton_path,
parts=parts,
parts_order=parts_order,
sort_by_xyz=kwargs.get('sort_by_xyz', False),
)
def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
'''
Check if part exists.
'''
if part not in self.parts[cls]:
return False
for name in self.parts[cls][part]:
if name not in names:
return False
return True
def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]:
'''
Get names for specified cls.
'''
names = []
for part in parts:
if part is None: # spring
continue
if cls in self.parts and part in self.parts[cls]:
names.extend(self.parts[cls][part])
assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
for i in range(len(names), num_bones):
names.append(f"bone_{i}")
return names
def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]:
'''
Arrange names according to required parts order.
'''
def sort_by_xyz(joints):
return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1]))
if self.sort_by_xyz:
assert joints is not None
new_names = []
root = -1
son = defaultdict(list)
not_root = {}
for (i, p) in enumerate(parents):
if p != -1:
son[p].append(i)
not_root[i] = True
for i in range(len(parents)):
if not_root.get(i, False) == False:
root = i
break
Q = [root]
while Q:
u = Q.pop(0)
new_names.append(names[u])
wait = []
for v in son[u]:
wait.append((v, joints[v]))
wait_sorted = sort_by_xyz(wait)
new_wait = [v for v, _ in wait_sorted]
Q = new_wait + Q
return new_names, {}
if cls not in self.parts_order:
return names, {0: None} # add a spring token
vis = defaultdict(bool)
name_to_id = {name: i for (i, name) in enumerate(names)}
new_names = []
parts_bias = {}
for part in self.parts_order[cls]:
if self.part_exists(cls=cls, part=part, names=names):
for name in self.parts[cls][part]:
vis[name] = True
flag = False
for name in self.parts[cls][part]:
pid = parents[name_to_id[name]]
if pid==-1:
continue
if not vis[names[pid]]:
flag = True
break
if flag: # incorrect parts order and should immediately add a spring token
break
parts_bias[len(new_names)] = part
new_names.extend(self.parts[cls][part])
parts_bias[len(new_names)] = None # add a spring token
for name in names:
if name not in new_names:
new_names.append(name)
return new_names, parts_bias |