Diffusers
Safetensors
AnimaPipeline
Anima-BaseV1-UnofficalDiffusers / transformer /convert_anima_to_diffusers.py
SleepVeryHard's picture
Upload 19 files
11d757a verified
#!/usr/bin/env python3
"""Convert original Anima transformer safetensors keys to Diffusers format."""
from __future__ import annotations
import argparse
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
DIRECT_KEY_MAP = {
"net.t_embedder.1.linear_1.weight": "core.time_embed.t_embedder.linear_1.weight",
"net.t_embedder.1.linear_2.weight": "core.time_embed.t_embedder.linear_2.weight",
"net.t_embedding_norm.weight": "core.time_embed.norm.weight",
"net.x_embedder.proj.1.weight": "core.patch_embed.proj.weight",
"net.final_layer.linear.weight": "core.proj_out.weight",
"net.final_layer.adaln_modulation.1.weight": "core.norm_out.linear_1.weight",
"net.final_layer.adaln_modulation.2.weight": "core.norm_out.linear_2.weight",
}
BLOCK_KEY_MAP = {
"self_attn.q_proj.weight": "attn1.to_q.weight",
"self_attn.k_proj.weight": "attn1.to_k.weight",
"self_attn.v_proj.weight": "attn1.to_v.weight",
"self_attn.output_proj.weight": "attn1.to_out.0.weight",
"self_attn.q_norm.weight": "attn1.norm_q.weight",
"self_attn.k_norm.weight": "attn1.norm_k.weight",
"cross_attn.q_proj.weight": "attn2.to_q.weight",
"cross_attn.k_proj.weight": "attn2.to_k.weight",
"cross_attn.v_proj.weight": "attn2.to_v.weight",
"cross_attn.output_proj.weight": "attn2.to_out.0.weight",
"cross_attn.q_norm.weight": "attn2.norm_q.weight",
"cross_attn.k_norm.weight": "attn2.norm_k.weight",
"mlp.layer1.weight": "ff.net.0.proj.weight",
"mlp.layer2.weight": "ff.net.2.weight",
"adaln_modulation_self_attn.1.weight": "norm1.linear_1.weight",
"adaln_modulation_self_attn.2.weight": "norm1.linear_2.weight",
"adaln_modulation_cross_attn.1.weight": "norm2.linear_1.weight",
"adaln_modulation_cross_attn.2.weight": "norm2.linear_2.weight",
"adaln_modulation_mlp.1.weight": "norm3.linear_1.weight",
"adaln_modulation_mlp.2.weight": "norm3.linear_2.weight",
}
def convert_key(key: str) -> str:
if key.startswith("net.llm_adapter."):
return key.removeprefix("net.")
if key in DIRECT_KEY_MAP:
return DIRECT_KEY_MAP[key]
if key.startswith("net.blocks."):
block_and_tail = key.removeprefix("net.blocks.")
block_index, tail = block_and_tail.split(".", 1)
if tail in BLOCK_KEY_MAP:
return f"core.transformer_blocks.{block_index}.{BLOCK_KEY_MAP[tail]}"
raise KeyError(f"No Diffusers key mapping for: {key}")
def convert_checkpoint(source: Path, output: Path, overwrite: bool) -> None:
if output.exists() and not overwrite:
raise FileExistsError(f"Output exists: {output}. Pass --overwrite to replace it.")
state_dict = {}
with safe_open(source, framework="pt", device="cpu") as f:
for key in f.keys():
converted_key = convert_key(key)
if converted_key in state_dict:
raise ValueError(f"Duplicate converted key: {converted_key}")
state_dict[converted_key] = f.get_tensor(key)
output.parent.mkdir(parents=True, exist_ok=True)
save_file(state_dict, output, metadata={"format": "pt"})
print(f"Saved {len(state_dict)} tensors to {output}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"source",
nargs="?",
type=Path,
default=Path("anima-base-v1.0.safetensors"),
help="Original Anima transformer safetensors file.",
)
parser.add_argument(
"--output",
type=Path,
default=Path("diffusion_pytorch_model.safetensors"),
help="Diffusers-format safetensors output path.",
)
parser.add_argument("--overwrite", action="store_true", help="Replace output if it already exists.")
return parser.parse_args()
def main() -> None:
args = parse_args()
convert_checkpoint(args.source, args.output, args.overwrite)
if __name__ == "__main__":
main()