SkinTokens / src /model /tokenrig.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from copy import deepcopy
from pathlib import Path
from torch import nn, Tensor, FloatTensor
from torch.nn.functional import pad
from transformers import AutoModelForCausalLM, AutoConfig, LogitsProcessor, LogitsProcessorList # type: ignore
from typing import Dict, List, Tuple
import math
import numpy as np
import torch
import torch.nn.functional as F
LLM_LOCAL_DIR = Path("models/Qwen3-0.6B")
from .skin_vae_model import SkinVAEModel
from .skin_vae.autoencoders import SkinFSQCVAEModel
from .spec import ModelSpec, ModelInput, VaeInput, TokenRigResult
from .parse_encoder import MAP_MESH_ENCODER, get_mesh_encoder
from ..rig_package.info.asset import Asset
from ..tokenizer.spec import Tokenizer
from ..tokenizer.spec import DetokenizeOutput
from ..tokenizer.parse import get_tokenizer
try:
from flash_attn_interface import flash_attn_func # type: ignore
except Exception as e:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
def flash_attn_func(*args, **kwargs):
res = _flash_attn_func(*args, **kwargs)
return res, None
class VocabSwitchingLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: Tokenizer, switch_token_id, eos_token_id, tokens_per_skin, init):
# make sure all skin tokens > switch_token_id
self.tokenizer = tokenizer
self.switch_token_id = switch_token_id
self.eos_token_id = eos_token_id
self.tokens_per_skin = tokens_per_skin
self.init = init
def __call__(self, input_ids: Tensor, scores: FloatTensor) -> FloatTensor:
# input_ids shape: (batch_size, seq_len)
for batch_idx, sequence in enumerate(input_ids):
mask = torch.full_like(scores[batch_idx], float('-inf'))
sequence = torch.cat([self.init, sequence])
length = len(sequence)
if self.switch_token_id in sequence:
mask[self.switch_token_id:] = 0
where = torch.where(sequence == self.switch_token_id)[0][:1]
J = self.tokenizer.bones_in_sequence(ids=sequence.detach().cpu().numpy())
if (length-where) == J*self.tokens_per_skin:
mask[:] = float('-inf')
mask[self.eos_token_id] = 0
else:
mask[self.eos_token_id] = float('-inf')
else:
tokens = self.tokenizer.next_posible_token(ids=sequence.detach().cpu().numpy())
mask[tokens] = 0
scores[batch_idx] = scores[batch_idx] + mask
return scores
class TokenRig(ModelSpec):
def __init__(self, model_config, transform_config, tokenizer_config=None):
assert tokenizer_config is not None
super().__init__(model_config=model_config, transform_config=transform_config, tokenizer_config=tokenizer_config)
cfg = self.model_config
self.tokens_per_skin: int = cfg['tokens_per_skin']
self.tokens_skin_cond: int = cfg['tokens_skin_cond']
self.use_rope: bool = cfg.get('use_rope', True)
self.encode_repeat: int = cfg.get('encode_repeat', 4)
self.skin_warmup_start_epoch: int = cfg.get('skin_warmup_start_epoch', 0)
self.skin_warmup_end_epoch: int = cfg.get('skin_warmup_end_epoch', -1)
self.vae = SkinVAEModel.load_from_system_checkpoint(cfg['pretrained_vae']).to(torch.bfloat16)
for param in self.vae.parameters():
param.requires_grad_(False)
self.vae.eval()
self.mesh_encoder = get_mesh_encoder(**cfg['mesh_encoder'])
assert (
isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo) or
isinstance(self.mesh_encoder, MAP_MESH_ENCODER.michelangelo_encoder)
)
self.mesh_encoder = self.mesh_encoder.to(torch.bfloat16)
self.tokenizer: Tokenizer = get_tokenizer(**tokenizer_config)
# (tokenizer codebook, fsq vae codebook)
self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1
self.eos = self.vocab_size - 1
_d = cfg['llm'].copy()
self.hidden_size = _d['hidden_size']
_d['vocab_size'] = self.vocab_size
if LLM_LOCAL_DIR.exists():
_d['pretrained_model_name_or_path'] = str(LLM_LOCAL_DIR)
llm_config = AutoConfig.from_pretrained(**_d)
self.vocab_size = self.tokenizer.vocab_size + self.vae.vocab_size + 1
llm_config.torch_dtype = torch.bfloat16
llm_config.pre_norm = True
self.llm_config = llm_config
self.transformer = AutoModelForCausalLM.from_config(config=llm_config, attn_implementation="flash_attention_2").to(torch.bfloat16)
self.output_proj = nn.Sequential(
nn.Linear(self.mesh_encoder.width, self.hidden_size),
nn.RMSNorm(self.hidden_size),
).to(torch.bfloat16)
init_scale = cfg.get('init_scale', None)
if init_scale is not None:
self.initialize_weights(init_scale)
def compile_model(self):
self.vae.compile_model()
self.transformer = torch.compile(self.transformer, dynamic=False)
self.mesh_encoder = torch.compile(self.mesh_encoder, dynamic=False)
def initialize_weights(self, s: float):
def init_linear(l, stddev):
nn.init.normal_(l.weight, std=stddev)
if l.bias is not None:
nn.init.constant_(l.bias, 0.0)
init_scale = s * math.sqrt(1.0 / self.mesh_encoder.width)
for m in self.mesh_encoder.modules():
if isinstance(m, nn.Linear):
init_linear(m, stddev=init_scale)
init_scale = s * math.sqrt(1.0 / self.hidden_size)
for m in self.output_proj.modules():
if isinstance(m, nn.Linear):
init_linear(m, stddev=init_scale)
def get_skin_warmup_rate(self, steps_per_epoch: int) -> float:
if self.current_epoch < self.skin_warmup_start_epoch:
return 0.
if self.current_epoch > self.skin_warmup_end_epoch:
return 1.
start_steps = self.skin_warmup_start_epoch * steps_per_epoch
end_steps = (self.skin_warmup_end_epoch+1) * steps_per_epoch
rate = (self.global_step-start_steps) / (end_steps-start_steps)
return min(max((1.0-math.cos(math.pi * rate))/2, 0), 1)
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
def training_step(self, batch: Dict) -> Dict:
raise NotImplementedError()
def make_start_tokens(self, **kwargs) -> List[List[int]]:
skeleton_tokens = kwargs.get('skeleton_tokens', None)
skeleton_mask = kwargs.get('skeleton_mask', None)
num_joints = kwargs.get('num_joints', None)
parents = kwargs.get('parents', None)
cls = kwargs.get('cls', None)
start_tokens_list = []
batch_size = 1
if skeleton_tokens is not None:
batch_size = len(skeleton_tokens)
elif cls is not None:
batch_size = len(cls)
elif num_joints is not None:
batch_size = len(num_joints)
elif parents is not None:
batch_size = len(parents)
else:
assert 0, "must provide one of skeleton_tokens, cls, num_joints, parents"
for i in range(batch_size):
if skeleton_tokens is not None:
_skeleton_tokens = skeleton_tokens[i]
_skeleton_mask = skeleton_mask[i] if skeleton_mask is not None else None
assert _skeleton_tokens[0] == self.tokenizer.bos
if skeleton_mask is not None:
start_tokens = _skeleton_tokens[_skeleton_mask==1]
else:
start_tokens = _skeleton_tokens
else:
start_tokens = [self.tokenizer.bos]
start_tokens += self.tokenizer.make_cls_head(
cls=cls[i] if cls is not None else None,
num_joints=num_joints[i] if num_joints is not None else None,
parents=parents[i] if parents is not None else None,
)
if isinstance(start_tokens, Tensor):
start_tokens = start_tokens.detach().cpu().numpy().tolist()
start_tokens_list.append(start_tokens)
return start_tokens_list
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
def generate(
self,
vertices: Tensor,
normals: Tensor,
cls: str|None=None,
skeleton_tokens: np.ndarray|Tensor|None=None,
only_ids: bool=False,
return_decode_dict: bool=False,
num_joints: int|None=None,
parents: Tensor|None=None,
**kwargs,
) -> TokenRigResult:
"""
Do not support batch!
"""
assert isinstance(self.vae.model, SkinFSQCVAEModel)
assert vertices.dim() == 2, 'do not support batch'
assert normals.dim() == 2, 'do not support batch'
if isinstance(skeleton_tokens, np.ndarray):
skeleton_tokens = torch.from_numpy(skeleton_tokens).to(self.device)
cond = torch.cat([vertices, normals], dim=-1).unsqueeze(0)
_, cond_latents = self.vae.model._encode(
x=None,
cond=cond,
num_tokens=self.tokens_per_skin,
cond_tokens=self.tokens_skin_cond,
return_z=False,
)
assert cond_latents is not None
# (1, len, dim)
learned_mesh_cond = encode_mesh_cond(self.mesh_encoder, self.output_proj, self.tokens_skin_cond, {'vertices': vertices, 'normals': normals})
device = cond.device
start_tokens = torch.tensor(self.make_start_tokens(
device=device,
cls=None if cls is None else [cls],
skeleton_tokens=None if skeleton_tokens is None else [skeleton_tokens],
num_joints=None if num_joints is None else [num_joints],
parents=None if parents is None else [parents],
)[0], device=device).unsqueeze(0)
assert start_tokens.shape[0] == 1
start_embed = self.transformer.get_input_embeddings()(start_tokens)
inputs_embeds = torch.cat([learned_mesh_cond, start_embed], dim=1)
results = self.transformer.generate(
inputs_embeds=inputs_embeds,
bos_token_id=self.tokenizer.bos,
eos_token_id=self.eos,
pad_token_id=self.tokenizer.pad,
logits_processor=get_logits_processor(
tokenizer=self.tokenizer,
eos=self.eos,
tokens_per_skin=self.tokens_per_skin,
start_tokens=start_tokens[0],
),
**kwargs,
)
res = TokenRigResult()
output_ids = results[0, :]
for token in reversed(start_tokens[0]):
v = token.item()
output_ids = pad(output_ids, (1, 0), value=v)
res.input_ids = start_tokens[0]
res.output_ids = output_ids
if only_ids:
return res
res.cond = cond[0]
res.cond_latents = cond_latents[0]
if return_decode_dict:
return res
d = decode(
cond=cond[0],
cond_latents=cond_latents[0],
inputs_ids=output_ids,
tokenizer=self.tokenizer,
tokens_per_skin=self.tokens_per_skin,
vae=self.vae,
)
res.skin_pred = d['skin_pred']
res.detokenize_output = d['detokenize_output']
return res
def _debug_export(
self,
batch: Dict,
cond: Tensor,
cond_latents: Tensor,
inputs_ids: Tensor,
id: int=0,
path: str='res.fbx',
):
if inputs_ids.dim() == 2:
assert cond_latents.dim() == cond.dim() == 3, f"Expected 3 dimensions, got {cond_latents.dim()}, {cond.dim()}"
cond = cond[id]
cond_latents = cond_latents[id]
inputs_ids = inputs_ids[id]
res = decode(
cond=cond,
cond_latents=cond_latents,
inputs_ids=inputs_ids,
tokenizer=self.tokenizer,
tokens_per_skin=self.tokens_per_skin,
vae=self.vae,
)
detokenize_output: DetokenizeOutput = res['detokenize_output']
origin_asset: Asset = batch['model_input'][id].asset
asset = Asset.from_data(
vertices=origin_asset.vertices,
faces=origin_asset.faces,
sampled_vertices=batch['vertices'][id].detach().cpu().numpy(),
sampled_skin=res['skin_pred'].detach().cpu().numpy(),
parents=np.array(detokenize_output.parents),
joint_names=detokenize_output.joint_names,
joints=detokenize_output.joints,
)
from ..rig_package.parser.bpy import BpyParser
BpyParser.export_asset(asset, filepath=path)
def process_fn(self, batch: List[ModelInput]) -> List[Dict]:
res = []
max_length = 0
for b in batch:
if b.tokens is not None:
max_length = max(max_length, b.tokens.shape[0])
res = []
for b in batch:
if b.tokens is not None:
skeleton_tokens = np.pad(b.tokens, ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=self.tokenizer.pad)
skeleton_mask = np.pad(np.ones_like(b.tokens), ((0, max_length-b.tokens.shape[0])), 'constant', constant_values=0)
else:
skeleton_tokens = None
skeleton_mask = None
_d = {
'vertices': torch.from_numpy(b.asset.sampled_vertices).float(),
'normals': torch.from_numpy(b.asset.sampled_normals).float(),
'non': {
'cls': b.asset.cls,
}
}
if skeleton_mask is not None:
_d.update({
'skeleton_tokens': skeleton_tokens,
'skeleton_mask': skeleton_mask,
})
_d['non'].update({
'parents': b.asset.parents,
'num_bones': b.asset.J,
})
if b.asset.sampled_vertex_groups is not None and 'skin' in b.asset.sampled_vertex_groups:
assert b.asset.meta is not None
_d['non'].update({
'cls': b.asset.cls,
'uniform_skin': torch.from_numpy(b.asset.sampled_vertex_groups['skin']).float(),
'skin_samples': b.asset.skin_samples,
'dense_indices': b.asset.meta['dense_indices'],
'dense_skin': torch.from_numpy(b.asset.meta['dense_skin']).float(),
'dense_vertices': torch.from_numpy(b.asset.meta['dense_vertices']).float(),
'dense_normals': torch.from_numpy(b.asset.meta['dense_normals']).float(),
})
res.append(_d)
return res
def predict_step(
self,
batch: Dict,
no_cls: bool=False,
skeleton_tokens=None,
parents=None,
num_joints=None,
make_asset: bool=False,
**kwargs
) -> Dict:
vertices: Tensor = batch['vertices']
normals : Tensor = batch['normals']
cls = batch['cls']
generate_kwargs = deepcopy(batch['generate_kwargs'])
if vertices.dim() == 2:
vertices = vertices.unsqueeze(0)
normals = normals.unsqueeze(0)
results = []
if skeleton_tokens is None:
skeleton_tokens = [None] * vertices.shape[0]
d = {}
for i in range(vertices.shape[0]):
res = self.generate(
vertices=vertices[i],
normals=normals[i],
skeleton_tokens=skeleton_tokens[i],
cls=None if no_cls else cls[i],
parents=None if parents is None else parents[i],
num_joints=None if num_joints is None else num_joints[i],
**generate_kwargs
)
if make_asset:
assert 'model_input' in batch, "need model_input to make asset (in validate/predict mode)"
assert res.detokenize_output is not None
assert res.skin_pred is not None
asset: Asset = batch['model_input'][i].asset.copy()
res.asset = Asset.from_data(
vertices=asset.vertices,
faces=asset.faces,
sampled_vertices=vertices[i].detach().float().cpu().numpy(),
sampled_skin=res.skin_pred.detach().float().cpu().numpy(),
joints=res.detokenize_output.joints,
parents=np.array(res.detokenize_output.parents),
cls=asset.cls,
path=asset.path,
)
results.append(res)
d['results'] = results
return d
def forward(self, batch: Dict) -> Dict[str, Tensor]:
return self.training_step(batch=batch)
def _check(x: Tensor, s, m=None):
assert isinstance(s, (list, tuple)), "Expected shape must be a list or tuple"
assert x.dim() == len(s), f"Expected {len(s)} dims, got {x.dim()}"
for i, (dim_actual, dim_expected) in enumerate(zip(x.shape, s)):
if dim_expected is not None and dim_expected != -1:
if m is None:
assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}"
else:
assert dim_actual == dim_expected, f"Shape mismatch at dim {i}: expected {dim_expected}, got {dim_actual}. Message: {m}"
def encode_mesh_cond(mesh_encoder, output_proj, tokens_skin_cond, batch: Dict) -> Tensor:
vertices = batch['vertices'] # (B, N, 3)
normals = batch['normals'] # (B, N, 3)
assert isinstance(vertices, Tensor)
assert isinstance(normals, Tensor)
if (len(vertices.shape) == 3):
shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices, feats=normals) # type: ignore
else:
shape_embed, latents, token_num, pre_pc = mesh_encoder.encode_latents(pc=vertices.unsqueeze(0), feats=normals.unsqueeze(0)) # type: ignore
latents = output_proj(latents)
return latents
@torch.no_grad()
def encode(
tokenizer: Tokenizer,
vae: SkinVAEModel,
vae_input: VaeInput,
encode_repeat: int,
tokens_skin_cond: int,
tokens_per_skin: int,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Returns:
skin_tokens: (B, tokens_per_skin*J)
cond_latents: (B, tokens_skin_cond, vae.latent_channels)
skin_mask: (B, tokens_per_skin*J), 1 -> skin, 0 -> pad
"""
device = vae_input.uniform_cond.device
B = vae_input.B
J = vae_input.max_J
_, cond_latents, codes, _ = vae.encode(vae_input=vae_input, num_tokens=tokens_per_skin, full=True, encode_repeat=encode_repeat)
codes = codes[:, :tokens_per_skin]
indices = vae_input.get_flatten_indices()
skin_tokens = torch.full((B, J * tokens_per_skin), tokenizer.pad, dtype=torch.long, device=device)
skin_mask = torch.zeros_like(skin_tokens, dtype=torch.long)
j_counters = [0 for _ in range(B)]
for idx, batch_id in enumerate(indices):
j = j_counters[batch_id]
s = j * tokens_per_skin
t = s + tokens_per_skin
skin_tokens[batch_id, s:t] = codes[idx] + tokenizer.vocab_size
skin_mask[batch_id, s:t] = 1
j_counters[batch_id] += 1
assert cond_latents is not None
_check(cond_latents, (B, tokens_skin_cond, vae.latent_channels))
_check(skin_tokens, (B, J * tokens_per_skin))
_check(skin_mask, (B, J * tokens_per_skin))
return skin_tokens, cond_latents, skin_mask
def prepare_llm_tokens(
tokenizer: Tokenizer,
eos: int,
skeleton_tokens: Tensor,
skeleton_mask: Tensor,
skin_tokens: Tensor,
skin_mask: Tensor,
cond_latents: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Args:
skeleton_tokens: (B, n)
skeleton_mask: (B, n)
skin_tokens: (B, tokens_per_skin*J)
skin_mask: (B, tokens_per_skin*J)
cond_latents: (B, tokens_skin_cond, vae.latent_channels)
Returns:
llm_tokens: (B, seq_len)
attention_mask: (B, seq_len), 1 -> attend, 0 -> pad
"""
B = skeleton_tokens.shape[0]
inputs_ids = torch.ones((B, skeleton_tokens.shape[1] + skin_tokens.shape[1] + 1), dtype=torch.long, device=skeleton_tokens.device) * tokenizer.pad
num_skeleton = skeleton_mask.sum(dim=1)
num_skin = skin_mask.sum(dim=1)
attention_mask = torch.ones((B, inputs_ids.shape[1]), dtype=torch.float32, device=skeleton_tokens.device)
llm_skeleton_mask = torch.ones_like(attention_mask, dtype=torch.bool)
llm_skin_mask = torch.ones_like(attention_mask, dtype=torch.bool)
for i in range(B):
length = num_skeleton[i] + num_skin[i]
inputs_ids[i, :num_skeleton[i]] = skeleton_tokens[i, :num_skeleton[i]]
inputs_ids[i, num_skeleton[i]:num_skeleton[i]+num_skin[i]] = skin_tokens[i, :num_skin[i]]
inputs_ids[i, num_skeleton[i]+num_skin[i]] = eos # add an eos
attention_mask[i, length+1:] = 0.
llm_skeleton_mask[i, num_skeleton[i]:] = 0
llm_skin_mask[i, :num_skeleton[i]] = 0
llm_skin_mask[i, length+1:] = 0
seq_len = inputs_ids.shape[1]
_check(inputs_ids, (B, seq_len))
_check(attention_mask, (B, seq_len))
return inputs_ids, attention_mask, llm_skeleton_mask, llm_skin_mask
def get_logits_processor(tokenizer: Tokenizer, eos: int, tokens_per_skin: int, start_tokens):
processor = VocabSwitchingLogitsProcessor(
tokenizer=tokenizer,
switch_token_id=tokenizer.eos,
eos_token_id=eos,
tokens_per_skin=tokens_per_skin,
init=start_tokens,
)
return LogitsProcessorList([processor])
@torch.no_grad()
def decode(
cond: Tensor,
cond_latents: Tensor,
inputs_ids: Tensor,
tokenizer: Tokenizer,
tokens_per_skin: int,
vae: SkinVAEModel,
encode_repeat: int=1,
) -> Dict:
"""
inputs_ids: (seq_len)
cond: (N, c)
cond_latents: (tokens_skin_cond, dim)
"""
assert cond.dim() == 2, 'do not support batch'
assert cond_latents.dim() == 2, 'do not support batch'
where_eos = torch.where(inputs_ids == tokenizer.eos)
if where_eos[0].shape[0] == 0:
raise ValueError("No EOS token found in inputs_ids")
where_eos = where_eos[0][:1]
skeleton_tokens = inputs_ids[:where_eos+1]
skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy())
detokenize_output = tokenizer.detokenize(ids=skeleton_tokens)
J = detokenize_output.joints.shape[0]
skin_tokens = inputs_ids[where_eos+1:where_eos+1+J*tokens_per_skin]
if skin_tokens.shape != (J*(tokens_per_skin),):
return {
'skin_pred': None,
'detokenize_output': detokenize_output,
}
cond = cond.unsqueeze(0)
cond_latents = cond_latents.unsqueeze(0)
skin = []
g = tokens_per_skin * encode_repeat
for s in range(0, J*tokens_per_skin, g):
t = min(s+g, J*tokens_per_skin)
indices = skin_tokens[s:t].unsqueeze(0) - tokenizer.vocab_size
# expect: (b, tokens_per_skin, dim)
b = (t-s)//tokens_per_skin
z = vae.model.FSQ.indices_to_codes(indices).reshape(b, tokens_per_skin, -1)
# (b, n, 1)
logits = vae.decode(z=z, sampled_cond=cond.repeat(b, 1, 1), cond_tokens=cond_latents.repeat(b, 1, 1))
skin_pred = logits.reshape(b, logits.shape[1]).permute(1, 0)
skin.append(skin_pred)
skin = torch.concat(skin, dim=1).float()
return {
'skin_pred': skin,
'detokenize_output': detokenize_output,
}
@torch.no_grad()
def decode_multi(
cond: Tensor,
cond_latents: Tensor,
inputs_ids: List[Tensor],
tokenizer: Tokenizer,
tokens_per_skin: int,
vae: SkinVAEModel,
is_numpy: bool=True,
encode_repeat: int=1,
) -> List[Dict]:
"""
inputs_ids: List[(seq_len)]
cond: (N, c)
cond_latents: (tokens_skin_cond, dim)
"""
assert cond.dim() == 2, 'do not support batch'
assert cond_latents.dim() == 2, 'do not support batch'
B = len(inputs_ids)
res = [{'skin_pred': None, 'detokenize_output': None} for _ in range(B)]
device = cond.device
batch_mapping = []
skin_tokens_list = []
oks = []
oks_J = []
for i in range(B):
where_eos = torch.where(inputs_ids[i] == tokenizer.eos)
if where_eos[0].shape[0] == 0:
print(f"decode_multi: {i} has bad skeleton")
continue
where_eos = where_eos[0][:1]
skeleton_tokens = inputs_ids[i][:where_eos+1]
skeleton_tokens = np.array(skeleton_tokens.detach().cpu().numpy())
try:
detokenize_output = tokenizer.detokenize(ids=skeleton_tokens)
except Exception as e:
print(f"decode_multi: error while decoding skeleton: {str(e)}")
continue
J = detokenize_output.joints.shape[0]
res[i]['detokenize_output'] = detokenize_output # type: ignore
skin_tokens = inputs_ids[i][where_eos+1:where_eos+1+J*tokens_per_skin]
if skin_tokens.shape != (J*(tokens_per_skin),):
print(f"decode_multi: {i} has bad skin")
continue
batch_mapping.append(torch.full((J,), i, device=device, dtype=torch.long))
skin_tokens_list.append(skin_tokens)
oks.append(i)
oks_J.append(J)
if len(batch_mapping) == 0:
return res
batch_mapping = torch.cat(batch_mapping, dim=0)
# (1, sum_J*tokens_per_skin)
skin_tokens = torch.cat(skin_tokens_list, dim=0).unsqueeze(0)
cond = cond.unsqueeze(0)
cond_latents = cond_latents.unsqueeze(0)
skin_list = []
g = tokens_per_skin * encode_repeat
sum_J = batch_mapping.shape[0]
for s in range(0, sum_J*tokens_per_skin, g):
t = min(s+g, sum_J*tokens_per_skin)
# (1, m*tokens_per_skin)
indices = skin_tokens[:, s:t] - tokenizer.vocab_size
# expect: (m, tokens_per_skin, dim)
m = (t-s)//tokens_per_skin
z = vae.model.FSQ.indices_to_codes(indices).reshape(m, tokens_per_skin, -1)
# (m, n, 1)
logits = vae.decode(z=z, sampled_cond=cond.repeat(m, 1, 1), cond_tokens=cond_latents.repeat(m, 1, 1))
skin_pred = logits.reshape(m, logits.shape[1]).permute(1, 0)
skin_list.append(skin_pred)
skin = torch.concat(skin_list, dim=1).float()
for (i, id) in enumerate(oks):
skin_pred = skin[:, batch_mapping==id].reshape(-1, oks_J[i])
res[id]['skin_pred'] = skin_pred.detach().cpu().numpy() if is_numpy else skin_pred
return res