| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import torch |
| from safetensors.torch import load_file as load_safetensors |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| |
| |
| |
| if False: |
| from .model import BitDance_B as _BD_B_STD |
| from .model import BitDance_H as _BD_H_STD |
| from .model import BitDance_L as _BD_L_STD |
| from .model_parallel import BitDance_B as _BD_B_PAR |
| from .model_parallel import BitDance_H as _BD_H_PAR |
| from .model_parallel import BitDance_L as _BD_L_PAR |
| from .diff_head import DiffHead as _DiffHead |
| from .diff_head_parallel import DiffHead as _DiffHeadParallel |
| from .layers import TransformerBlock as _TB |
| from .layers_parallel import TransformerBlock as _TBP |
| from .qae import VQModel as _VQ |
| from .gfq import GFQ as _GFQ |
| from .sampling import euler_maruyama as _EM |
| from .sampling_parallel import euler_maruyama as _EMP |
| from .utils import patchify_raster as _PR |
|
|
|
|
| class BitDanceImageNetTransformer(ModelMixin, ConfigMixin): |
| @register_to_config |
| def __init__( |
| self, |
| architecture: str, |
| parallel_num: int, |
| resolution: int, |
| down_size: int, |
| latent_dim: int, |
| num_classes: int, |
| runtime_impl: str, |
| parallel_mode: str = "patch", |
| time_schedule: str = "logit_normal", |
| time_shift: float = 1.0, |
| p_std: float = 1.0, |
| p_mean: float = 0.0, |
| ): |
| super().__init__() |
|
|
| kwargs = dict( |
| resolution=resolution, |
| down_size=down_size, |
| patch_size=1, |
| latent_dim=latent_dim, |
| diff_batch_mul=4, |
| cls_token_num=64, |
| num_classes=num_classes, |
| grad_checkpointing=False, |
| trained_vae="", |
| drop_rate=0.0, |
| perturb_schedule="constant", |
| perturb_rate=0.0, |
| perturb_rate_max=0.3, |
| time_schedule=time_schedule, |
| time_shift=time_shift, |
| P_std=p_std, |
| P_mean=p_mean, |
| ) |
|
|
| if runtime_impl == "model_parallel.py" or parallel_num > 1: |
| from .model_parallel import BitDance_B, BitDance_H, BitDance_L |
|
|
| ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
| kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode) |
| else: |
| from .model import BitDance_B, BitDance_H, BitDance_L |
|
|
| ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
|
|
| self.runtime_model = ctors[architecture](**kwargs) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): |
| del args, kwargs |
| model_dir = Path(pretrained_model_name_or_path) |
| config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
| model = cls( |
| architecture=config["architecture"], |
| parallel_num=int(config["parallel_num"]), |
| resolution=int(config["resolution"]), |
| down_size=int(config["down_size"]), |
| latent_dim=int(config["latent_dim"]), |
| num_classes=int(config["num_classes"]), |
| runtime_impl=config["runtime_impl"], |
| parallel_mode=config.get("parallel_mode", "patch"), |
| time_schedule=config.get("time_schedule", "logit_normal"), |
| time_shift=float(config.get("time_shift", 1.0)), |
| p_std=float(config.get("p_std", 1.0)), |
| p_mean=float(config.get("p_mean", 0.0)), |
| ) |
| state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") |
| model.runtime_model.load_state_dict(state, strict=True) |
| model.eval() |
| return model |
|
|
| @torch.no_grad() |
| def sample( |
| self, |
| class_ids: torch.Tensor, |
| sample_steps: int = 100, |
| cfg_scale: float = 4.6, |
| chunk_size: int = 0, |
| ) -> torch.Tensor: |
| return self.runtime_model.sample( |
| cond=class_ids, |
| sample_steps=sample_steps, |
| cfg_scale=cfg_scale, |
| chunk_size=chunk_size, |
| ) |
|
|
| def forward(self, *args, **kwargs): |
| return self.runtime_model(*args, **kwargs) |
|
|