File size: 7,286 Bytes
2a4c86a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | from __future__ import annotations
import argparse
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, Literal, Tuple
import torch
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from .modeling_jit_backbone import JiT_models
def _extract_module_state_dict(
state_dict: Dict[str, torch.Tensor], prefixes: Tuple[str, ...] = ("transformer.", "net.")
) -> Dict[str, torch.Tensor]:
"""Extract module state by stripping the first fully-matching prefix.
Prefix precedence is left-to-right; `"transformer."` is preferred over legacy `"net."`.
"""
for prefix in prefixes:
if all(key.startswith(prefix) for key in state_dict.keys()):
return {k[len(prefix):]: v for k, v in state_dict.items()}
return state_dict
def _build_jit_kwargs(
image_size: int,
num_classes: int,
attn_dropout: float,
proj_dropout: float,
model_name: str | None = None,
) -> Dict[str, object]:
# Keep model_name for backward-compatible internal call signatures.
_ = model_name
return {
"input_size": image_size,
"in_channels": 3,
"num_classes": num_classes,
"attn_drop": attn_dropout,
"proj_drop": proj_dropout,
}
@dataclass
class JiTCheckpointConfig:
model_name: str
image_size: int
num_classes: int
attn_dropout: float
proj_dropout: float
def _config_from_checkpoint(ckpt_args: argparse.Namespace | Mapping[str, Any]) -> JiTCheckpointConfig:
if isinstance(ckpt_args, argparse.Namespace):
args_dict = vars(ckpt_args)
elif isinstance(ckpt_args, Mapping):
args_dict = ckpt_args
else:
raise TypeError(f"Unsupported checkpoint args type: {type(ckpt_args)}")
def _get_first_available(*keys: str, default=None):
for key in keys:
if key in args_dict and args_dict[key] is not None:
return args_dict[key]
return default
model_name = _get_first_available("model", "model_name", "model_type")
image_size = _get_first_available("img_size", "image_size", "sample_size")
num_classes = _get_first_available("class_num", "num_classes", "num_class_embeds")
if model_name is None or image_size is None or num_classes is None:
raise ValueError("Checkpoint args are missing model/image_size/num_classes information.")
return JiTCheckpointConfig(
model_name=str(model_name),
image_size=int(image_size),
num_classes=int(num_classes),
attn_dropout=float(_get_first_available("attn_dropout", "attention_dropout", default=0.0)),
proj_dropout=float(_get_first_available("proj_dropout", "dropout", default=0.0)),
)
class JiTTransformer2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
model_type: str = "JiT-B/16",
sample_size: int = 256,
num_class_embeds: int = 1000,
attention_dropout: float = 0.0,
dropout: float = 0.0,
model_name: str | None = None,
image_size: int | None = None,
num_classes: int | None = None,
attn_dropout: float | None = None,
proj_dropout: float | None = None,
):
super().__init__()
resolved_model_type = model_type if model_name is None else model_name
resolved_sample_size = sample_size if image_size is None else image_size
resolved_num_class_embeds = num_class_embeds if num_classes is None else num_classes
resolved_attention_dropout = attention_dropout if attn_dropout is None else attn_dropout
resolved_dropout = dropout if proj_dropout is None else proj_dropout
if resolved_model_type not in JiT_models:
raise ValueError(f"Unknown model '{resolved_model_type}'. Available: {list(JiT_models.keys())}")
self.transformer = JiT_models[resolved_model_type](
**_build_jit_kwargs(
image_size=resolved_sample_size,
num_classes=resolved_num_class_embeds,
attn_dropout=resolved_attention_dropout,
proj_dropout=resolved_dropout,
model_name=resolved_model_type,
)
)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
class_labels: torch.Tensor,
return_dict: bool = True,
):
timestep = torch.as_tensor(timestep, device=sample.device)
if timestep.ndim == 0:
timestep = timestep.repeat(sample.shape[0])
else:
timestep = timestep.reshape(-1)
if timestep.shape[0] == 1 and sample.shape[0] > 1:
timestep = timestep.repeat(sample.shape[0])
denoised = self.transformer(sample, timestep, class_labels)
if not return_dict:
return (denoised,)
return Transformer2DModelOutput(sample=denoised)
@classmethod
def from_jit_checkpoint(
cls,
checkpoint_path: str,
weights: Literal["model", "ema1", "ema2"] = "ema1",
map_location: str = "cpu",
strict: bool = True,
) -> Tuple["JiTTransformer2DModel", Dict[str, object]]:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if "args" not in checkpoint:
raise ValueError("Checkpoint is missing 'args', cannot infer JiT architecture config.")
config = _config_from_checkpoint(checkpoint["args"])
model = cls(
model_type=config.model_name,
sample_size=config.image_size,
num_class_embeds=config.num_classes,
attention_dropout=config.attn_dropout,
dropout=config.proj_dropout,
)
key = "model" if weights == "model" else f"model_{weights}"
if key not in checkpoint:
raise ValueError(f"Checkpoint key '{key}' not found. Available keys: {list(checkpoint.keys())}")
model_state = _extract_module_state_dict(checkpoint[key])
model.transformer.load_state_dict(model_state, strict=strict)
metadata = {
"checkpoint_path": checkpoint_path,
"weights": weights,
"epoch": checkpoint.get("epoch"),
"source_args": checkpoint.get("args"),
}
return model, metadata
def to_jit_checkpoint(
self,
ema_mode: Literal["none", "copy_to_both"] = "copy_to_both",
prefix: str = "net.",
) -> Dict[str, object]:
base_state = {f"{prefix}{k}": v.detach().cpu() for k, v in self.transformer.state_dict().items()}
checkpoint = {"model": base_state}
if ema_mode == "copy_to_both":
checkpoint["model_ema1"] = {k: v.clone() for k, v in base_state.items()}
checkpoint["model_ema2"] = {k: v.clone() for k, v in base_state.items()}
elif ema_mode != "none":
raise ValueError(f"Unsupported ema_mode='{ema_mode}'.")
return checkpoint
@property
def net(self):
return self.transformer
@net.setter
def net(self, module):
self.transformer = module
# Backward-compatible alias.
JiTDiffusersModel = JiTTransformer2DModel
|