JiT-diffusers / jit_diffusers /modeling_jit_transformer_2d.py
BiliSakura's picture
Add files using upload-large-folder tool
2a4c86a verified
from __future__ import annotations
import argparse
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, Literal, Tuple
import torch
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from .modeling_jit_backbone import JiT_models
def _extract_module_state_dict(
state_dict: Dict[str, torch.Tensor], prefixes: Tuple[str, ...] = ("transformer.", "net.")
) -> Dict[str, torch.Tensor]:
"""Extract module state by stripping the first fully-matching prefix.
Prefix precedence is left-to-right; `"transformer."` is preferred over legacy `"net."`.
"""
for prefix in prefixes:
if all(key.startswith(prefix) for key in state_dict.keys()):
return {k[len(prefix):]: v for k, v in state_dict.items()}
return state_dict
def _build_jit_kwargs(
image_size: int,
num_classes: int,
attn_dropout: float,
proj_dropout: float,
model_name: str | None = None,
) -> Dict[str, object]:
# Keep model_name for backward-compatible internal call signatures.
_ = model_name
return {
"input_size": image_size,
"in_channels": 3,
"num_classes": num_classes,
"attn_drop": attn_dropout,
"proj_drop": proj_dropout,
}
@dataclass
class JiTCheckpointConfig:
model_name: str
image_size: int
num_classes: int
attn_dropout: float
proj_dropout: float
def _config_from_checkpoint(ckpt_args: argparse.Namespace | Mapping[str, Any]) -> JiTCheckpointConfig:
if isinstance(ckpt_args, argparse.Namespace):
args_dict = vars(ckpt_args)
elif isinstance(ckpt_args, Mapping):
args_dict = ckpt_args
else:
raise TypeError(f"Unsupported checkpoint args type: {type(ckpt_args)}")
def _get_first_available(*keys: str, default=None):
for key in keys:
if key in args_dict and args_dict[key] is not None:
return args_dict[key]
return default
model_name = _get_first_available("model", "model_name", "model_type")
image_size = _get_first_available("img_size", "image_size", "sample_size")
num_classes = _get_first_available("class_num", "num_classes", "num_class_embeds")
if model_name is None or image_size is None or num_classes is None:
raise ValueError("Checkpoint args are missing model/image_size/num_classes information.")
return JiTCheckpointConfig(
model_name=str(model_name),
image_size=int(image_size),
num_classes=int(num_classes),
attn_dropout=float(_get_first_available("attn_dropout", "attention_dropout", default=0.0)),
proj_dropout=float(_get_first_available("proj_dropout", "dropout", default=0.0)),
)
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
model_type: str = "JiT-B/16",
sample_size: int = 256,
num_class_embeds: int = 1000,
attention_dropout: float = 0.0,
dropout: float = 0.0,
model_name: str | None = None,
image_size: int | None = None,
num_classes: int | None = None,
attn_dropout: float | None = None,
proj_dropout: float | None = None,
):
super().__init__()
resolved_model_type = model_type if model_name is None else model_name
resolved_sample_size = sample_size if image_size is None else image_size
resolved_num_class_embeds = num_class_embeds if num_classes is None else num_classes
resolved_attention_dropout = attention_dropout if attn_dropout is None else attn_dropout
resolved_dropout = dropout if proj_dropout is None else proj_dropout
if resolved_model_type not in JiT_models:
raise ValueError(f"Unknown model '{resolved_model_type}'. Available: {list(JiT_models.keys())}")
self.transformer = JiT_models[resolved_model_type](
**_build_jit_kwargs(
image_size=resolved_sample_size,
num_classes=resolved_num_class_embeds,
attn_dropout=resolved_attention_dropout,
proj_dropout=resolved_dropout,
model_name=resolved_model_type,
)
)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.Tensor,
return_dict: bool = True,
):
timestep = torch.as_tensor(timestep, device=sample.device)
if timestep.ndim == 0:
timestep = timestep.repeat(sample.shape[0])
else:
timestep = timestep.reshape(-1)
if timestep.shape[0] == 1 and sample.shape[0] > 1:
timestep = timestep.repeat(sample.shape[0])
denoised = self.transformer(sample, timestep, class_labels)
if not return_dict:
return (denoised,)
return Transformer2DModelOutput(sample=denoised)
@classmethod
def from_jit_checkpoint(
cls,
checkpoint_path: str,
weights: Literal["model", "ema1", "ema2"] = "ema1",
map_location: str = "cpu",
strict: bool = True,
) -> Tuple["JiTTransformer2DModel", Dict[str, object]]:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if "args" not in checkpoint:
raise ValueError("Checkpoint is missing 'args', cannot infer JiT architecture config.")
config = _config_from_checkpoint(checkpoint["args"])
model = cls(
model_type=config.model_name,
sample_size=config.image_size,
num_class_embeds=config.num_classes,
attention_dropout=config.attn_dropout,
dropout=config.proj_dropout,
)
key = "model" if weights == "model" else f"model_{weights}"
if key not in checkpoint:
raise ValueError(f"Checkpoint key '{key}' not found. Available keys: {list(checkpoint.keys())}")
model_state = _extract_module_state_dict(checkpoint[key])
model.transformer.load_state_dict(model_state, strict=strict)
metadata = {
"checkpoint_path": checkpoint_path,
"weights": weights,
"epoch": checkpoint.get("epoch"),
"source_args": checkpoint.get("args"),
}
return model, metadata
def to_jit_checkpoint(
self,
ema_mode: Literal["none", "copy_to_both"] = "copy_to_both",
prefix: str = "net.",
) -> Dict[str, object]:
base_state = {f"{prefix}{k}": v.detach().cpu() for k, v in self.transformer.state_dict().items()}
checkpoint = {"model": base_state}
if ema_mode == "copy_to_both":
checkpoint["model_ema1"] = {k: v.clone() for k, v in base_state.items()}
checkpoint["model_ema2"] = {k: v.clone() for k, v in base_state.items()}
elif ema_mode != "none":
raise ValueError(f"Unsupported ema_mode='{ema_mode}'.")
return checkpoint
@property
def net(self):
return self.transformer
@net.setter
def net(self, module):
self.transformer = module
# Backward-compatible alias.
JiTDiffusersModel = JiTTransformer2DModel