BiliSakura commited on
Commit
cf85d6b
·
verified ·
1 Parent(s): 62cb9d5

Update all files for BitDance-ImageNet-diffusers

Browse files
BitDance_B_16x/modeling_autoencoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict
6
+
7
+ import torch
8
+ from safetensors.torch import load_file as load_safetensors
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+ try:
14
+ from .transformer.qae import VQModel
15
+ except ImportError: # pragma: no cover
16
+ from transformer.qae import VQModel
17
+
18
+
19
+ class BitDanceImageNetAutoencoder(ModelMixin, ConfigMixin):
20
+ @register_to_config
21
+ def __init__(self, ddconfig: Dict[str, Any], num_codebooks: int = 4):
22
+ super().__init__()
23
+ self.runtime_model = VQModel(ddconfig, num_codebooks)
24
+
25
+ @classmethod
26
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
27
+ del args, kwargs
28
+ model_dir = Path(pretrained_model_name_or_path)
29
+ config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
30
+ model = cls(ddconfig=config["ddconfig"], num_codebooks=int(config.get("num_codebooks", 4)))
31
+ state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors")
32
+ model.runtime_model.load_state_dict(state, strict=True)
33
+ model.eval()
34
+ return model
35
+
36
+ def encode(self, x: torch.Tensor):
37
+ return self.runtime_model.encode(x)
38
+
39
+ def decode(self, z: torch.Tensor):
40
+ return self.runtime_model.decode(z)
41
+
42
+ def forward(self, z: torch.Tensor):
43
+ return self.decode(z)