Spaces:
Running on Zero
Running on Zero
| from copy import deepcopy | |
| from pathlib import Path | |
| from torch import nn, Tensor, FloatTensor | |
| from torch.nn.functional import pad | |
| from transformers import AutoModelForCausalLM, AutoConfig, LogitsProcessor, LogitsProcessorList # type: ignore | |
| from typing import Dict, List, Tuple | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| LLM_LOCAL_DIR = Path("models/Qwen3-0.6B") | |
| from .skin_vae_model import SkinVAEModel | |
| from .skin_vae.autoencoders import SkinFSQCVAEModel | |
| from .spec import ModelSpec, ModelInput, VaeInput, TokenRigResult | |
| from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder | |
| from ..rig_package.info.asset import Asset | |
| from ..tokenizer.spec import Tokenizer | |
| from ..tokenizer.spec import DetokenizeOutput | |
| from ..tokenizer.parse import get_tokenizer | |
| try: | |
| from flash_attn_interface import flash_attn_func # type: ignore | |
| except Exception as e: | |
| from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func | |
| def flash_attn_func(*args, **kwargs): | |
| res = _flash_attn_func(*args, **kwargs) | |
| return res, None | |
| class VocabSwitchingLogitsProcessor(LogitsProcessor): | |
| def __init__(self, tokenizer: Tokenizer, switch_token_id, eos_token_id, tokens_per_skin, init): | |
| # make sure all skin tokens > switch_token_id | |
| self.tokenizer = tokenizer | |
| self.switch_token_id = switch_token_id | |
| self.eos_token_id = eos_token_id | |
| self.tokens_per_skin = tokens_per_skin | |
| self.init = init | |
| def __call__(self, input_ids: Tensor, scores: FloatTensor) -> FloatTensor: | |
| # input_ids shape: (batch_size, seq_len) | |
| for batch_idx, sequence in enumerate(input_ids): | |
| mask = torch.full_like(scores[batch_idx], float('-inf')) | |
| sequence = torch.cat([self.init, sequence]) | |
| length = len(sequence) | |
| if self.switch_token_id in sequence: | |
| mask[self.switch_token_id:] = 0 | |
| where = torch.where(sequence == self.switch_token_id)[0][:1] | |
| J = self.tokenizer.bones_in_sequence(ids=sequence.detach().cpu().numpy()) | |
| if (length-where) == J*self.tokens_per_skin: | |
| mask[:] = float('-inf') | |
| mask[self.eos_token_id] = 0 | |
| else: | |
| mask[self.eos_token_id] = float('-inf') | |
| else: | |
| tokens = self.tokenizer.next_posible_token(ids=sequence.detach().cpu().numpy()) | |
| mask[tokens] = 0 | |
| scores[batch_idx] = scores[batch_idx] + mask | |
| return scores | |
| class TokenRig(ModelSpec): | |
| def __init__(self, model_config, transform_config, tokenizer_config=None): | |
| assert tokenizer_config is not None | |
| super().__init__(model_config=model_config, transform_config=transform_config, tokenizer_config=tokenizer_config) | |
| cfg = self.model_config | |
| self.tokens_per_skin: int = cfg['tokens_per_skin'] | |
| self.tokens_skin_cond: int = cfg['tokens_skin_cond'] | |
| self.use_rope: bool = cfg.get('use_rope', True) | |
| self.encode_repeat: int = cfg.get('encode_repeat', 4) | |
| self.skin_warmup_start_epoch: int = cfg.get('skin_warmup_start_epoch', 0) | |
| self.skin_warmup_end_epoch: int = cfg.get('skin_warmup_end_epoch', -1) | |
| self.vae = SkinVAEModel.load_from_system_checkpoint(cfg['pretrained_vae']).to(torch.bfloat16) | |
| for param in self.vae.parameters(): | |
| param.requires_grad_(False) | |
| self.vae.eval() | |
| self.mesh_encoder = get_mesh_encoder(**cfg['mesh_encoder']) | |
| assert ( | |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or | |
| isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) | |
| ) | |
| self.mesh_encoder = self.mesh_encoder.to(torch.bfloat16) | |
| self.tokenizer: Tokenizer = get_tokenizer(**tokenizer_config) | |
| # (tokenizer codebook, fsq vae codebook) | |
| self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 | |
| self.eos = self.vocab_size - 1 | |
| _d = cfg['llm'].copy() | |
| self.hidden_size = _d['hidden_size'] | |
| _d['vocab_size'] = self.vocab_size | |
| if LLM_LOCAL_DIR.exists(): | |
| _d['pretrained_model_name_or_path'] = str(LLM_LOCAL_DIR) | |
| llm_config = AutoConfig.from_pretrained(**_d) | |
| self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 | |
| llm_config.torch_dtype = torch.bfloat16 | |
| llm_config.pre_norm = True | |
| self.llm_config = llm_config | |
| self.transformer = AutoModelForCausalLM.from_config(config=llm_config, attn_implementation="flash_attention_2").to(torch.bfloat16) | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(self.mesh_encoder.width, self.hidden_size), | |
| nn.RMSNorm(self.hidden_size), | |
| ).to(torch.bfloat16) | |
| init_scale = cfg.get('init_scale', None) | |
| if init_scale is not None: | |
| self.initialize_weights(init_scale) | |
| def compile_model(self): | |
| self.vae.compile_model() | |
| self.transformer = torch.compile(self.transformer, dynamic=False) | |
| self.mesh_encoder = torch.compile(self.mesh_encoder, dynamic=False) | |
| def initialize_weights(self, s: float): | |
| def init_linear(l, stddev): | |
| nn.init.normal_(l.weight, std=stddev) | |
| if l.bias is not None: | |
| nn.init.constant_(l.bias, 0.0) | |
| init_scale = s * math.sqrt(1.0 / self.mesh_encoder.width) | |
| for m in self.mesh_encoder.modules(): | |
| if isinstance(m, nn.Linear): | |
| init_linear(m, stddev=init_scale) | |
| init_scale = s * math.sqrt(1.0 / self.hidden_size) | |
| for m in self.output_proj.modules(): | |
| if isinstance(m, nn.Linear): | |
| init_linear(m, stddev=init_scale) | |
| def get_skin_warmup_rate(self, steps_per_epoch: int) -> float: | |
| if self.current_epoch < self.skin_warmup_start_epoch: | |
| return 0. | |
| if self.current_epoch > self.skin_warmup_end_epoch: | |
| return 1. | |
| start_steps = self.skin_warmup_start_epoch * steps_per_epoch | |
| end_steps = (self.skin_warmup_end_epoch+1) * steps_per_epoch | |
| rate = (self.global_step-start_steps) / (end_steps-start_steps) | |
| return min(max((1.0-math.cos(math.pi * rate))/2, 0), 1) | |
| def training_step(self, batch: Dict) -> Dict: | |
| raise NotImplementedError() | |
| def make_start_tokens(self, **kwargs) -> List[List[int]]: | |
| skeleton_tokens = kwargs.get('skeleton_tokens', None) | |
| skeleton_mask = kwargs.get('skeleton_mask', None) | |
| num_joints = kwargs.get('num_joints', None) | |
| parents = kwargs.get('parents', None) | |
| cls = kwargs.get('cls', None) | |
| start_tokens_list = [] | |
| batch_size = 1 | |
| if skeleton_tokens is not None: | |
| batch_size = len(skeleton_tokens) | |
| elif cls is not None: | |
| batch_size = len(cls) | |
| elif num_joints is not None: | |
| batch_size = len(num_joints) | |
| elif parents is not None: | |
| batch_size = len(parents) | |
| else: | |
| assert 0, "must provide one of skeleton_tokens, cls, num_joints, parents" | |
| for i in range(batch_size): | |
| if skeleton_tokens is not None: | |
| _skeleton_tokens = skeleton_tokens[i] | |
| _skeleton_mask = skeleton_mask[i] if skeleton_mask is not None else None | |
| assert _skeleton_tokens[0] == self.tokenizer.bos | |
| if skeleton_mask is not None: | |
| start_tokens = _skeleton_tokens[_skeleton_mask==1] | |
| else: | |
| start_tokens = _skeleton_tokens | |
| else: | |
| start_tokens = [self.tokenizer.bos] | |
| start_tokens += self.tokenizer.make_cls_head( | |
| cls=cls[i] if cls is not None else None, | |
| num_joints=num_joints[i] if num_joints is not None else None, | |
| parents=parents[i] if parents is not None else None, | |
| ) | |
| if isinstance(start_tokens, Tensor): | |
| start_tokens = start_tokens.detach().cpu().numpy().tolist() | |
| start_tokens_list.append(start_tokens) | |
| return start_tokens_list | |
| def generate( | |
| self, | |
| vertices: Tensor, | |
| normals: Tensor, | |
| cls: str|None=None, | |
| skeleton_tokens: np.ndarray|Tensor|None=None, | |
| only_ids: bool=False, | |
| return_decode_dict: bool=False, | |
| num_joints: int|None=None, | |
| parents: Tensor|None=None, | |
| **kwargs, | |
| ) -> TokenRigResult: | |
| """ | |
| Do not support batch! | |
| """ | |
| assert isinstance(self.vae.model, SkinFSQCVAEModel) | |
| assert vertices.dim() == 2, 'do not support batch' | |
| assert normals.dim() == 2, 'do not support batch' | |
| if isinstance(skeleton_tokens, np.ndarray): | |
| skeleton_tokens = torch.from_numpy(skeleton_tokens).to(self.device) | |
| cond = torch.cat([vertices, normals], dim=-1).unsqueeze(0) | |
| _, cond_latents = self.vae.model._encode( | |
| x=None, | |
| cond=cond, | |
| num_tokens=self.tokens_per_skin, | |
| cond_tokens=self.tokens_skin_cond, | |
| return_z=False, | |
| ) | |
| assert cond_latents is not None | |
| # (1, len, dim) | |
| learned_mesh_cond = encode_mesh_cond(self.mesh_encoder, self.output_proj, self.tokens_skin_cond, {'vertices': vertices, 'normals': normals}) | |
| device = cond.device | |
| start_tokens = torch.tensor(self.make_start_tokens( | |
| device=device, | |
| cls=None if cls is None else [cls], | |
| skeleton_tokens=None if skeleton_tokens is None else [skeleton_tokens], | |
| num_joints=None if num_joints is None else [num_joints], | |
| parents=None if parents is None else [parents], | |
| )[0], device=device).unsqueeze(0) | |
| assert start_tokens.shape[0] == 1 | |
| start_embed = self.transformer.get_input_embeddings()(start_tokens) | |
| inputs_embeds = torch.cat([learned_mesh_cond, start_embed], dim=1) | |
| results = self.transformer.generate( | |
| inputs_embeds=inputs_embeds, | |
| bos_token_id=self.tokenizer.bos, | |
| eos_token_id=self.eos, | |
| pad_token_id=self.tokenizer.pad, | |
| logits_processor=get_logits_processor( | |
| tokenizer=self.tokenizer, | |
| eos=self.eos, | |
| tokens_per_skin=self.tokens_per_skin, | |
| start_tokens=start_tokens[0], | |
| ), | |
| **kwargs, | |
| ) | |
| res = TokenRigResult() | |
| output_ids = results[0, :] | |
| for token in reversed(start_tokens[0]): | |
| v = token.item() | |
| output_ids = pad(output_ids, (1, 0), value=v) | |
| res.input_ids = start_tokens[0] | |
| res.output_ids = output_ids | |
| if only_ids: | |
| return res | |
| res.cond = cond[0] | |
| res.cond_latents = cond_latents[0] | |
| if return_decode_dict: | |
| return res | |
| d = decode( | |
| cond=cond[0], | |
| cond_latents=cond_latents[0], | |
| inputs_ids=output_ids, | |
| tokenizer=self.tokenizer, | |
| tokens_per_skin=self.tokens_per_skin, | |
| vae=self.vae, | |
| ) | |
| res.skin_pred = d['skin_pred'] | |
| res.detokenize_output = d['detokenize_output'] | |
| return res | |
| def _debug_export( | |
| self, | |
| batch: Dict, | |
| cond: Tensor, | |
| cond_latents: Tensor, | |
| inputs_ids: Tensor, | |
| id: int=0, | |
| path: str='res.fbx', | |
| ): | |
| if inputs_ids.dim() == 2: | |
| assert cond_latents.dim() == cond.dim() == 3, f"Expected 3 dimensions, got {cond_latents.dim()}, {cond.dim()}" | |
| cond = cond[id] | |
| cond_latents = cond_latents[id] | |
| inputs_ids = inputs_ids[id] | |
| res = decode( | |
| cond=cond, | |
| cond_latents=cond_latents, | |
| inputs_ids=inputs_ids, | |
| tokenizer=self.tokenizer, | |
| tokens_per_skin=self.tokens_per_skin, | |
| vae=self.vae, | |
| ) | |
| detokenize_output: DetokenizeOutput = res['detokenize_output'] | |
| origin_asset: Asset = batch['model_input'][id].asset | |
| asset = Asset.from_data( | |
| vertices=origin_asset.vertices, | |
| faces=origin_asset.faces, | |
| sampled_vertices=batch['vertices'][id].detach().cpu().numpy(), | |
| sampled_skin=res['skin_pred'].detach().cpu().numpy(), | |
| parents=np.array(detokenize_output.parents), | |
| joint_names=detokenize_output.joint_names, | |
| joints=detokenize_output.joints, | |
| ) | |
| from ..rig_package.parser.bpy import BpyParser | |
| BpyParser.export_asset(asset, filepath=path) | |
| def process_fn(self, batch: List[ModelInput]) -> List[Dict]: | |
| res = [] | |
| max_length = 0 | |
| for b in batch: | |
| if b.tokens is not None: | |
| max_length = max(max_length, b.tokens.shape[0]) | |
| res = [] | |
| for b in batch: | |
| if b.tokens is not None: | |
| skeleton_tokens = np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=self.tokenizer.pad) | |
| skeleton_mask = np.pad(np.ones_like(b.tokens), ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=0) | |
| else: | |
| skeleton_tokens = None | |
| skeleton_mask = None | |
| _d = { | |
| 'vertices': torch.from_numpy(b.asset.sampled_vertices).float(), | |
| 'normals': torch.from_numpy(b.asset.sampled_normals).float(), | |
| 'non': { | |
| 'cls': b.asset.cls, | |
| } | |
| } | |
| if skeleton_mask is not None: | |
| _d.update({ | |
| 'skeleton_tokens': skeleton_tokens, | |
| 'skeleton_mask': skeleton_mask, | |
| }) | |
| _d['non'].update({ | |
| 'parents': b.asset.parents, | |
| 'num_bones': b.asset.J, | |
| }) | |
| if b.asset.sampled_vertex_groups is not None and 'skin' in b.asset.sampled_vertex_groups: | |
| assert b.asset.meta is not None | |
| _d['non'].update({ | |
| 'cls': b.asset.cls, | |
| 'uniform_skin': torch.from_numpy(b.asset.sampled_vertex_groups['skin']).float(), | |
| 'skin_samples': b.asset.skin_samples, | |
| 'dense_indices': b.asset.meta['dense_indices'], | |
| 'dense_skin': torch.from_numpy(b.asset.meta['dense_skin']).float(), | |
| 'dense_vertices': torch.from_numpy(b.asset.meta['dense_vertices']).float(), | |
| 'dense_normals': torch.from_numpy(b.asset.meta['dense_normals']).float(), | |
| }) | |
| res.append(_d) | |
| return res | |
| def predict_step( | |
| self, | |
| batch: Dict, | |
| no_cls: bool=False, | |
| skeleton_tokens=None, | |
| parents=None, | |
| num_joints=None, | |
| make_asset: bool=False, | |
| **kwargs | |
| ) -> Dict: | |
| vertices: Tensor = batch['vertices'] | |
| normals : Tensor = batch['normals'] | |
| cls = batch['cls'] | |
| generate_kwargs = deepcopy(batch['generate_kwargs']) | |
| if vertices.dim() == 2: | |
| vertices = vertices.unsqueeze(0) | |
| normals = normals.unsqueeze(0) | |
| results = [] | |
| if skeleton_tokens is None: | |
| skeleton_tokens = [None] * vertices.shape[0] | |
| d = {} | |
| for i in range(vertices.shape[0]): | |
| res = self.generate( | |
| vertices=vertices[i], | |
| normals=normals[i], | |
| skeleton_tokens=skeleton_tokens[i], | |
| cls=None if no_cls else cls[i], | |
| parents=None if parents is None else parents[i], | |
| num_joints=None if num_joints is None else num_joints[i], | |
| **generate_kwargs | |
| ) | |
| if make_asset: | |
| assert 'model_input' in batch, "need model_input to make asset (in validate/predict mode)" | |
| assert res.detokenize_output is not None | |
| assert res.skin_pred is not None | |
| asset: Asset = batch['model_input'][i].asset.copy() | |
| res.asset = Asset.from_data( | |
| vertices=asset.vertices, | |
| faces=asset.faces, | |
| sampled_vertices=vertices[i].detach().float().cpu().numpy(), | |
| sampled_skin=res.skin_pred.detach().float().cpu().numpy(), | |
| joints=res.detokenize_output.joints, | |
| parents=np.array(res.detokenize_output.parents), | |
| cls=asset.cls, | |
| path=asset.path, | |
| ) | |
| results.append(res) | |
| d['results'] = results | |
| return d | |
| def forward(self, batch: Dict) -> Dict[str, Tensor]: | |
| return self.training_step(batch=batch) | |
| def _check(x: Tensor, s, m=None): | |
| assert isinstance(s, (list, tuple)), "Expected shape must be a list or tuple" | |
| assert x.dim() == len(s), f"Expected {len(s)} dims, got {x.dim()}" | |
| for i, (dim_actual, dim_expected) in enumerate(zip(x.shape, s)): | |
| if dim_expected is not None and dim_expected != -1: | |
| if m is None: | |
| assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}" | |
| else: | |
| assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}. Message: {m}" | |
| def encode_mesh_cond(mesh_encoder, output_proj, tokens_skin_cond, batch: Dict) -> Tensor: | |
| vertices = batch['vertices'] # (B, N, 3) | |
| normals = batch['normals'] # (B, N, 3) | |
| assert isinstance(vertices, Tensor) | |
| assert isinstance(normals, Tensor) | |
| if (len(vertices.shape) == 3): | |
| shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices, feats=normals) # type: ignore | |
| else: | |
| shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) # type: ignore | |
| latents = output_proj(latents) | |
| return latents | |
| def encode( | |
| tokenizer: Tokenizer, | |
| vae: SkinVAEModel, | |
| vae_input: VaeInput, | |
| encode_repeat: int, | |
| tokens_skin_cond: int, | |
| tokens_per_skin: int, | |
| ) -> Tuple[Tensor, Tensor, Tensor]: | |
| """ | |
| Returns: | |
| skin_tokens: (B, tokens_per_skin*J) | |
| cond_latents: (B, tokens_skin_cond, vae.latent_channels) | |
| skin_mask: (B, tokens_per_skin*J), 1 -> skin, 0 -> pad | |
| """ | |
| device = vae_input.uniform_cond.device | |
| B = vae_input.B | |
| J = vae_input.max_J | |
| _, cond_latents, codes, _ = vae.encode(vae_input=vae_input, num_tokens=tokens_per_skin, full=True, encode_repeat=encode_repeat) | |
| codes = codes[:, :tokens_per_skin] | |
| indices = vae_input.get_flatten_indices() | |
| skin_tokens = torch.full((B, J * tokens_per_skin), tokenizer.pad, dtype=torch.long, device=device) | |
| skin_mask = torch.zeros_like(skin_tokens, dtype=torch.long) | |
| j_counters = [0 for _ in range(B)] | |
| for idx, batch_id in enumerate(indices): | |
| j = j_counters[batch_id] | |
| s = j * tokens_per_skin | |
| t = s + tokens_per_skin | |
| skin_tokens[batch_id, s:t] = codes[idx] + tokenizer.vocab_size | |
| skin_mask[batch_id, s:t] = 1 | |
| j_counters[batch_id] += 1 | |
| assert cond_latents is not None | |
| _check(cond_latents, (B, tokens_skin_cond, vae.latent_channels)) | |
| _check(skin_tokens, (B, J * tokens_per_skin)) | |
| _check(skin_mask, (B, J * tokens_per_skin)) | |
| return skin_tokens, cond_latents, skin_mask | |
| def prepare_llm_tokens( | |
| tokenizer: Tokenizer, | |
| eos: int, | |
| skeleton_tokens: Tensor, | |
| skeleton_mask: Tensor, | |
| skin_tokens: Tensor, | |
| skin_mask: Tensor, | |
| cond_latents: Tensor, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """ | |
| Args: | |
| skeleton_tokens: (B, n) | |
| skeleton_mask: (B, n) | |
| skin_tokens: (B, tokens_per_skin*J) | |
| skin_mask: (B, tokens_per_skin*J) | |
| cond_latents: (B, tokens_skin_cond, vae.latent_channels) | |
| Returns: | |
| llm_tokens: (B, seq_len) | |
| attention_mask: (B, seq_len), 1 -> attend, 0 -> pad | |
| """ | |
| B = skeleton_tokens.shape[0] | |
| inputs_ids = torch.ones((B, skeleton_tokens.shape[1] + skin_tokens.shape[1] + 1), dtype=torch.long, device=skeleton_tokens.device) * tokenizer.pad | |
| num_skeleton = skeleton_mask.sum(dim=1) | |
| num_skin = skin_mask.sum(dim=1) | |
| attention_mask = torch.ones((B, inputs_ids.shape[1]), dtype=torch.float32, device=skeleton_tokens.device) | |
| llm_skeleton_mask = torch.ones_like(attention_mask, dtype=torch.bool) | |
| llm_skin_mask = torch.ones_like(attention_mask, dtype=torch.bool) | |
| for i in range(B): | |
| length = num_skeleton[i] + num_skin[i] | |
| inputs_ids[i, :num_skeleton[i]] = skeleton_tokens[i, :num_skeleton[i]] | |
| inputs_ids[i, num_skeleton[i]:num_skeleton[i]+num_skin[i]] = skin_tokens[i, :num_skin[i]] | |
| inputs_ids[i, num_skeleton[i]+num_skin[i]] = eos # add an eos | |
| attention_mask[i, length+1:] = 0. | |
| llm_skeleton_mask[i, num_skeleton[i]:] = 0 | |
| llm_skin_mask[i, :num_skeleton[i]] = 0 | |
| llm_skin_mask[i, length+1:] = 0 | |
| seq_len = inputs_ids.shape[1] | |
| _check(inputs_ids, (B, seq_len)) | |
| _check(attention_mask, (B, seq_len)) | |
| return inputs_ids, attention_mask, llm_skeleton_mask, llm_skin_mask | |
| def get_logits_processor(tokenizer: Tokenizer, eos: int, tokens_per_skin: int, start_tokens): | |
| processor = VocabSwitchingLogitsProcessor( | |
| tokenizer=tokenizer, | |
| switch_token_id=tokenizer.eos, | |
| eos_token_id=eos, | |
| tokens_per_skin=tokens_per_skin, | |
| init=start_tokens, | |
| ) | |
| return LogitsProcessorList([processor]) | |
| def decode( | |
| cond: Tensor, | |
| cond_latents: Tensor, | |
| inputs_ids: Tensor, | |
| tokenizer: Tokenizer, | |
| tokens_per_skin: int, | |
| vae: SkinVAEModel, | |
| encode_repeat: int=1, | |
| ) -> Dict: | |
| """ | |
| inputs_ids: (seq_len) | |
| cond: (N, c) | |
| cond_latents: (tokens_skin_cond, dim) | |
| """ | |
| assert cond.dim() == 2, 'do not support batch' | |
| assert cond_latents.dim() == 2, 'do not support batch' | |
| where_eos = torch.where(inputs_ids == tokenizer.eos) | |
| if where_eos[0].shape[0] == 0: | |
| raise ValueError("No EOS token found in inputs_ids") | |
| where_eos = where_eos[0][:1] | |
| skeleton_tokens = inputs_ids[:where_eos+1] | |
| skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) | |
| detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) | |
| J = detokenize_output.joints.shape[0] | |
| skin_tokens = inputs_ids[where_eos+1:where_eos+1+J*tokens_per_skin] | |
| if skin_tokens.shape != (J*(tokens_per_skin),): | |
| return { | |
| 'skin_pred': None, | |
| 'detokenize_output': detokenize_output, | |
| } | |
| cond = cond.unsqueeze(0) | |
| cond_latents = cond_latents.unsqueeze(0) | |
| skin = [] | |
| g = tokens_per_skin * encode_repeat | |
| for s in range(0, J*tokens_per_skin, g): | |
| t = min(s+g, J*tokens_per_skin) | |
| indices = skin_tokens[s:t].unsqueeze(0) - tokenizer.vocab_size | |
| # expect: (b, tokens_per_skin, dim) | |
| b = (t-s)//tokens_per_skin | |
| z = vae.model.FSQ.indices_to_codes(indices).reshape(b, tokens_per_skin, -1) | |
| # (b, n, 1) | |
| logits = vae.decode(z=z, sampled_cond=cond.repeat(b, 1, 1), cond_tokens=cond_latents.repeat(b, 1, 1)) | |
| skin_pred = logits.reshape(b, logits.shape[1]).permute(1, 0) | |
| skin.append(skin_pred) | |
| skin = torch.concat(skin, dim=1).float() | |
| return { | |
| 'skin_pred': skin, | |
| 'detokenize_output': detokenize_output, | |
| } | |
| def decode_multi( | |
| cond: Tensor, | |
| cond_latents: Tensor, | |
| inputs_ids: List[Tensor], | |
| tokenizer: Tokenizer, | |
| tokens_per_skin: int, | |
| vae: SkinVAEModel, | |
| is_numpy: bool=True, | |
| encode_repeat: int=1, | |
| ) -> List[Dict]: | |
| """ | |
| inputs_ids: List[(seq_len)] | |
| cond: (N, c) | |
| cond_latents: (tokens_skin_cond, dim) | |
| """ | |
| assert cond.dim() == 2, 'do not support batch' | |
| assert cond_latents.dim() == 2, 'do not support batch' | |
| B = len(inputs_ids) | |
| res = [{'skin_pred': None, 'detokenize_output': None} for _ in range(B)] | |
| device = cond.device | |
| batch_mapping = [] | |
| skin_tokens_list = [] | |
| oks = [] | |
| oks_J = [] | |
| for i in range(B): | |
| where_eos = torch.where(inputs_ids[i] == tokenizer.eos) | |
| if where_eos[0].shape[0] == 0: | |
| print(f"decode_multi: {i} has bad skeleton") | |
| continue | |
| where_eos = where_eos[0][:1] | |
| skeleton_tokens = inputs_ids[i][:where_eos+1] | |
| skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) | |
| try: | |
| detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) | |
| except Exception as e: | |
| print(f"decode_multi: error while decoding skeleton: {str(e)}") | |
| continue | |
| J = detokenize_output.joints.shape[0] | |
| res[i]['detokenize_output'] = detokenize_output # type: ignore | |
| skin_tokens = inputs_ids[i][where_eos+1:where_eos+1+J*tokens_per_skin] | |
| if skin_tokens.shape != (J*(tokens_per_skin),): | |
| print(f"decode_multi: {i} has bad skin") | |
| continue | |
| batch_mapping.append(torch.full((J,), i, device=device, dtype=torch.long)) | |
| skin_tokens_list.append(skin_tokens) | |
| oks.append(i) | |
| oks_J.append(J) | |
| if len(batch_mapping) == 0: | |
| return res | |
| batch_mapping = torch.cat(batch_mapping, dim=0) | |
| # (1, sum_J*tokens_per_skin) | |
| skin_tokens = torch.cat(skin_tokens_list, dim=0).unsqueeze(0) | |
| cond = cond.unsqueeze(0) | |
| cond_latents = cond_latents.unsqueeze(0) | |
| skin_list = [] | |
| g = tokens_per_skin * encode_repeat | |
| sum_J = batch_mapping.shape[0] | |
| for s in range(0, sum_J*tokens_per_skin, g): | |
| t = min(s+g, sum_J*tokens_per_skin) | |
| # (1, m*tokens_per_skin) | |
| indices = skin_tokens[:, s:t] - tokenizer.vocab_size | |
| # expect: (m, tokens_per_skin, dim) | |
| m = (t-s)//tokens_per_skin | |
| z = vae.model.FSQ.indices_to_codes(indices).reshape(m, tokens_per_skin, -1) | |
| # (m, n, 1) | |
| logits = vae.decode(z=z, sampled_cond=cond.repeat(m, 1, 1), cond_tokens=cond_latents.repeat(m, 1, 1)) | |
| skin_pred = logits.reshape(m, logits.shape[1]).permute(1, 0) | |
| skin_list.append(skin_pred) | |
| skin = torch.concat(skin_list, dim=1).float() | |
| for (i, id) in enumerate(oks): | |
| skin_pred = skin[:, batch_mapping==id].reshape(-1, oks_J[i]) | |
| res[id]['skin_pred'] = skin_pred.detach().cpu().numpy() if is_numpy else skin_pred | |
| return res |