#!/usr/bin/env python3 """Convert original Anima transformer safetensors keys to Diffusers format.""" from __future__ import annotations import argparse from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file DIRECT_KEY_MAP = { "net.t_embedder.1.linear_1.weight": "core.time_embed.t_embedder.linear_1.weight", "net.t_embedder.1.linear_2.weight": "core.time_embed.t_embedder.linear_2.weight", "net.t_embedding_norm.weight": "core.time_embed.norm.weight", "net.x_embedder.proj.1.weight": "core.patch_embed.proj.weight", "net.final_layer.linear.weight": "core.proj_out.weight", "net.final_layer.adaln_modulation.1.weight": "core.norm_out.linear_1.weight", "net.final_layer.adaln_modulation.2.weight": "core.norm_out.linear_2.weight", } BLOCK_KEY_MAP = { "self_attn.q_proj.weight": "attn1.to_q.weight", "self_attn.k_proj.weight": "attn1.to_k.weight", "self_attn.v_proj.weight": "attn1.to_v.weight", "self_attn.output_proj.weight": "attn1.to_out.0.weight", "self_attn.q_norm.weight": "attn1.norm_q.weight", "self_attn.k_norm.weight": "attn1.norm_k.weight", "cross_attn.q_proj.weight": "attn2.to_q.weight", "cross_attn.k_proj.weight": "attn2.to_k.weight", "cross_attn.v_proj.weight": "attn2.to_v.weight", "cross_attn.output_proj.weight": "attn2.to_out.0.weight", "cross_attn.q_norm.weight": "attn2.norm_q.weight", "cross_attn.k_norm.weight": "attn2.norm_k.weight", "mlp.layer1.weight": "ff.net.0.proj.weight", "mlp.layer2.weight": "ff.net.2.weight", "adaln_modulation_self_attn.1.weight": "norm1.linear_1.weight", "adaln_modulation_self_attn.2.weight": "norm1.linear_2.weight", "adaln_modulation_cross_attn.1.weight": "norm2.linear_1.weight", "adaln_modulation_cross_attn.2.weight": "norm2.linear_2.weight", "adaln_modulation_mlp.1.weight": "norm3.linear_1.weight", "adaln_modulation_mlp.2.weight": "norm3.linear_2.weight", } def convert_key(key: str) -> str: if key.startswith("net.llm_adapter."): return key.removeprefix("net.") if key in DIRECT_KEY_MAP: return DIRECT_KEY_MAP[key] if key.startswith("net.blocks."): block_and_tail = key.removeprefix("net.blocks.") block_index, tail = block_and_tail.split(".", 1) if tail in BLOCK_KEY_MAP: return f"core.transformer_blocks.{block_index}.{BLOCK_KEY_MAP[tail]}" raise KeyError(f"No Diffusers key mapping for: {key}") def convert_checkpoint(source: Path, output: Path, overwrite: bool) -> None: if output.exists() and not overwrite: raise FileExistsError(f"Output exists: {output}. Pass --overwrite to replace it.") state_dict = {} with safe_open(source, framework="pt", device="cpu") as f: for key in f.keys(): converted_key = convert_key(key) if converted_key in state_dict: raise ValueError(f"Duplicate converted key: {converted_key}") state_dict[converted_key] = f.get_tensor(key) output.parent.mkdir(parents=True, exist_ok=True) save_file(state_dict, output, metadata={"format": "pt"}) print(f"Saved {len(state_dict)} tensors to {output}") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "source", nargs="?", type=Path, default=Path("anima-base-v1.0.safetensors"), help="Original Anima transformer safetensors file.", ) parser.add_argument( "--output", type=Path, default=Path("diffusion_pytorch_model.safetensors"), help="Diffusers-format safetensors output path.", ) parser.add_argument("--overwrite", action="store_true", help="Replace output if it already exists.") return parser.parse_args() def main() -> None: args = parse_args() convert_checkpoint(args.source, args.output, args.overwrite) if __name__ == "__main__": main()