SkinTokens / src /server /spec.py
pookiefoof's picture
Public release: SkinTokens 路 TokenRig demo
9d7cf7f
from dataclasses import dataclass
from torch import Tensor
from typing import Dict, Optional, List, Tuple
import io
import os
import torch
from ..rig_package.info.asset import Asset
from ..model.tokenrig import TokenRig
PORT = 59875
SERVER = f"http://localhost:{PORT}"
TMP_CKPT_DIR = "./tmp_ckpt"
BPY_PORT = 59876
BPY_SERVER = f"http://localhost:{BPY_PORT}"
@dataclass
class TensorPacket:
"""make sure stays on cpu"""
validate: bool=False
know_skeleton: bool=False
learned_mesh_cond: Optional[Tensor]=None
cond_latents: Optional[Tensor]=None
mesh_cond: Optional[Tensor]=None
vertices: Optional[Tensor]=None
assets: Optional[List[Asset]]=None
output_ids: Optional[Tensor]=None
start_embed_list: Optional[List[Tensor]]=None
start_tokens_list: Optional[List[List[int]]]=None
def to_device(self, device):
if self.learned_mesh_cond is not None:
self.learned_mesh_cond = self.learned_mesh_cond.to(device)
if self.cond_latents is not None:
self.cond_latents = self.cond_latents.to(device)
if self.mesh_cond is not None:
self.mesh_cond = self.mesh_cond.to(device)
if self.vertices is not None:
self.vertices = self.vertices.to(device)
if self.output_ids is not None:
self.output_ids = self.output_ids.to(device)
if self.start_embed_list is not None:
self.start_embed_list = [x.to(device) for x in self.start_embed_list]
@property
def B(self):
assert self.learned_mesh_cond is not None
return self.learned_mesh_cond.shape[0]
def to_bytes(self):
return object_to_bytes(self)
@classmethod
def from_bytes(cls, bytes) -> 'TensorPacket':
return bytes_to_object(bytes)
def object_to_bytes(t):
buffer = io.BytesIO()
torch.save(t, buffer)
return buffer.getvalue()
def bytes_to_object(b, map_location=None):
return torch.load(io.BytesIO(b), weights_only=False, map_location=map_location)
def get_model(
ckpt_path: str,
hf_path: Optional[str]=None,
device='cuda',
) -> TokenRig:
model = TokenRig.load_from_system_checkpoint(checkpoint_path=ckpt_path)
if hf_path is not None:
from transformers import AutoModel
a = AutoModel.from_pretrained(
hf_path,
local_files_only=True,
_attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
model.transformer.model.load_state_dict(a.state_dict())
model = model.to(device)
return model