comfyui
Depth-Anything-3 / convert_da3.py
TalmajM's picture
Upload convert_da3.py with huggingface_hub
24dbacf verified
#!/usr/bin/env python3
"""Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout.
Usage:
python scripts/convert_da3.py \
--input da3_original.safetensors \
--output models/diffusion_models/da3_comfy.safetensors
This applies, *offline*, the exact transform that ComfyUI used to do at load
time for DA3:
* remap the DINOv2 backbone keys ``backbone.pretrained.*`` (upstream DA3)
to the ``Dinov2Model`` runtime layout (``backbone.embeddings.*``,
``backbone.encoder.layer.*``, ``backbone.layernorm.*``);
* split each fused ``attn.qkv`` projection into separate query/key/value
linears;
* drop the unused Gaussian-splat head weights (``gs_head.*``, ``gs_adapter.*``).
The head (``head.*``), camera encoder/decoder (``cam_enc.*``, ``cam_dec.*``)
and any other keys are passed through unchanged. After conversion the file
loads directly via ComfyUI auto-detection with no in-code remap.
"""
import argparse
import glob
import os
import torch
from safetensors.torch import load_file, save_file
DROP_PREFIXES = ("gs_head.", "gs_adapter.")
def remap_backbone_keys(state_dict, prefix="backbone."):
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` layout."""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:
return state_dict
static_renames = {
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
pre + "pos_embed": prefix + "embeddings.position_embeddings",
pre + "cls_token": prefix + "embeddings.cls_token",
pre + "camera_token": prefix + "embeddings.camera_token",
pre + "norm.weight": prefix + "layernorm.weight",
pre + "norm.bias": prefix + "layernorm.bias",
}
for src, dst in static_renames.items():
if src in state_dict:
state_dict[dst] = state_dict.pop(src)
block_pre = pre + "blocks."
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
for k in block_keys:
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
idx_str, _, sub = rest.partition(".")
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
# Fused QKV -> split query/key/value linears.
if sub == "attn.qkv.weight":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
continue
if sub == "attn.qkv.bias":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
continue
# Sub-key remap (suffix preserved).
if sub.startswith("attn.proj."):
tail = sub[len("attn.proj."):]
new = "attention.output.dense." + tail
elif sub.startswith("attn.q_norm."):
new = "attention.q_norm." + sub[len("attn.q_norm."):]
elif sub.startswith("attn.k_norm."):
new = "attention.k_norm." + sub[len("attn.k_norm."):]
elif sub == "ls1.gamma":
new = "layer_scale1.lambda1"
elif sub == "ls2.gamma":
new = "layer_scale2.lambda1"
elif sub.startswith("mlp.w12."):
new = "mlp.weights_in." + sub[len("mlp.w12."):]
elif sub.startswith("mlp.w3."):
new = "mlp.weights_out." + sub[len("mlp.w3."):]
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
new = sub
else:
# Unrecognised key -- leave as-is so a later load can complain.
raise ValueError("Unrecognised DA3 backbone key: {}".format(k))
state_dict[target_block + new] = state_dict.pop(k)
return state_dict
def drop_unused(state_dict):
for k in list(state_dict.keys()):
if k.startswith(DROP_PREFIXES):
state_dict.pop(k)
return state_dict
def load_state_dict(path):
if os.path.isdir(path):
sd = {}
files = sorted(glob.glob(os.path.join(path, "*.safetensors")))
if not files:
raise FileNotFoundError("No .safetensors files in {}".format(path))
for f in files:
sd.update(load_file(f))
return sd
if path.endswith(".safetensors"):
return load_file(path)
sd = torch.load(path, map_location="cpu", weights_only=False)
# Unwrap common nesting (e.g. {"model": ...} / {"state_dict": ...}).
for wrap in ("state_dict", "model", "module"):
if isinstance(sd, dict) and wrap in sd and isinstance(sd[wrap], dict):
sd = sd[wrap]
break
return sd
def main():
parser = argparse.ArgumentParser(
description="Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout"
)
parser.add_argument("--input", type=str, required=True,
help="Path to original DA3 .safetensors / .pt / .pth file or directory")
parser.add_argument("--output", type=str, required=True,
help="Output .safetensors file path")
args = parser.parse_args()
print("Loading: {}".format(args.input))
sd = load_state_dict(args.input)
print(" Loaded {} keys".format(len(sd)))
# Original DA3 checkpoints store everything under a "model." string prefix
# (e.g. "model.backbone.pretrained.*"). Strip it so the remap works on bare
# "backbone.*" keys, then re-add it at the end: ComfyUI's loader resolves the
# diffusion-model prefix to "model." for DA3, so the saved file must keep it.
if any(k.startswith("model.") for k in sd):
print(' Stripping "model." prefix for processing')
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
if any(k.startswith("backbone.pretrained.") for k in sd):
print(" Remapping backbone (backbone.pretrained.* -> Dinov2Model layout)...")
sd = remap_backbone_keys(sd, prefix="backbone.")
elif any(k.startswith("backbone.embeddings.") for k in sd):
print(" Backbone already in ComfyUI layout, skipping remap.")
else:
raise ValueError("Input does not look like a DA3 checkpoint (no backbone.* keys found)")
n_before = len(sd)
sd = drop_unused(sd)
dropped = n_before - len(sd)
if dropped:
print(" Dropped {} unused Gaussian-head keys".format(dropped))
# Re-add the "model." prefix expected by ComfyUI's diffusion-model loader.
sd = {"model." + k: v for k, v in sd.items()}
# safetensors requires contiguous tensors; the qkv split slices are cloned
# above but enforce contiguity defensively for all tensors.
sd = {k: v.contiguous() for k, v in sd.items()}
os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True)
save_file(sd, args.output)
print(" Saved {} keys to {}".format(len(sd), args.output))
if __name__ == "__main__":
main()