SkinTokens / src /data /order.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from collections import defaultdict
from dataclasses import dataclass
from numpy import ndarray
from omegaconf import OmegaConf
from typing import Dict, List, Tuple, Optional
from .spec import ConfigSpec
@dataclass
class Order(ConfigSpec):
# {part_name: [bone_name_1, bone_name_2, ...]}
parts: Dict[str, Dict[str, List[str]]]
# parts of bones to be arranged in [part_name_1, part_name_2, ...]
parts_order: Dict[str, List[str]]
# {skeleton_name: path}
skeleton_path: Optional[Dict[str, str]]=None
sort_by_xyz: bool=False
@classmethod
def parse(cls, **kwargs) -> 'Order':
cls.check_keys(kwargs)
skeleton_path = kwargs.get('skeleton_path', None)
if skeleton_path is not None:
parts = {}
parts_order = {}
for (cls, path) in skeleton_path.items():
assert cls not in parts, 'cls conflicts'
d = OmegaConf.load(path)
parts[cls] = d.parts
parts_order[cls] = d.parts_order
else:
parts = kwargs.get('parts')
parts_order = kwargs.get('parts_order')
assert parts is not None
assert parts_order is not None
return Order(
skeleton_path=skeleton_path,
parts=parts,
parts_order=parts_order,
sort_by_xyz=kwargs.get('sort_by_xyz', False),
)
def part_exists(self, cls: str, part: str, names: List[str]) -> bool:
'''
Check if part exists.
'''
if part not in self.parts[cls]:
return False
for name in self.parts[cls][part]:
if name not in names:
return False
return True
def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]:
'''
Get names for specified cls.
'''
names = []
for part in parts:
if part is None: # spring
continue
if cls in self.parts and part in self.parts[cls]:
names.extend(self.parts[cls][part])
assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones"
for i in range(len(names), num_bones):
names.append(f"bone_{i}")
return names
def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]:
'''
Arrange names according to required parts order.
'''
def sort_by_xyz(joints):
return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1]))
if self.sort_by_xyz:
assert joints is not None
new_names = []
root = -1
son = defaultdict(list)
not_root = {}
for (i, p) in enumerate(parents):
if p != -1:
son[p].append(i)
not_root[i] = True
for i in range(len(parents)):
if not_root.get(i, False) == False:
root = i
break
Q = [root]
while Q:
u = Q.pop(0)
new_names.append(names[u])
wait = []
for v in son[u]:
wait.append((v, joints[v]))
wait_sorted = sort_by_xyz(wait)
new_wait = [v for v, _ in wait_sorted]
Q = new_wait + Q
return new_names, {}
if cls not in self.parts_order:
return names, {0: None} # add a spring token
vis = defaultdict(bool)
name_to_id = {name: i for (i, name) in enumerate(names)}
new_names = []
parts_bias = {}
for part in self.parts_order[cls]:
if self.part_exists(cls=cls, part=part, names=names):
for name in self.parts[cls][part]:
vis[name] = True
flag = False
for name in self.parts[cls][part]:
pid = parents[name_to_id[name]]
if pid==-1:
continue
if not vis[names[pid]]:
flag = True
break
if flag: # incorrect parts order and should immediately add a spring token
break
parts_bias[len(new_names)] = part
new_names.extend(self.parts[cls][part])
parts_bias[len(new_names)] = None # add a spring token
for name in names:
if name not in new_names:
new_names.append(name)
return new_names, parts_bias