from __future__ import annotations import json from pathlib import Path from typing import Any, Dict import torch from safetensors.torch import load_file as load_safetensors from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin try: from .transformer.qae import VQModel except ImportError: # pragma: no cover from transformer.qae import VQModel class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin): @register_to_config def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4): super().__init__() self.runtime_model = VQModel(ddconfig, num_codebooks) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): del args, kwargs model_dir = Path(pretrained_model_name_or_path) config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4))) state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") model.runtime_model.load_state_dict(state, strict=True) model.eval() return model def encode(self, x: torch.Tensor): return self.runtime_model.encode(x) def decode(self, z: torch.Tensor): return self.runtime_model.decode(z) def forward(self, z: torch.Tensor): return self.decode(z)