Spaces:
Running on Zero
Running on Zero
| from abc import abstractmethod, ABC | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from numpy import ndarray | |
| from random import shuffle | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| import requests | |
| import os | |
| from ..rig_package.info.asset import Asset | |
| from ..server.spec import BPY_SERVER, bytes_to_object, object_to_bytes | |
| from .spec import ConfigSpec | |
| class LazyAsset(ABC): | |
| """store datapath and load upon requiring""" | |
| path: str | |
| cls: Optional[str]=None | |
| def load(self) -> 'Asset': | |
| raise NotImplementedError() | |
| class BpyLazyAsset(LazyAsset): | |
| def load(self) -> 'Asset': | |
| from ..rig_package.parser.bpy import BpyParser | |
| asset = BpyParser.load(filepath=self.path) | |
| asset.cls = self.cls | |
| asset.path = self.path | |
| return asset | |
| class BpyServerLazyAsset(LazyAsset): | |
| """workaround while bpy is working in multiple threads""" | |
| def load(self) -> 'Asset': | |
| try: | |
| asset = bytes_to_object(requests.get(f"{BPY_SERVER}/load", data=object_to_bytes(self.path)).content) | |
| if isinstance(asset, str): | |
| raise RuntimeError(f"bpy server failed: {asset}") | |
| assert isinstance(asset, Asset) | |
| asset.cls = self.cls | |
| asset.path = self.path | |
| return asset | |
| except Exception as e: | |
| raise RuntimeError(f"bpy server failed: {str(e)}") | |
| class NpzLazyAsset(LazyAsset): | |
| def load(self) -> 'Asset': | |
| d = np.load(self.path, allow_pickle=True) | |
| asset = Asset( | |
| vertices=d['vertices'], | |
| faces=d['faces'], | |
| mesh_names=d.get('mesh_names', None), | |
| joint_names=d.get('joint_names', None), | |
| parents=d.get('parents', None), | |
| lengths=d.get('lengths', None), | |
| matrix_world=d.get('matrix_world', None), | |
| matrix_local=d.get('matrix_local', None), | |
| armature_name=d.get('armature_name', None), | |
| skin=d.get('skin', None), | |
| cls=self.cls, | |
| path=self.path | |
| ) | |
| asset.cls = self.cls | |
| asset.path = self.path | |
| return asset | |
| class UniRigLazyAsset(LazyAsset): | |
| """map unirig's data correctly""" | |
| def load(self) -> 'Asset': | |
| def bn(x): | |
| if isinstance(x, ndarray) and x.ndim==0: | |
| return x.item() | |
| return x | |
| d = np.load(self.path, allow_pickle=True) | |
| parents = bn(d.get('parents', None)) | |
| if parents is not None: | |
| parents = [-1 if x is None else x for x in parents] | |
| parents = np.array(parents) | |
| matrix_local = bn(d.get('matrix_local', None)) | |
| joints = bn(d.get('joints', None)) | |
| if matrix_local is not None and matrix_local.ndim != 3 and joints is not None: | |
| matrix_local = np.zeros((joints.shape[0], 4, 4)) | |
| matrix_local[...] = np.eye(4) | |
| matrix_local[:, :3, 3] = joints | |
| asset = Asset( | |
| vertices=d['vertices'], | |
| faces=d['faces'], | |
| joint_names=bn(d.get('names', None)), | |
| parents=parents, # type: ignore | |
| lengths=bn(d.get('lengths', None)), | |
| matrix_world=bn(d.get('matrix_world', None)), | |
| matrix_local=matrix_local, | |
| armature_name=bn(d.get('armature_name', None)), | |
| skin=bn(d.get('skin', None)), | |
| cls=self.cls, | |
| path=self.path | |
| ).change_dtype(float_dtype=np.float32, int_dtype=np.int32) | |
| asset.cls = self.cls | |
| asset.path = self.path | |
| return asset | |
| class Datapath(ConfigSpec): | |
| """handle input data paths""" | |
| # all filepaths | |
| filepaths: List[str] | |
| # root to add to prefix | |
| input_dataset_dir: str='' | |
| # name of class | |
| cls_name: Optional[List[str]]=None | |
| # bias in a single class | |
| cls_bias: Optional[List[int]]=None | |
| # num of files in a single class | |
| cls_length: Optional[List[int]]=None | |
| # how many files to return when using data sampling | |
| num_files: Optional[int]=None | |
| # use proportion data sampling | |
| use_prob: bool=False | |
| # weight | |
| cls_weight: Optional[List[float]]=None | |
| # use bpy loader | |
| loader: type[LazyAsset]=BpyLazyAsset | |
| # data name | |
| data_name: Optional[str]=None | |
| # check if path exists | |
| ignore_check: bool=False | |
| ################################################################# | |
| # other vertex groups | |
| vertex_groups: Dict[str, ndarray]=field(default_factory=dict) | |
| # sampled vertices | |
| sampled_vertices: Optional[ndarray]=None | |
| # sampled normals | |
| sampled_normals: Optional[ndarray]=None | |
| # sampled vertex groups | |
| sampled_vertex_groups: Optional[Dict[str, ndarray]]=None | |
| def parse(cls, **kwargs) -> 'Datapath': | |
| MAP = { | |
| None: BpyLazyAsset, | |
| 'bpy': BpyLazyAsset, | |
| 'bpy_server': BpyServerLazyAsset, | |
| 'npz': NpzLazyAsset, | |
| 'unirig': UniRigLazyAsset, | |
| } | |
| input_dataset_dir = kwargs.get('input_dataset_dir', '') | |
| num_files = kwargs.get('num_files', None) | |
| use_prob = kwargs.get('use_prob', False) | |
| data_name = kwargs.get('data_name', 'raw_data.npz') | |
| data_path = kwargs.get('data_path', None) | |
| loader_cls = MAP[kwargs.get('loader', None)] | |
| ignore_check = kwargs.get('ignore_check', False) | |
| if data_path is not None: | |
| filepaths = [] | |
| if isinstance(data_path, dict): | |
| cls_name = [] | |
| cls_bias = [] | |
| cls_length = [] | |
| cls_weight = [] | |
| for name, v in data_path.items(): | |
| assert isinstance(v, list), "items in the dict must be a list of data list paths" | |
| for item in v: | |
| if isinstance(item, str): | |
| datalist_path = item | |
| weight = 1.0 | |
| else: | |
| datalist_path = item[0] | |
| weight = item[1] | |
| cls_name.append(name) | |
| lines = [x.strip() for x in open(datalist_path, "r").readlines()] | |
| ok_lines = [] | |
| missing = 0 | |
| for line in lines: | |
| if ignore_check: | |
| ok_lines.append(line) | |
| elif os.path.exists(os.path.join(input_dataset_dir, line, data_name)): | |
| ok_lines.append(line) | |
| else: | |
| missing += 1 | |
| if missing != 0: | |
| print(f"\033[31m{datalist_path}: {missing} missing files\033[0m") | |
| cls_bias.append(len(filepaths)) | |
| cls_length.append(len(ok_lines)) | |
| cls_weight.append(weight) | |
| filepaths.extend(ok_lines) | |
| else: | |
| raise NotImplementedError() | |
| else: | |
| _filepaths = kwargs['filepaths'] | |
| if isinstance(_filepaths, list): | |
| filepaths = _filepaths | |
| cls_name = None | |
| cls_bias = None | |
| cls_length = None | |
| cls_weight = None | |
| elif isinstance(_filepaths, dict): | |
| filepaths = [] | |
| cls_name = [] | |
| cls_bias = [] | |
| cls_length = [] | |
| cls_weight = [] | |
| for k, v in _filepaths.items(): | |
| assert isinstance(v, list), "items in the dict must be a list of paths" | |
| cls_name.append(k) | |
| cls_bias.append(len(filepaths)) | |
| cls_length.append(len(v)) | |
| cls_weight.append(1.0) | |
| filepaths.extend(v) | |
| else: | |
| raise NotImplementedError() | |
| if cls_weight is not None: | |
| total = sum(cls_weight) | |
| cls_weight = [x/total for x in cls_weight] | |
| return Datapath( | |
| filepaths=filepaths, | |
| input_dataset_dir=input_dataset_dir, | |
| cls_name=cls_name, | |
| cls_bias=cls_bias, | |
| cls_length=cls_length, | |
| num_files=num_files, | |
| use_prob=use_prob, | |
| cls_weight=cls_weight, | |
| loader=loader_cls, | |
| data_name=data_name, | |
| ignore_check=ignore_check, | |
| ) | |
| def make(self, path: str, cls: str|None) -> LazyAsset: | |
| return self.loader(path=path, cls=cls) | |
| def __getitem__(self, index: int) -> LazyAsset: | |
| if self.use_prob and self.cls_weight is not None: | |
| if self.cls_bias is None: | |
| raise ValueError("do not have cls_bias") | |
| if self.cls_length is None: | |
| raise ValueError("do not have cls_length") | |
| if not hasattr(self, "perms"): | |
| self.perms = [] | |
| self.current_bias = [] | |
| for i in range(len(self.cls_weight)): | |
| self.perms.append([x for x in range(self.cls_length[i])]) | |
| self.current_bias.append(0) | |
| idx = np.random.choice(len(self.cls_weight), p=self.cls_weight) | |
| i = self.perms[idx][self.current_bias[idx]] | |
| self.current_bias[idx] += 1 | |
| if self.current_bias[idx] >= self.cls_length[idx]: | |
| shuffle(self.perms[idx]) | |
| self.current_bias[idx] = 0 | |
| if self.cls_name is None: | |
| name = None | |
| else: | |
| name = self.cls_name[idx] | |
| path = os.path.join(self.input_dataset_dir, self.filepaths[i+self.cls_bias[idx]]) | |
| if self.data_name is not None: | |
| path = os.path.join(path, self.data_name) | |
| return self.make(path=path, cls=name) | |
| else: | |
| if self.cls_name is None or self.cls_bias is None or self.cls_length is None: | |
| name = None | |
| else: | |
| name = None | |
| for i in range(len(self.cls_bias)): | |
| start = self.cls_bias[i] | |
| end = start + self.cls_length[i] | |
| if start <= index < end: | |
| name = self.cls_name[i] | |
| break | |
| path = os.path.join(self.input_dataset_dir, self.filepaths[index]) | |
| if self.data_name is not None: | |
| path = os.path.join(path, self.data_name) | |
| return self.make(path=path, cls=name) | |
| def get_data(self) -> List[LazyAsset]: | |
| return [self[i] for i in range(len(self))] | |
| def split_by_cls(self) -> Dict[str|None, 'Datapath']: | |
| res: Dict[str|None, Datapath] = {} | |
| if self.cls_name is None: | |
| res[None] = self | |
| return res | |
| if self.cls_bias is None: | |
| raise ValueError("do not have cls_bias") | |
| if self.cls_length is None: | |
| raise ValueError("do not have cls_length") | |
| d_filepaths = defaultdict(list) | |
| d_length = defaultdict(int) | |
| d_weight = defaultdict(list) | |
| for (i, cls) in enumerate(self.cls_name): | |
| s = slice(self.cls_bias[i], self.cls_bias[i]+self.cls_length[i]) | |
| d_filepaths[cls].extend(self.filepaths[s].copy()) | |
| d_length[cls] += self.cls_length[i] | |
| if self.cls_weight is not None: | |
| d_weight[cls].append(self.cls_weight[i]) | |
| for cls in d_filepaths: | |
| cls_weight = None if self.cls_weight is None else d_weight[cls] | |
| if cls_weight is not None: | |
| total = sum(cls_weight) | |
| cls_weight = [x/total for x in cls_weight] | |
| res[cls] = Datapath( | |
| filepaths=d_filepaths[cls], | |
| input_dataset_dir=self.input_dataset_dir, | |
| cls_name=[cls], | |
| cls_bias=[0], | |
| cls_length=[len(d_filepaths[cls])], | |
| num_files=self.num_files, | |
| use_prob=self.use_prob, | |
| cls_weight=cls_weight, | |
| loader=self.loader, | |
| data_name=self.data_name, | |
| ) | |
| return res | |
| def __len__(self): | |
| if self.use_prob: | |
| assert self.num_files is not None, 'num_files is not specified' | |
| return self.num_files | |
| return len(self.filepaths) |