File size: 7,286 Bytes
2a4c86a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations

import argparse
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any, Dict, Literal, Tuple

import torch
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput

from .modeling_jit_backbone import JiT_models


def _extract_module_state_dict(
    state_dict: Dict[str, torch.Tensor], prefixes: Tuple[str, ...] = ("transformer.", "net.")
) -> Dict[str, torch.Tensor]:
    """Extract module state by stripping the first fully-matching prefix.

    Prefix precedence is left-to-right; `"transformer."` is preferred over legacy `"net."`.
    """
    for prefix in prefixes:
        if all(key.startswith(prefix) for key in state_dict.keys()):
            return {k[len(prefix):]: v for k, v in state_dict.items()}
    return state_dict


def _build_jit_kwargs(
    image_size: int,
    num_classes: int,
    attn_dropout: float,
    proj_dropout: float,
    model_name: str | None = None,
) -> Dict[str, object]:
    # Keep model_name for backward-compatible internal call signatures.
    _ = model_name
    return {
        "input_size": image_size,
        "in_channels": 3,
        "num_classes": num_classes,
        "attn_drop": attn_dropout,
        "proj_drop": proj_dropout,
    }


@dataclass
class JiTCheckpointConfig:
    model_name: str
    image_size: int
    num_classes: int
    attn_dropout: float
    proj_dropout: float


def _config_from_checkpoint(ckpt_args: argparse.Namespace | Mapping[str, Any]) -> JiTCheckpointConfig:
    if isinstance(ckpt_args, argparse.Namespace):
        args_dict = vars(ckpt_args)
    elif isinstance(ckpt_args, Mapping):
        args_dict = ckpt_args
    else:
        raise TypeError(f"Unsupported checkpoint args type: {type(ckpt_args)}")

    def _get_first_available(*keys: str, default=None):
        for key in keys:
            if key in args_dict and args_dict[key] is not None:
                return args_dict[key]
        return default

    model_name = _get_first_available("model", "model_name", "model_type")
    image_size = _get_first_available("img_size", "image_size", "sample_size")
    num_classes = _get_first_available("class_num", "num_classes", "num_class_embeds")
    if model_name is None or image_size is None or num_classes is None:
        raise ValueError("Checkpoint args are missing model/image_size/num_classes information.")

    return JiTCheckpointConfig(
        model_name=str(model_name),
        image_size=int(image_size),
        num_classes=int(num_classes),
        attn_dropout=float(_get_first_available("attn_dropout", "attention_dropout", default=0.0)),
        proj_dropout=float(_get_first_available("proj_dropout", "dropout", default=0.0)),
    )


class JiTTransformer2DModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        model_type: str = "JiT-B/16",
        sample_size: int = 256,
        num_class_embeds: int = 1000,
        attention_dropout: float = 0.0,
        dropout: float = 0.0,
        model_name: str | None = None,
        image_size: int | None = None,
        num_classes: int | None = None,
        attn_dropout: float | None = None,
        proj_dropout: float | None = None,
    ):
        super().__init__()
        resolved_model_type = model_type if model_name is None else model_name
        resolved_sample_size = sample_size if image_size is None else image_size
        resolved_num_class_embeds = num_class_embeds if num_classes is None else num_classes
        resolved_attention_dropout = attention_dropout if attn_dropout is None else attn_dropout
        resolved_dropout = dropout if proj_dropout is None else proj_dropout

        if resolved_model_type not in JiT_models:
            raise ValueError(f"Unknown model '{resolved_model_type}'. Available: {list(JiT_models.keys())}")

        self.transformer = JiT_models[resolved_model_type](
            **_build_jit_kwargs(
                image_size=resolved_sample_size,
                num_classes=resolved_num_class_embeds,
                attn_dropout=resolved_attention_dropout,
                proj_dropout=resolved_dropout,
                model_name=resolved_model_type,
            )
        )

    def forward(
        self,
        sample: torch.Tensor,
        timestep: torch.Tensor,
        class_labels: torch.Tensor,
        return_dict: bool = True,
    ):
        timestep = torch.as_tensor(timestep, device=sample.device)
        if timestep.ndim == 0:
            timestep = timestep.repeat(sample.shape[0])
        else:
            timestep = timestep.reshape(-1)
            if timestep.shape[0] == 1 and sample.shape[0] > 1:
                timestep = timestep.repeat(sample.shape[0])

        denoised = self.transformer(sample, timestep, class_labels)
        if not return_dict:
            return (denoised,)
        return Transformer2DModelOutput(sample=denoised)

    @classmethod
    def from_jit_checkpoint(
        cls,
        checkpoint_path: str,
        weights: Literal["model", "ema1", "ema2"] = "ema1",
        map_location: str = "cpu",
        strict: bool = True,
    ) -> Tuple["JiTTransformer2DModel", Dict[str, object]]:
        checkpoint = torch.load(checkpoint_path, map_location=map_location)
        if "args" not in checkpoint:
            raise ValueError("Checkpoint is missing 'args', cannot infer JiT architecture config.")

        config = _config_from_checkpoint(checkpoint["args"])
        model = cls(
            model_type=config.model_name,
            sample_size=config.image_size,
            num_class_embeds=config.num_classes,
            attention_dropout=config.attn_dropout,
            dropout=config.proj_dropout,
        )

        key = "model" if weights == "model" else f"model_{weights}"
        if key not in checkpoint:
            raise ValueError(f"Checkpoint key '{key}' not found. Available keys: {list(checkpoint.keys())}")

        model_state = _extract_module_state_dict(checkpoint[key])
        model.transformer.load_state_dict(model_state, strict=strict)

        metadata = {
            "checkpoint_path": checkpoint_path,
            "weights": weights,
            "epoch": checkpoint.get("epoch"),
            "source_args": checkpoint.get("args"),
        }
        return model, metadata

    def to_jit_checkpoint(
        self,
        ema_mode: Literal["none", "copy_to_both"] = "copy_to_both",
        prefix: str = "net.",
    ) -> Dict[str, object]:
        base_state = {f"{prefix}{k}": v.detach().cpu() for k, v in self.transformer.state_dict().items()}
        checkpoint = {"model": base_state}
        if ema_mode == "copy_to_both":
            checkpoint["model_ema1"] = {k: v.clone() for k, v in base_state.items()}
            checkpoint["model_ema2"] = {k: v.clone() for k, v in base_state.items()}
        elif ema_mode != "none":
            raise ValueError(f"Unsupported ema_mode='{ema_mode}'.")
        return checkpoint

    @property
    def net(self):
        return self.transformer

    @net.setter
    def net(self, module):
        self.transformer = module


# Backward-compatible alias.
JiTDiffusersModel = JiTTransformer2DModel