File size: 2,578 Bytes
9d7cf7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from dataclasses import dataclass
from torch import Tensor
from typing import Dict, Optional, List, Tuple

import io
import os
import torch

from ..rig_package.info.asset import Asset
from ..model.tokenrig import TokenRig

PORT = 59875
SERVER = f"http://localhost:{PORT}"
TMP_CKPT_DIR = "./tmp_ckpt"

BPY_PORT = 59876
BPY_SERVER = f"http://localhost:{BPY_PORT}"

@dataclass
class TensorPacket:
    """make sure stays on cpu"""
    validate: bool=False
    know_skeleton: bool=False
    learned_mesh_cond: Optional[Tensor]=None
    cond_latents: Optional[Tensor]=None
    mesh_cond: Optional[Tensor]=None
    vertices: Optional[Tensor]=None
    assets: Optional[List[Asset]]=None
    output_ids: Optional[Tensor]=None
    start_embed_list: Optional[List[Tensor]]=None
    start_tokens_list: Optional[List[List[int]]]=None

    def to_device(self, device):
        if self.learned_mesh_cond is not None:
            self.learned_mesh_cond = self.learned_mesh_cond.to(device)
        if self.cond_latents is not None:
            self.cond_latents = self.cond_latents.to(device)
        if self.mesh_cond is not None:
            self.mesh_cond = self.mesh_cond.to(device)
        if self.vertices is not None:
            self.vertices = self.vertices.to(device)
        if self.output_ids is not None:
            self.output_ids = self.output_ids.to(device)
        if self.start_embed_list is not None:
            self.start_embed_list = [x.to(device) for x in self.start_embed_list]

    @property
    def B(self):
        assert self.learned_mesh_cond is not None
        return self.learned_mesh_cond.shape[0]

    def to_bytes(self):
        return object_to_bytes(self)

    @classmethod
    def from_bytes(cls, bytes) -> 'TensorPacket':
        return bytes_to_object(bytes)


def object_to_bytes(t):
    buffer = io.BytesIO()
    torch.save(t, buffer)
    return buffer.getvalue()

def bytes_to_object(b, map_location=None):
    return torch.load(io.BytesIO(b), weights_only=False, map_location=map_location)

def get_model(
    ckpt_path: str,
    hf_path: Optional[str]=None,
    device='cuda',
) -> TokenRig:
    model = TokenRig.load_from_system_checkpoint(checkpoint_path=ckpt_path)
    if hf_path is not None:
        from transformers import AutoModel
        a = AutoModel.from_pretrained(
            hf_path,
            local_files_only=True,
            _attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
        )
        model.transformer.model.load_state_dict(a.state_dict())

    model = model.to(device)
    return model