Spaces:
Running on Zero
Running on Zero
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from numpy import ndarray | |
| from omegaconf import OmegaConf | |
| from typing import Dict, List, Tuple, Optional | |
| from .spec import ConfigSpec | |
| class Order(ConfigSpec): | |
| # {part_name: [bone_name_1, bone_name_2, ...]} | |
| parts: Dict[str, Dict[str, List[str]]] | |
| # parts of bones to be arranged in [part_name_1, part_name_2, ...] | |
| parts_order: Dict[str, List[str]] | |
| # {skeleton_name: path} | |
| skeleton_path: Optional[Dict[str, str]]=None | |
| sort_by_xyz: bool=False | |
| def parse(cls, **kwargs) -> 'Order': | |
| cls.check_keys(kwargs) | |
| skeleton_path = kwargs.get('skeleton_path', None) | |
| if skeleton_path is not None: | |
| parts = {} | |
| parts_order = {} | |
| for (cls, path) in skeleton_path.items(): | |
| assert cls not in parts, 'cls conflicts' | |
| d = OmegaConf.load(path) | |
| parts[cls] = d.parts | |
| parts_order[cls] = d.parts_order | |
| else: | |
| parts = kwargs.get('parts') | |
| parts_order = kwargs.get('parts_order') | |
| assert parts is not None | |
| assert parts_order is not None | |
| return Order( | |
| skeleton_path=skeleton_path, | |
| parts=parts, | |
| parts_order=parts_order, | |
| sort_by_xyz=kwargs.get('sort_by_xyz', False), | |
| ) | |
| def part_exists(self, cls: str, part: str, names: List[str]) -> bool: | |
| ''' | |
| Check if part exists. | |
| ''' | |
| if part not in self.parts[cls]: | |
| return False | |
| for name in self.parts[cls][part]: | |
| if name not in names: | |
| return False | |
| return True | |
| def make_names(self, cls: str|None, parts: List[str|None], num_bones: int) -> List[str]: | |
| ''' | |
| Get names for specified cls. | |
| ''' | |
| names = [] | |
| for part in parts: | |
| if part is None: # spring | |
| continue | |
| if cls in self.parts and part in self.parts[cls]: | |
| names.extend(self.parts[cls][part]) | |
| assert len(names) <= num_bones, "number of bones in required skeleton is more than existing bones" | |
| for i in range(len(names), num_bones): | |
| names.append(f"bone_{i}") | |
| return names | |
| def arrange_names(self, cls: str|None, names: List[str], parents: List[int], joints: Optional[ndarray]=None) -> Tuple[List[str], Dict[int, str|None]]: | |
| ''' | |
| Arrange names according to required parts order. | |
| ''' | |
| def sort_by_xyz(joints): | |
| return sorted(joints, key=lambda joint: (joint[1][2], joint[1][0], joint[1][1])) | |
| if self.sort_by_xyz: | |
| assert joints is not None | |
| new_names = [] | |
| root = -1 | |
| son = defaultdict(list) | |
| not_root = {} | |
| for (i, p) in enumerate(parents): | |
| if p != -1: | |
| son[p].append(i) | |
| not_root[i] = True | |
| for i in range(len(parents)): | |
| if not_root.get(i, False) == False: | |
| root = i | |
| break | |
| Q = [root] | |
| while Q: | |
| u = Q.pop(0) | |
| new_names.append(names[u]) | |
| wait = [] | |
| for v in son[u]: | |
| wait.append((v, joints[v])) | |
| wait_sorted = sort_by_xyz(wait) | |
| new_wait = [v for v, _ in wait_sorted] | |
| Q = new_wait + Q | |
| return new_names, {} | |
| if cls not in self.parts_order: | |
| return names, {0: None} # add a spring token | |
| vis = defaultdict(bool) | |
| name_to_id = {name: i for (i, name) in enumerate(names)} | |
| new_names = [] | |
| parts_bias = {} | |
| for part in self.parts_order[cls]: | |
| if self.part_exists(cls=cls, part=part, names=names): | |
| for name in self.parts[cls][part]: | |
| vis[name] = True | |
| flag = False | |
| for name in self.parts[cls][part]: | |
| pid = parents[name_to_id[name]] | |
| if pid==-1: | |
| continue | |
| if not vis[names[pid]]: | |
| flag = True | |
| break | |
| if flag: # incorrect parts order and should immediately add a spring token | |
| break | |
| parts_bias[len(new_names)] = part | |
| new_names.extend(self.parts[cls][part]) | |
| parts_bias[len(new_names)] = None # add a spring token | |
| for name in names: | |
| if name not in new_names: | |
| new_names.append(name) | |
| return new_names, parts_bias |