|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
|
try: |
|
|
from .transformer.qae import VQModel |
|
|
except ImportError: |
|
|
from transformer.qae import VQModel |
|
|
|
|
|
|
|
|
class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin): |
|
|
@register_to_config |
|
|
def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4): |
|
|
super().__init__() |
|
|
self.runtime_model = VQModel(ddconfig, num_codebooks) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): |
|
|
del args, kwargs |
|
|
model_dir = Path(pretrained_model_name_or_path) |
|
|
config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
|
|
model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4))) |
|
|
state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") |
|
|
model.runtime_model.load_state_dict(state, strict=True) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def encode(self, x: torch.Tensor): |
|
|
return self.runtime_model.encode(x) |
|
|
|
|
|
def decode(self, z: torch.Tensor): |
|
|
return self.runtime_model.decode(z) |
|
|
|
|
|
def forward(self, z: torch.Tensor): |
|
|
return self.decode(z) |
|
|
|