from copy import deepcopy from pathlib import Path from torch import nn, Tensor, FloatTensor from torch.nn.functional import pad from transformers import AutoModelForCausalLM, AutoConfig, LogitsProcessor, LogitsProcessorList # type: ignore from typing import Dict, List, Tuple import math import numpy as np import torch import torch.nn.functional as F LLM_LOCAL_DIR = Path("models/Qwen3-0.6B") from .skin_vae_model import SkinVAEModel from .skin_vae.autoencoders import SkinFSQCVAEModel from .spec import ModelSpec, ModelInput, VaeInput, TokenRigResult from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder from ..rig_package.info.asset import Asset from ..tokenizer.spec import Tokenizer from ..tokenizer.spec import DetokenizeOutput from ..tokenizer.parse import get_tokenizer try: from flash_attn_interface import flash_attn_func # type: ignore except Exception as e: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func def flash_attn_func(*args, **kwargs): res = _flash_attn_func(*args, **kwargs) return res, None class VocabSwitchingLogitsProcessor(LogitsProcessor): def __init__(self, tokenizer: Tokenizer, switch_token_id, eos_token_id, tokens_per_skin, init): # make sure all skin tokens > switch_token_id self.tokenizer = tokenizer self.switch_token_id = switch_token_id self.eos_token_id = eos_token_id self.tokens_per_skin = tokens_per_skin self.init = init def __call__(self, input_ids: Tensor, scores: FloatTensor) -> FloatTensor: # input_ids shape: (batch_size, seq_len) for batch_idx, sequence in enumerate(input_ids): mask = torch.full_like(scores[batch_idx], float('-inf')) sequence = torch.cat([self.init, sequence]) length = len(sequence) if self.switch_token_id in sequence: mask[self.switch_token_id:] = 0 where = torch.where(sequence == self.switch_token_id)[0][:1] J = self.tokenizer.bones_in_sequence(ids=sequence.detach().cpu().numpy()) if (length-where) == J*self.tokens_per_skin: mask[:] = float('-inf') mask[self.eos_token_id] = 0 else: mask[self.eos_token_id] = float('-inf') else: tokens = self.tokenizer.next_posible_token(ids=sequence.detach().cpu().numpy()) mask[tokens] = 0 scores[batch_idx] = scores[batch_idx] + mask return scores class TokenRig(ModelSpec): def __init__(self, model_config, transform_config, tokenizer_config=None): assert tokenizer_config is not None super().__init__(model_config=model_config, transform_config=transform_config, tokenizer_config=tokenizer_config) cfg = self.model_config self.tokens_per_skin: int = cfg['tokens_per_skin'] self.tokens_skin_cond: int = cfg['tokens_skin_cond'] self.use_rope: bool = cfg.get('use_rope', True) self.encode_repeat: int = cfg.get('encode_repeat', 4) self.skin_warmup_start_epoch: int = cfg.get('skin_warmup_start_epoch', 0) self.skin_warmup_end_epoch: int = cfg.get('skin_warmup_end_epoch', -1) self.vae = SkinVAEModel.load_from_system_checkpoint(cfg['pretrained_vae']).to(torch.bfloat16) for param in self.vae.parameters(): param.requires_grad_(False) self.vae.eval() self.mesh_encoder = get_mesh_encoder(**cfg['mesh_encoder']) assert ( isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder) ) self.mesh_encoder = self.mesh_encoder.to(torch.bfloat16) self.tokenizer: Tokenizer = get_tokenizer(**tokenizer_config) # (tokenizer codebook, fsq vae codebook) self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 self.eos = self.vocab_size - 1 _d = cfg['llm'].copy() self.hidden_size = _d['hidden_size'] _d['vocab_size'] = self.vocab_size if LLM_LOCAL_DIR.exists(): _d['pretrained_model_name_or_path'] = str(LLM_LOCAL_DIR) llm_config = AutoConfig.from_pretrained(**_d) self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1 llm_config.torch_dtype = torch.bfloat16 llm_config.pre_norm = True self.llm_config = llm_config self.transformer = AutoModelForCausalLM.from_config(config=llm_config, attn_implementation="flash_attention_2").to(torch.bfloat16) self.output_proj = nn.Sequential( nn.Linear(self.mesh_encoder.width, self.hidden_size), nn.RMSNorm(self.hidden_size), ).to(torch.bfloat16) init_scale = cfg.get('init_scale', None) if init_scale is not None: self.initialize_weights(init_scale) def compile_model(self): self.vae.compile_model() self.transformer = torch.compile(self.transformer, dynamic=False) self.mesh_encoder = torch.compile(self.mesh_encoder, dynamic=False) def initialize_weights(self, s: float): def init_linear(l, stddev): nn.init.normal_(l.weight, std=stddev) if l.bias is not None: nn.init.constant_(l.bias, 0.0) init_scale = s * math.sqrt(1.0 / self.mesh_encoder.width) for m in self.mesh_encoder.modules(): if isinstance(m, nn.Linear): init_linear(m, stddev=init_scale) init_scale = s * math.sqrt(1.0 / self.hidden_size) for m in self.output_proj.modules(): if isinstance(m, nn.Linear): init_linear(m, stddev=init_scale) def get_skin_warmup_rate(self, steps_per_epoch: int) -> float: if self.current_epoch < self.skin_warmup_start_epoch: return 0. if self.current_epoch > self.skin_warmup_end_epoch: return 1. start_steps = self.skin_warmup_start_epoch * steps_per_epoch end_steps = (self.skin_warmup_end_epoch+1) * steps_per_epoch rate = (self.global_step-start_steps) / (end_steps-start_steps) return min(max((1.0-math.cos(math.pi * rate))/2, 0), 1) @torch.autocast(device_type='cuda', dtype=torch.bfloat16) def training_step(self, batch: Dict) -> Dict: raise NotImplementedError() def make_start_tokens(self, **kwargs) -> List[List[int]]: skeleton_tokens = kwargs.get('skeleton_tokens', None) skeleton_mask = kwargs.get('skeleton_mask', None) num_joints = kwargs.get('num_joints', None) parents = kwargs.get('parents', None) cls = kwargs.get('cls', None) start_tokens_list = [] batch_size = 1 if skeleton_tokens is not None: batch_size = len(skeleton_tokens) elif cls is not None: batch_size = len(cls) elif num_joints is not None: batch_size = len(num_joints) elif parents is not None: batch_size = len(parents) else: assert 0, "must provide one of skeleton_tokens, cls, num_joints, parents" for i in range(batch_size): if skeleton_tokens is not None: _skeleton_tokens = skeleton_tokens[i] _skeleton_mask = skeleton_mask[i] if skeleton_mask is not None else None assert _skeleton_tokens[0] == self.tokenizer.bos if skeleton_mask is not None: start_tokens = _skeleton_tokens[_skeleton_mask==1] else: start_tokens = _skeleton_tokens else: start_tokens = [self.tokenizer.bos] start_tokens += self.tokenizer.make_cls_head( cls=cls[i] if cls is not None else None, num_joints=num_joints[i] if num_joints is not None else None, parents=parents[i] if parents is not None else None, ) if isinstance(start_tokens, Tensor): start_tokens = start_tokens.detach().cpu().numpy().tolist() start_tokens_list.append(start_tokens) return start_tokens_list @torch.autocast(device_type='cuda', dtype=torch.bfloat16) def generate( self, vertices: Tensor, normals: Tensor, cls: str|None=None, skeleton_tokens: np.ndarray|Tensor|None=None, only_ids: bool=False, return_decode_dict: bool=False, num_joints: int|None=None, parents: Tensor|None=None, **kwargs, ) -> TokenRigResult: """ Do not support batch! """ assert isinstance(self.vae.model, SkinFSQCVAEModel) assert vertices.dim() == 2, 'do not support batch' assert normals.dim() == 2, 'do not support batch' if isinstance(skeleton_tokens, np.ndarray): skeleton_tokens = torch.from_numpy(skeleton_tokens).to(self.device) cond = torch.cat([vertices, normals], dim=-1).unsqueeze(0) _, cond_latents = self.vae.model._encode( x=None, cond=cond, num_tokens=self.tokens_per_skin, cond_tokens=self.tokens_skin_cond, return_z=False, ) assert cond_latents is not None # (1, len, dim) learned_mesh_cond = encode_mesh_cond(self.mesh_encoder, self.output_proj, self.tokens_skin_cond, {'vertices': vertices, 'normals': normals}) device = cond.device start_tokens = torch.tensor(self.make_start_tokens( device=device, cls=None if cls is None else [cls], skeleton_tokens=None if skeleton_tokens is None else [skeleton_tokens], num_joints=None if num_joints is None else [num_joints], parents=None if parents is None else [parents], )[0], device=device).unsqueeze(0) assert start_tokens.shape[0] == 1 start_embed = self.transformer.get_input_embeddings()(start_tokens) inputs_embeds = torch.cat([learned_mesh_cond, start_embed], dim=1) results = self.transformer.generate( inputs_embeds=inputs_embeds, bos_token_id=self.tokenizer.bos, eos_token_id=self.eos, pad_token_id=self.tokenizer.pad, logits_processor=get_logits_processor( tokenizer=self.tokenizer, eos=self.eos, tokens_per_skin=self.tokens_per_skin, start_tokens=start_tokens[0], ), **kwargs, ) res = TokenRigResult() output_ids = results[0, :] for token in reversed(start_tokens[0]): v = token.item() output_ids = pad(output_ids, (1, 0), value=v) res.input_ids = start_tokens[0] res.output_ids = output_ids if only_ids: return res res.cond = cond[0] res.cond_latents = cond_latents[0] if return_decode_dict: return res d = decode( cond=cond[0], cond_latents=cond_latents[0], inputs_ids=output_ids, tokenizer=self.tokenizer, tokens_per_skin=self.tokens_per_skin, vae=self.vae, ) res.skin_pred = d['skin_pred'] res.detokenize_output = d['detokenize_output'] return res def _debug_export( self, batch: Dict, cond: Tensor, cond_latents: Tensor, inputs_ids: Tensor, id: int=0, path: str='res.fbx', ): if inputs_ids.dim() == 2: assert cond_latents.dim() == cond.dim() == 3, f"Expected 3 dimensions, got {cond_latents.dim()}, {cond.dim()}" cond = cond[id] cond_latents = cond_latents[id] inputs_ids = inputs_ids[id] res = decode( cond=cond, cond_latents=cond_latents, inputs_ids=inputs_ids, tokenizer=self.tokenizer, tokens_per_skin=self.tokens_per_skin, vae=self.vae, ) detokenize_output: DetokenizeOutput = res['detokenize_output'] origin_asset: Asset = batch['model_input'][id].asset asset = Asset.from_data( vertices=origin_asset.vertices, faces=origin_asset.faces, sampled_vertices=batch['vertices'][id].detach().cpu().numpy(), sampled_skin=res['skin_pred'].detach().cpu().numpy(), parents=np.array(detokenize_output.parents), joint_names=detokenize_output.joint_names, joints=detokenize_output.joints, ) from ..rig_package.parser.bpy import BpyParser BpyParser.export_asset(asset, filepath=path) def process_fn(self, batch: List[ModelInput]) -> List[Dict]: res = [] max_length = 0 for b in batch: if b.tokens is not None: max_length = max(max_length, b.tokens.shape[0]) res = [] for b in batch: if b.tokens is not None: skeleton_tokens = np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=self.tokenizer.pad) skeleton_mask = np.pad(np.ones_like(b.tokens), ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=0) else: skeleton_tokens = None skeleton_mask = None _d = { 'vertices': torch.from_numpy(b.asset.sampled_vertices).float(), 'normals': torch.from_numpy(b.asset.sampled_normals).float(), 'non': { 'cls': b.asset.cls, } } if skeleton_mask is not None: _d.update({ 'skeleton_tokens': skeleton_tokens, 'skeleton_mask': skeleton_mask, }) _d['non'].update({ 'parents': b.asset.parents, 'num_bones': b.asset.J, }) if b.asset.sampled_vertex_groups is not None and 'skin' in b.asset.sampled_vertex_groups: assert b.asset.meta is not None _d['non'].update({ 'cls': b.asset.cls, 'uniform_skin': torch.from_numpy(b.asset.sampled_vertex_groups['skin']).float(), 'skin_samples': b.asset.skin_samples, 'dense_indices': b.asset.meta['dense_indices'], 'dense_skin': torch.from_numpy(b.asset.meta['dense_skin']).float(), 'dense_vertices': torch.from_numpy(b.asset.meta['dense_vertices']).float(), 'dense_normals': torch.from_numpy(b.asset.meta['dense_normals']).float(), }) res.append(_d) return res def predict_step( self, batch: Dict, no_cls: bool=False, skeleton_tokens=None, parents=None, num_joints=None, make_asset: bool=False, **kwargs ) -> Dict: vertices: Tensor = batch['vertices'] normals : Tensor = batch['normals'] cls = batch['cls'] generate_kwargs = deepcopy(batch['generate_kwargs']) if vertices.dim() == 2: vertices = vertices.unsqueeze(0) normals = normals.unsqueeze(0) results = [] if skeleton_tokens is None: skeleton_tokens = [None] * vertices.shape[0] d = {} for i in range(vertices.shape[0]): res = self.generate( vertices=vertices[i], normals=normals[i], skeleton_tokens=skeleton_tokens[i], cls=None if no_cls else cls[i], parents=None if parents is None else parents[i], num_joints=None if num_joints is None else num_joints[i], **generate_kwargs ) if make_asset: assert 'model_input' in batch, "need model_input to make asset (in validate/predict mode)" assert res.detokenize_output is not None assert res.skin_pred is not None asset: Asset = batch['model_input'][i].asset.copy() res.asset = Asset.from_data( vertices=asset.vertices, faces=asset.faces, sampled_vertices=vertices[i].detach().float().cpu().numpy(), sampled_skin=res.skin_pred.detach().float().cpu().numpy(), joints=res.detokenize_output.joints, parents=np.array(res.detokenize_output.parents), cls=asset.cls, path=asset.path, ) results.append(res) d['results'] = results return d def forward(self, batch: Dict) -> Dict[str, Tensor]: return self.training_step(batch=batch) def _check(x: Tensor, s, m=None): assert isinstance(s, (list, tuple)), "Expected shape must be a list or tuple" assert x.dim() == len(s), f"Expected {len(s)} dims, got {x.dim()}" for i, (dim_actual, dim_expected) in enumerate(zip(x.shape, s)): if dim_expected is not None and dim_expected != -1: if m is None: assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}" else: assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}. Message: {m}" def encode_mesh_cond(mesh_encoder, output_proj, tokens_skin_cond, batch: Dict) -> Tensor: vertices = batch['vertices'] # (B, N, 3) normals = batch['normals'] # (B, N, 3) assert isinstance(vertices, Tensor) assert isinstance(normals, Tensor) if (len(vertices.shape) == 3): shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices, feats=normals) # type: ignore else: shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) # type: ignore latents = output_proj(latents) return latents @torch.no_grad() def encode( tokenizer: Tokenizer, vae: SkinVAEModel, vae_input: VaeInput, encode_repeat: int, tokens_skin_cond: int, tokens_per_skin: int, ) -> Tuple[Tensor, Tensor, Tensor]: """ Returns: skin_tokens: (B, tokens_per_skin*J) cond_latents: (B, tokens_skin_cond, vae.latent_channels) skin_mask: (B, tokens_per_skin*J), 1 -> skin, 0 -> pad """ device = vae_input.uniform_cond.device B = vae_input.B J = vae_input.max_J _, cond_latents, codes, _ = vae.encode(vae_input=vae_input, num_tokens=tokens_per_skin, full=True, encode_repeat=encode_repeat) codes = codes[:, :tokens_per_skin] indices = vae_input.get_flatten_indices() skin_tokens = torch.full((B, J * tokens_per_skin), tokenizer.pad, dtype=torch.long, device=device) skin_mask = torch.zeros_like(skin_tokens, dtype=torch.long) j_counters = [0 for _ in range(B)] for idx, batch_id in enumerate(indices): j = j_counters[batch_id] s = j * tokens_per_skin t = s + tokens_per_skin skin_tokens[batch_id, s:t] = codes[idx] + tokenizer.vocab_size skin_mask[batch_id, s:t] = 1 j_counters[batch_id] += 1 assert cond_latents is not None _check(cond_latents, (B, tokens_skin_cond, vae.latent_channels)) _check(skin_tokens, (B, J * tokens_per_skin)) _check(skin_mask, (B, J * tokens_per_skin)) return skin_tokens, cond_latents, skin_mask def prepare_llm_tokens( tokenizer: Tokenizer, eos: int, skeleton_tokens: Tensor, skeleton_mask: Tensor, skin_tokens: Tensor, skin_mask: Tensor, cond_latents: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """ Args: skeleton_tokens: (B, n) skeleton_mask: (B, n) skin_tokens: (B, tokens_per_skin*J) skin_mask: (B, tokens_per_skin*J) cond_latents: (B, tokens_skin_cond, vae.latent_channels) Returns: llm_tokens: (B, seq_len) attention_mask: (B, seq_len), 1 -> attend, 0 -> pad """ B = skeleton_tokens.shape[0] inputs_ids = torch.ones((B, skeleton_tokens.shape[1] + skin_tokens.shape[1] + 1), dtype=torch.long, device=skeleton_tokens.device) * tokenizer.pad num_skeleton = skeleton_mask.sum(dim=1) num_skin = skin_mask.sum(dim=1) attention_mask = torch.ones((B, inputs_ids.shape[1]), dtype=torch.float32, device=skeleton_tokens.device) llm_skeleton_mask = torch.ones_like(attention_mask, dtype=torch.bool) llm_skin_mask = torch.ones_like(attention_mask, dtype=torch.bool) for i in range(B): length = num_skeleton[i] + num_skin[i] inputs_ids[i, :num_skeleton[i]] = skeleton_tokens[i, :num_skeleton[i]] inputs_ids[i, num_skeleton[i]:num_skeleton[i]+num_skin[i]] = skin_tokens[i, :num_skin[i]] inputs_ids[i, num_skeleton[i]+num_skin[i]] = eos # add an eos attention_mask[i, length+1:] = 0. llm_skeleton_mask[i, num_skeleton[i]:] = 0 llm_skin_mask[i, :num_skeleton[i]] = 0 llm_skin_mask[i, length+1:] = 0 seq_len = inputs_ids.shape[1] _check(inputs_ids, (B, seq_len)) _check(attention_mask, (B, seq_len)) return inputs_ids, attention_mask, llm_skeleton_mask, llm_skin_mask def get_logits_processor(tokenizer: Tokenizer, eos: int, tokens_per_skin: int, start_tokens): processor = VocabSwitchingLogitsProcessor( tokenizer=tokenizer, switch_token_id=tokenizer.eos, eos_token_id=eos, tokens_per_skin=tokens_per_skin, init=start_tokens, ) return LogitsProcessorList([processor]) @torch.no_grad() def decode( cond: Tensor, cond_latents: Tensor, inputs_ids: Tensor, tokenizer: Tokenizer, tokens_per_skin: int, vae: SkinVAEModel, encode_repeat: int=1, ) -> Dict: """ inputs_ids: (seq_len) cond: (N, c) cond_latents: (tokens_skin_cond, dim) """ assert cond.dim() == 2, 'do not support batch' assert cond_latents.dim() == 2, 'do not support batch' where_eos = torch.where(inputs_ids == tokenizer.eos) if where_eos[0].shape[0] == 0: raise ValueError("No EOS token found in inputs_ids") where_eos = where_eos[0][:1] skeleton_tokens = inputs_ids[:where_eos+1] skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) J = detokenize_output.joints.shape[0] skin_tokens = inputs_ids[where_eos+1:where_eos+1+J*tokens_per_skin] if skin_tokens.shape != (J*(tokens_per_skin),): return { 'skin_pred': None, 'detokenize_output': detokenize_output, } cond = cond.unsqueeze(0) cond_latents = cond_latents.unsqueeze(0) skin = [] g = tokens_per_skin * encode_repeat for s in range(0, J*tokens_per_skin, g): t = min(s+g, J*tokens_per_skin) indices = skin_tokens[s:t].unsqueeze(0) - tokenizer.vocab_size # expect: (b, tokens_per_skin, dim) b = (t-s)//tokens_per_skin z = vae.model.FSQ.indices_to_codes(indices).reshape(b, tokens_per_skin, -1) # (b, n, 1) logits = vae.decode(z=z, sampled_cond=cond.repeat(b, 1, 1), cond_tokens=cond_latents.repeat(b, 1, 1)) skin_pred = logits.reshape(b, logits.shape[1]).permute(1, 0) skin.append(skin_pred) skin = torch.concat(skin, dim=1).float() return { 'skin_pred': skin, 'detokenize_output': detokenize_output, } @torch.no_grad() def decode_multi( cond: Tensor, cond_latents: Tensor, inputs_ids: List[Tensor], tokenizer: Tokenizer, tokens_per_skin: int, vae: SkinVAEModel, is_numpy: bool=True, encode_repeat: int=1, ) -> List[Dict]: """ inputs_ids: List[(seq_len)] cond: (N, c) cond_latents: (tokens_skin_cond, dim) """ assert cond.dim() == 2, 'do not support batch' assert cond_latents.dim() == 2, 'do not support batch' B = len(inputs_ids) res = [{'skin_pred': None, 'detokenize_output': None} for _ in range(B)] device = cond.device batch_mapping = [] skin_tokens_list = [] oks = [] oks_J = [] for i in range(B): where_eos = torch.where(inputs_ids[i] == tokenizer.eos) if where_eos[0].shape[0] == 0: print(f"decode_multi: {i} has bad skeleton") continue where_eos = where_eos[0][:1] skeleton_tokens = inputs_ids[i][:where_eos+1] skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy()) try: detokenize_output = tokenizer.detokenize(ids=skeleton_tokens) except Exception as e: print(f"decode_multi: error while decoding skeleton: {str(e)}") continue J = detokenize_output.joints.shape[0] res[i]['detokenize_output'] = detokenize_output # type: ignore skin_tokens = inputs_ids[i][where_eos+1:where_eos+1+J*tokens_per_skin] if skin_tokens.shape != (J*(tokens_per_skin),): print(f"decode_multi: {i} has bad skin") continue batch_mapping.append(torch.full((J,), i, device=device, dtype=torch.long)) skin_tokens_list.append(skin_tokens) oks.append(i) oks_J.append(J) if len(batch_mapping) == 0: return res batch_mapping = torch.cat(batch_mapping, dim=0) # (1, sum_J*tokens_per_skin) skin_tokens = torch.cat(skin_tokens_list, dim=0).unsqueeze(0) cond = cond.unsqueeze(0) cond_latents = cond_latents.unsqueeze(0) skin_list = [] g = tokens_per_skin * encode_repeat sum_J = batch_mapping.shape[0] for s in range(0, sum_J*tokens_per_skin, g): t = min(s+g, sum_J*tokens_per_skin) # (1, m*tokens_per_skin) indices = skin_tokens[:, s:t] - tokenizer.vocab_size # expect: (m, tokens_per_skin, dim) m = (t-s)//tokens_per_skin z = vae.model.FSQ.indices_to_codes(indices).reshape(m, tokens_per_skin, -1) # (m, n, 1) logits = vae.decode(z=z, sampled_cond=cond.repeat(m, 1, 1), cond_tokens=cond_latents.repeat(m, 1, 1)) skin_pred = logits.reshape(m, logits.shape[1]).permute(1, 0) skin_list.append(skin_pred) skin = torch.concat(skin_list, dim=1).float() for (i, id) in enumerate(oks): skin_pred = skin[:, batch_mapping==id].reshape(-1, oks_J[i]) res[id]['skin_pred'] = skin_pred.detach().cpu().numpy() if is_numpy else skin_pred return res