Spaces:
Running on Zero
Running on Zero
| from dataclasses import dataclass | |
| from torch import Tensor | |
| from typing import Dict, Optional, List, Tuple | |
| import io | |
| import os | |
| import torch | |
| from ..rig_package.info.asset import Asset | |
| from ..model.tokenrig import TokenRig | |
| PORT = 59875 | |
| SERVER = f"http://localhost:{PORT}" | |
| TMP_CKPT_DIR = "./tmp_ckpt" | |
| BPY_PORT = 59876 | |
| BPY_SERVER = f"http://localhost:{BPY_PORT}" | |
| class TensorPacket: | |
| """make sure stays on cpu""" | |
| validate: bool=False | |
| know_skeleton: bool=False | |
| learned_mesh_cond: Optional[Tensor]=None | |
| cond_latents: Optional[Tensor]=None | |
| mesh_cond: Optional[Tensor]=None | |
| vertices: Optional[Tensor]=None | |
| assets: Optional[List[Asset]]=None | |
| output_ids: Optional[Tensor]=None | |
| start_embed_list: Optional[List[Tensor]]=None | |
| start_tokens_list: Optional[List[List[int]]]=None | |
| def to_device(self, device): | |
| if self.learned_mesh_cond is not None: | |
| self.learned_mesh_cond = self.learned_mesh_cond.to(device) | |
| if self.cond_latents is not None: | |
| self.cond_latents = self.cond_latents.to(device) | |
| if self.mesh_cond is not None: | |
| self.mesh_cond = self.mesh_cond.to(device) | |
| if self.vertices is not None: | |
| self.vertices = self.vertices.to(device) | |
| if self.output_ids is not None: | |
| self.output_ids = self.output_ids.to(device) | |
| if self.start_embed_list is not None: | |
| self.start_embed_list = [x.to(device) for x in self.start_embed_list] | |
| def B(self): | |
| assert self.learned_mesh_cond is not None | |
| return self.learned_mesh_cond.shape[0] | |
| def to_bytes(self): | |
| return object_to_bytes(self) | |
| def from_bytes(cls, bytes) -> 'TensorPacket': | |
| return bytes_to_object(bytes) | |
| def object_to_bytes(t): | |
| buffer = io.BytesIO() | |
| torch.save(t, buffer) | |
| return buffer.getvalue() | |
| def bytes_to_object(b, map_location=None): | |
| return torch.load(io.BytesIO(b), weights_only=False, map_location=map_location) | |
| def get_model( | |
| ckpt_path: str, | |
| hf_path: Optional[str]=None, | |
| device='cuda', | |
| ) -> TokenRig: | |
| model = TokenRig.load_from_system_checkpoint(checkpoint_path=ckpt_path) | |
| if hf_path is not None: | |
| from transformers import AutoModel | |
| a = AutoModel.from_pretrained( | |
| hf_path, | |
| local_files_only=True, | |
| _attn_implementation="flash_attention_2", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| model.transformer.model.load_state_dict(a.state_dict()) | |
| model = model.to(device) | |
| return model | |