Diffusers
Safetensors
AnimaPipeline
File size: 3,991 Bytes
11d757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
"""Convert original Anima transformer safetensors keys to Diffusers format."""

from __future__ import annotations

import argparse
from pathlib import Path

from safetensors import safe_open
from safetensors.torch import save_file


DIRECT_KEY_MAP = {
    "net.t_embedder.1.linear_1.weight": "core.time_embed.t_embedder.linear_1.weight",
    "net.t_embedder.1.linear_2.weight": "core.time_embed.t_embedder.linear_2.weight",
    "net.t_embedding_norm.weight": "core.time_embed.norm.weight",
    "net.x_embedder.proj.1.weight": "core.patch_embed.proj.weight",
    "net.final_layer.linear.weight": "core.proj_out.weight",
    "net.final_layer.adaln_modulation.1.weight": "core.norm_out.linear_1.weight",
    "net.final_layer.adaln_modulation.2.weight": "core.norm_out.linear_2.weight",
}

BLOCK_KEY_MAP = {
    "self_attn.q_proj.weight": "attn1.to_q.weight",
    "self_attn.k_proj.weight": "attn1.to_k.weight",
    "self_attn.v_proj.weight": "attn1.to_v.weight",
    "self_attn.output_proj.weight": "attn1.to_out.0.weight",
    "self_attn.q_norm.weight": "attn1.norm_q.weight",
    "self_attn.k_norm.weight": "attn1.norm_k.weight",
    "cross_attn.q_proj.weight": "attn2.to_q.weight",
    "cross_attn.k_proj.weight": "attn2.to_k.weight",
    "cross_attn.v_proj.weight": "attn2.to_v.weight",
    "cross_attn.output_proj.weight": "attn2.to_out.0.weight",
    "cross_attn.q_norm.weight": "attn2.norm_q.weight",
    "cross_attn.k_norm.weight": "attn2.norm_k.weight",
    "mlp.layer1.weight": "ff.net.0.proj.weight",
    "mlp.layer2.weight": "ff.net.2.weight",
    "adaln_modulation_self_attn.1.weight": "norm1.linear_1.weight",
    "adaln_modulation_self_attn.2.weight": "norm1.linear_2.weight",
    "adaln_modulation_cross_attn.1.weight": "norm2.linear_1.weight",
    "adaln_modulation_cross_attn.2.weight": "norm2.linear_2.weight",
    "adaln_modulation_mlp.1.weight": "norm3.linear_1.weight",
    "adaln_modulation_mlp.2.weight": "norm3.linear_2.weight",
}


def convert_key(key: str) -> str:
    if key.startswith("net.llm_adapter."):
        return key.removeprefix("net.")

    if key in DIRECT_KEY_MAP:
        return DIRECT_KEY_MAP[key]

    if key.startswith("net.blocks."):
        block_and_tail = key.removeprefix("net.blocks.")
        block_index, tail = block_and_tail.split(".", 1)
        if tail in BLOCK_KEY_MAP:
            return f"core.transformer_blocks.{block_index}.{BLOCK_KEY_MAP[tail]}"

    raise KeyError(f"No Diffusers key mapping for: {key}")


def convert_checkpoint(source: Path, output: Path, overwrite: bool) -> None:
    if output.exists() and not overwrite:
        raise FileExistsError(f"Output exists: {output}. Pass --overwrite to replace it.")

    state_dict = {}
    with safe_open(source, framework="pt", device="cpu") as f:
        for key in f.keys():
            converted_key = convert_key(key)
            if converted_key in state_dict:
                raise ValueError(f"Duplicate converted key: {converted_key}")
            state_dict[converted_key] = f.get_tensor(key)

    output.parent.mkdir(parents=True, exist_ok=True)
    save_file(state_dict, output, metadata={"format": "pt"})
    print(f"Saved {len(state_dict)} tensors to {output}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "source",
        nargs="?",
        type=Path,
        default=Path("anima-base-v1.0.safetensors"),
        help="Original Anima transformer safetensors file.",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("diffusion_pytorch_model.safetensors"),
        help="Diffusers-format safetensors output path.",
    )
    parser.add_argument("--overwrite", action="store_true", help="Replace output if it already exists.")
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    convert_checkpoint(args.source, args.output, args.overwrite)


if __name__ == "__main__":
    main()