File size: 1,478 Bytes
fc1f31d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict

import torch
from safetensors.torch import load_file as load_safetensors

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin

try:
    from .transformer.qae import VQModel
except ImportError:  # pragma: no cover
    from transformer.qae import VQModel


class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4):
        super().__init__()
        self.runtime_model = VQModel(ddconfig, num_codebooks)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
        del args, kwargs
        model_dir = Path(pretrained_model_name_or_path)
        config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
        model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4)))
        state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors")
        model.runtime_model.load_state_dict(state, strict=True)
        model.eval()
        return model

    def encode(self, x: torch.Tensor):
        return self.runtime_model.encode(x)

    def decode(self, z: torch.Tensor):
        return self.runtime_model.decode(z)

    def forward(self, z: torch.Tensor):
        return self.decode(z)