SkinTokens / src /model /spec.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from numpy import ndarray
from omegaconf import OmegaConf
from typing import Dict, List, Optional, final
from torch import Tensor
import numpy as np
import lightning.pytorch as pl
import torch
from ..data.transform import Transform
from ..rig_package.info.asset import Asset
from ..tokenizer.spec import DetokenizeOutput
@dataclass
class ModelInput():
asset: Asset
tokens: Optional[ndarray]=None
class ModelSpec(pl.LightningModule, ABC):
model_config: Dict
transform_config: Dict
tokenizer_config: Dict|None
@abstractmethod
def __init__(self, model_config, transform_config, tokenizer_config=None):
super().__init__()
if not isinstance(model_config, dict):
model_cfg = OmegaConf.to_container(model_config, resolve=True)
else:
model_cfg = model_config
if not isinstance(transform_config, dict):
transform_cfg = OmegaConf.to_container(transform_config, resolve=True)
else:
transform_cfg = transform_config
if tokenizer_config is not None and not isinstance(tokenizer_config, dict):
tokenizer_cfg = OmegaConf.to_container(tokenizer_config, resolve=True)
else:
tokenizer_cfg = tokenizer_config
self.model_config = model_cfg # type: ignore
self.transform_config = transform_cfg # type: ignore
self.tokenizer_config = tokenizer_cfg # type: ignore
self.save_hyperparameters(model_cfg)
self.save_hyperparameters(transform_cfg)
self.save_hyperparameters(tokenizer_cfg)
@final
def _process_fn(self, batch: List[ModelInput]) -> List[Dict]:
n_batch = self.process_fn(batch)
if self._trainer is None or not self.trainer.training:
for k in n_batch[0].keys():
if not isinstance(n_batch[0][k], ndarray) and not isinstance(n_batch[0][k], Tensor):
continue
s = n_batch[0][k].shape
for i in range(1, len(n_batch)):
assert n_batch[i][k].shape == s, f"{k} has different shape in batch"
for (i, b) in enumerate(batch):
non = n_batch[i].get('non', {})
non['model_input'] = deepcopy(b)
n_batch[i]['non'] = non
else:
for b in batch:
del b.asset
return n_batch
@abstractmethod
def process_fn(self, batch: List[ModelInput]) -> List[Dict]:
"""
Fetch data from dataloader and turn it into Tensor objects.
"""
raise NotImplementedError()
def compile_model(self):
"""
Compile the model. Do this before training and after loading state dicts.
"""
pass
@classmethod
def load_from_system_checkpoint(cls, checkpoint_path: str, strict: bool=True, **kwargs):
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state_dict = ckpt['state_dict']
model_config = kwargs.get('model_config', None)
transform_config = kwargs.get('transform_config', None)
tokenizer_config = kwargs.get('tokenizer_config', None)
if model_config is None:
model_config = ckpt['hyper_parameters']['model_config']
if transform_config is None:
transform_config = ckpt['hyper_parameters']['transform_config']
if tokenizer_config is None:
tokenizer_config = ckpt['hyper_parameters']['tokenizer_config']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("_orig_mod.", "")
if k.startswith("model."):
k = k[len("model.") :]
new_state_dict[k] = v
model = cls(
model_config=model_config,
transform_config=transform_config,
tokenizer_config=tokenizer_config,
)
missing, unexpected = model.load_state_dict(new_state_dict, strict=strict)
if missing or unexpected:
print(f"[Warning] Missing keys: {missing}")
print(f"[Warning] Unexpected keys: {unexpected}")
model.on_load_checkpoint(ckpt)
return model
def get_train_transform(self) -> Transform|None:
cfg = self.transform_config.get('train_transform', None)
if cfg is None:
return None
return Transform.parse(**cfg)
def get_validate_transform(self) -> Transform|None:
cfg = self.transform_config.get('validate_transform', None)
if cfg is None:
return None
return Transform.parse(**cfg)
def get_predict_transform(self) -> Transform|None:
cfg = self.transform_config.get('predict_transform', None)
if cfg is None:
return None
return Transform.parse(**cfg)
def predict_step(self, batch: Dict, no_cls: bool=False, skeleton_tokens=None) -> Dict:
raise NotImplementedError()
@dataclass
class VaeInput():
dense_cond: List[Tensor] # [(J, skin_samples, 6)]
dense_skin: List[Tensor] # [(J, skin_samples)]
dense_indices: List[List[int]] # [List[J]], corresponding indices of gt
uniform_cond: Tensor # (B, N, 6)
uniform_skin: List[Tensor] # [(N, J)]
@property
def B(self):
return self.uniform_cond.shape[0]
@property
def max_J(self):
return max([len(s) for s in self.dense_indices])
def get_len(self, i) -> int:
return len(self.dense_indices[i])
def _clamp_j(self, i: int, j: int) -> int:
return min(j, len(self.dense_indices[i])-1)
def get_dense_cond(self, j: int) -> Tensor:
"""return (B, skin_samples, 6)"""
return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_dense_skin(self, j: int) -> Tensor:
"""return (B, skin_samples)"""
return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_full_cond(self, j: int) -> Tensor:
"""return (B, N+skin_samples, 6)"""
return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1)
def get_uniform_skin(self, j: int) -> Tensor:
"""return (B, N)"""
return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_full_skin(self, j: int) -> Tensor:
"""return (B, N+skin_samples)"""
return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1)
def get_flatten_uniform_cond(self) -> Tensor:
"""return (sum_J, N, 6)"""
return self.uniform_cond[self.get_flatten_indices()]
def get_flatten_dense_cond(self) -> Tensor:
"""return (sum_J, skin_samples, 6)"""
return torch.cat(self.dense_cond, dim=0)
def get_flatten_dense_skin(self) -> Tensor:
"""return (sum_J, skin_samples)"""
return torch.cat(self.dense_skin, dim=0)
def get_flatten_full_skin(self) -> Tensor:
"""return (sum_J, N+skin_samples)"""
# (sum_J, N)
s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0)
return torch.cat([s, self.get_flatten_dense_skin()], dim=1)
def get_flatten_full_cond(self) -> Tensor:
"""return (sum_J, N+skin_samples, 6)"""
return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1)
def get_flatten_indices(self) -> List[int]:
"""return (sum_J)"""
return [i for i in range(self.B) for _ in range(self.get_len(i=i))]
def true_j(self, i: int, j: int) -> int:
"""return (clamped) corresponding indice in the skeleton"""
return self.dense_indices[i][self._clamp_j(i=i, j=j)]
@dataclass
class TokenRigResult():
cond: Optional[Tensor]=None # [vertices, normals]
cond_latents: Optional[Tensor]=None # (len, dim)
input_ids: Optional[Tensor]=None # (l,)
output_ids: Optional[Tensor]=None # (l,)
skin_pred: Optional[Tensor]=None # (N, J)
detokenize_output: Optional[DetokenizeOutput]=None
asset: Optional[Asset]=None
@dataclass
class BoneVaeInput():
dense_cond: List[Tensor] # [(J, skin_samples, 6)]
dense_skin: List[Tensor] # [(J, skin_samples)]
dense_indices: List[List[int]] # [List[J]], corresponding indices of gt
bones: List[Tensor] # [(J, 6)]
uniform_cond: Tensor # (B, N, 6)
uniform_skin: List[Tensor] # [(N, J)]
@property
def total_samples(self) -> int:
return self.dense_cond[0].shape[1] + self.uniform_cond.shape[1]
@property
def B(self) -> int:
return self.uniform_cond.shape[0]
@property
def max_J(self) -> int:
return max([len(s) for s in self.dense_indices])
def get_len(self, i) -> int:
return len(self.dense_indices[i])
def _clamp_j(self, i: int, j: int) -> int:
return min(j, len(self.dense_indices[i])-1)
def get_dense_cond(self, j: int) -> Tensor:
"""return (B, skin_samples, 6)"""
return torch.stack([self.dense_cond[i][self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_dense_skin(self, j: int) -> Tensor:
"""return (B, skin_samples)"""
return torch.stack([self.dense_skin[i][self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_full_cond(self, j: int) -> Tensor:
"""return (B, N+skin_samples, 6)"""
return torch.cat([self.uniform_cond, self.get_dense_cond(j=j)], dim=1)
def get_uniform_skin(self, j: int) -> Tensor:
"""return (B, N)"""
return torch.stack([self.uniform_skin[i][:, self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_full_skin(self, j: int) -> Tensor:
"""return (B, N+skin_samples)"""
return torch.cat([self.get_uniform_skin(j=j), self.get_dense_skin(j=j)], dim=1)
def get_bones(self, j: int) -> Tensor:
"""return (B, 3)"""
return torch.stack([self.bones[i][self._clamp_j(i=i, j=j)] for i in range(self.B)])
def get_flatten_bones(self) -> Tensor:
"""return (sum_J, 3)"""
return torch.cat([self.bones[i] for i in range(self.B)])
def get_flatten_uniform_cond(self) -> Tensor:
"""return (sum_J, N, 6)"""
return self.uniform_cond[self.get_flatten_indices()]
def get_flatten_dense_cond(self) -> Tensor:
"""return (sum_J, skin_samples, 6)"""
return torch.cat(self.dense_cond, dim=0)
def get_flatten_dense_skin(self) -> Tensor:
"""return (sum_J, skin_samples)"""
return torch.cat(self.dense_skin, dim=0)
def get_flatten_full_skin(self) -> Tensor:
"""return (sum_J, N+skin_samples)"""
# (sum_J, N)
s = torch.cat(self.uniform_skin, dim=-1).permute(1, 0)
return torch.cat([s, self.get_flatten_dense_skin()], dim=1)
def get_flatten_full_cond(self) -> Tensor:
"""return (sum_J, N+skin_samples, 6)"""
return torch.cat([self.get_flatten_uniform_cond(), self.get_flatten_dense_cond()], dim=1)
def get_flatten_indices(self) -> List[int]:
"""return (sum_J)"""
return [i for i in range(self.B) for _ in range(self.get_len(i=i))]
def true_j(self, i: int, j: int) -> int:
"""return (clamped) corresponding indice in the skeleton"""
return self.dense_indices[i][self._clamp_j(i=i, j=j)]