Spaces:
Running on Zero
Running on Zero
| from abc import ABC, abstractmethod | |
| from copy import deepcopy | |
| from dataclasses import dataclass | |
| from numpy import ndarray | |
| from omegaconf import OmegaConf | |
| from typing import Dict, List, Optional, final | |
| from torch import Tensor | |
| import numpy as np | |
| import lightning.pytorch as pl | |
| import torch | |
| from ..data.transform import Transform | |
| from ..rig_package.info.asset import Asset | |
| from ..tokenizer.spec import DetokenizeOutput | |
| class ModelInput(): | |
| asset: Asset | |
| tokens: Optional[ndarray]=None | |
| class ModelSpec(pl.LightningModule, ABC): | |
| model_config: Dict | |
| transform_config: Dict | |
| tokenizer_config: Dict|None | |
| def __init__(self, model_config, transform_config, tokenizer_config=None): | |
| super().__init__() | |
| if not isinstance(model_config, dict): | |
| model_cfg = OmegaConf.to_container(model_config, resolve=True) | |
| else: | |
| model_cfg = model_config | |
| if not isinstance(transform_config, dict): | |
| transform_cfg = OmegaConf.to_container(transform_config, resolve=True) | |
| else: | |
| transform_cfg = transform_config | |
| if tokenizer_config is not None and not isinstance(tokenizer_config, dict): | |
| tokenizer_cfg = OmegaConf.to_container(tokenizer_config, resolve=True) | |
| else: | |
| tokenizer_cfg = tokenizer_config | |
| self.model_config = model_cfg # type: ignore | |
| self.transform_config = transform_cfg # type: ignore | |
| self.tokenizer_config = tokenizer_cfg # type: ignore | |
| self.save_hyperparameters(model_cfg) | |
| self.save_hyperparameters(transform_cfg) | |
| self.save_hyperparameters(tokenizer_cfg) | |
| def _process_fn(self, batch: List[ModelInput]) -> List[Dict]: | |
| n_batch = self.process_fn(batch) | |
| if self._trainer is None or not self.trainer.training: | |
| for k in n_batch[0].keys(): | |
| if not isinstance(n_batch[0][k], ndarray) and not isinstance(n_batch[0][k], Tensor): | |
| continue | |
| s = n_batch[0][k].shape | |
| for i in range(1, len(n_batch)): | |
| assert n_batch[i][k].shape == s, f"{k} has different shape in batch" | |
| for (i, b) in enumerate(batch): | |
| non = n_batch[i].get('non', {}) | |
| non['model_input'] = deepcopy(b) | |
| n_batch[i]['non'] = non | |
| else: | |
| for b in batch: | |
| del b.asset | |
| return n_batch | |
| def process_fn(self, batch: List[ModelInput]) -> List[Dict]: | |
| """ | |
| Fetch data from dataloader and turn it into Tensor objects. | |
| """ | |
| raise NotImplementedError() | |
| def compile_model(self): | |
| """ | |
| Compile the model. Do this before training and after loading state dicts. | |
| """ | |
| pass | |
| def load_from_system_checkpoint(cls, checkpoint_path: str, strict: bool=True, **kwargs): | |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) | |
| state_dict = ckpt['state_dict'] | |
| model_config = kwargs.get('model_config', None) | |
| transform_config = kwargs.get('transform_config', None) | |
| tokenizer_config = kwargs.get('tokenizer_config', None) | |
| if model_config is None: | |
| model_config = ckpt['hyper_parameters']['model_config'] | |
| if transform_config is None: | |
| transform_config = ckpt['hyper_parameters']['transform_config'] | |
| if tokenizer_config is None: | |
| tokenizer_config = ckpt['hyper_parameters']['tokenizer_config'] | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| k = k.replace("_orig_mod.", "") | |
| if k.startswith("model."): | |
| k = k[len("model.") :] | |
| new_state_dict[k] = v | |
| model = cls( | |
| model_config=model_config, | |
| transform_config=transform_config, | |
| tokenizer_config=tokenizer_config, | |
| ) | |
| missing, unexpected = model.load_state_dict(new_state_dict, strict=strict) | |
| if missing or unexpected: | |
| print(f"[Warning] Missing keys: {missing}") | |
| print(f"[Warning] Unexpected keys: {unexpected}") | |
| model.on_load_checkpoint(ckpt) | |
| return model | |
| def get_train_transform(self) -> Transform|None: | |
| cfg = self.transform_config.get('train_transform', None) | |
| if cfg is None: | |
| return None | |
| return Transform.parse(**cfg) | |
| def get_validate_transform(self) -> Transform|None: | |
| cfg = self.transform_config.get('validate_transform', None) | |
| if cfg is None: | |
| return None | |
| return Transform.parse(**cfg) | |
| def get_predict_transform(self) -> Transform|None: | |
| cfg = self.transform_config.get('predict_transform', None) | |
| if cfg is None: | |
| return None | |
| return Transform.parse(**cfg) | |
| def predict_step(self, batch: Dict, no_cls: bool=False, skeleton_tokens=None) -> Dict: | |
| raise NotImplementedError() | |
| class VaeInput(): | |
| dense_cond: List[Tensor] # [(J, skin_samples, 6)] | |
| dense_skin: List[Tensor] # [(J, skin_samples)] | |
| dense_indices: List[List[int]] # [List[J]], corresponding indices of gt | |
| uniform_cond: Tensor # (B, N, 6) | |
| uniform_skin: List[Tensor] # [(N, J)] | |
| def B(self): | |
| return self.uniform_cond.shape[0] | |
| def max_J(self): | |
| return max([len(s) for s in self.dense_indices]) | |
| def get_len(self, i) -> int: | |
| return len(self.dense_indices[i]) | |
| def _clamp_j(self, i: int, j: int) -> int: | |
| return min(j, len(self.dense_indices[i])-1) | |
| def get_dense_cond(self, j: int) -> Tensor: | |
| """return (B, skin_samples, 6)""" | |
| return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_dense_skin(self, j: int) -> Tensor: | |
| """return (B, skin_samples)""" | |
| return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_full_cond(self, j: int) -> Tensor: | |
| """return (B, N+skin_samples, 6)""" | |
| return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1) | |
| def get_uniform_skin(self, j: int) -> Tensor: | |
| """return (B, N)""" | |
| return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_full_skin(self, j: int) -> Tensor: | |
| """return (B, N+skin_samples)""" | |
| return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1) | |
| def get_flatten_uniform_cond(self) -> Tensor: | |
| """return (sum_J, N, 6)""" | |
| return self.uniform_cond[self.get_flatten_indices()] | |
| def get_flatten_dense_cond(self) -> Tensor: | |
| """return (sum_J, skin_samples, 6)""" | |
| return torch.cat(self.dense_cond, dim=0) | |
| def get_flatten_dense_skin(self) -> Tensor: | |
| """return (sum_J, skin_samples)""" | |
| return torch.cat(self.dense_skin, dim=0) | |
| def get_flatten_full_skin(self) -> Tensor: | |
| """return (sum_J, N+skin_samples)""" | |
| # (sum_J, N) | |
| s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0) | |
| return torch.cat([s, self.get_flatten_dense_skin()], dim=1) | |
| def get_flatten_full_cond(self) -> Tensor: | |
| """return (sum_J, N+skin_samples, 6)""" | |
| return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1) | |
| def get_flatten_indices(self) -> List[int]: | |
| """return (sum_J)""" | |
| return [i for i in range(self.B) for _ in range(self.get_len(i=i))] | |
| def true_j(self, i: int, j: int) -> int: | |
| """return (clamped) corresponding indice in the skeleton""" | |
| return self.dense_indices[i][self._clamp_j(i=i, j=j)] | |
| class TokenRigResult(): | |
| cond: Optional[Tensor]=None # [vertices, normals] | |
| cond_latents: Optional[Tensor]=None # (len, dim) | |
| input_ids: Optional[Tensor]=None # (l,) | |
| output_ids: Optional[Tensor]=None # (l,) | |
| skin_pred: Optional[Tensor]=None # (N, J) | |
| detokenize_output: Optional[DetokenizeOutput]=None | |
| asset: Optional[Asset]=None | |
| class BoneVaeInput(): | |
| dense_cond: List[Tensor] # [(J, skin_samples, 6)] | |
| dense_skin: List[Tensor] # [(J, skin_samples)] | |
| dense_indices: List[List[int]] # [List[J]], corresponding indices of gt | |
| bones: List[Tensor] # [(J, 6)] | |
| uniform_cond: Tensor # (B, N, 6) | |
| uniform_skin: List[Tensor] # [(N, J)] | |
| def total_samples(self) -> int: | |
| return self.dense_cond[0].shape[1] + self.uniform_cond.shape[1] | |
| def B(self) -> int: | |
| return self.uniform_cond.shape[0] | |
| def max_J(self) -> int: | |
| return max([len(s) for s in self.dense_indices]) | |
| def get_len(self, i) -> int: | |
| return len(self.dense_indices[i]) | |
| def _clamp_j(self, i: int, j: int) -> int: | |
| return min(j, len(self.dense_indices[i])-1) | |
| def get_dense_cond(self, j: int) -> Tensor: | |
| """return (B, skin_samples, 6)""" | |
| return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_dense_skin(self, j: int) -> Tensor: | |
| """return (B, skin_samples)""" | |
| return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_full_cond(self, j: int) -> Tensor: | |
| """return (B, N+skin_samples, 6)""" | |
| return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1) | |
| def get_uniform_skin(self, j: int) -> Tensor: | |
| """return (B, N)""" | |
| return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_full_skin(self, j: int) -> Tensor: | |
| """return (B, N+skin_samples)""" | |
| return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1) | |
| def get_bones(self, j: int) -> Tensor: | |
| """return (B, 3)""" | |
| return torch.stack([self.bones[i][self._clamp_j(i=i, j=j)] for i in range(self.B)]) | |
| def get_flatten_bones(self) -> Tensor: | |
| """return (sum_J, 3)""" | |
| return torch.cat([self.bones[i] for i in range(self.B)]) | |
| def get_flatten_uniform_cond(self) -> Tensor: | |
| """return (sum_J, N, 6)""" | |
| return self.uniform_cond[self.get_flatten_indices()] | |
| def get_flatten_dense_cond(self) -> Tensor: | |
| """return (sum_J, skin_samples, 6)""" | |
| return torch.cat(self.dense_cond, dim=0) | |
| def get_flatten_dense_skin(self) -> Tensor: | |
| """return (sum_J, skin_samples)""" | |
| return torch.cat(self.dense_skin, dim=0) | |
| def get_flatten_full_skin(self) -> Tensor: | |
| """return (sum_J, N+skin_samples)""" | |
| # (sum_J, N) | |
| s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0) | |
| return torch.cat([s, self.get_flatten_dense_skin()], dim=1) | |
| def get_flatten_full_cond(self) -> Tensor: | |
| """return (sum_J, N+skin_samples, 6)""" | |
| return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1) | |
| def get_flatten_indices(self) -> List[int]: | |
| """return (sum_J)""" | |
| return [i for i in range(self.B) for _ in range(self.get_len(i=i))] | |
| def true_j(self, i: int, j: int) -> int: | |
| """return (clamped) corresponding indice in the skeleton""" | |
| return self.dense_indices[i][self._clamp_j(i=i, j=j)] |