#!/usr/bin/env python3 """Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout. Usage: python scripts/convert_da3.py \ --input da3_original.safetensors \ --output models/diffusion_models/da3_comfy.safetensors This applies, *offline*, the exact transform that ComfyUI used to do at load time for DA3: * remap the DINOv2 backbone keys ``backbone.pretrained.*`` (upstream DA3) to the ``Dinov2Model`` runtime layout (``backbone.embeddings.*``, ``backbone.encoder.layer.*``, ``backbone.layernorm.*``); * split each fused ``attn.qkv`` projection into separate query/key/value linears; * drop the unused Gaussian-splat head weights (``gs_head.*``, ``gs_adapter.*``). The head (``head.*``), camera encoder/decoder (``cam_enc.*``, ``cam_dec.*``) and any other keys are passed through unchanged. After conversion the file loads directly via ComfyUI auto-detection with no in-code remap. """ import argparse import glob import os import torch from safetensors.torch import load_file, save_file DROP_PREFIXES = ("gs_head.", "gs_adapter.") def remap_backbone_keys(state_dict, prefix="backbone."): """Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` layout.""" pre = prefix + "pretrained." src_keys = [k for k in state_dict.keys() if k.startswith(pre)] if not src_keys: return state_dict static_renames = { pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight", pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias", pre + "pos_embed": prefix + "embeddings.position_embeddings", pre + "cls_token": prefix + "embeddings.cls_token", pre + "camera_token": prefix + "embeddings.camera_token", pre + "norm.weight": prefix + "layernorm.weight", pre + "norm.bias": prefix + "layernorm.bias", } for src, dst in static_renames.items(): if src in state_dict: state_dict[dst] = state_dict.pop(src) block_pre = pre + "blocks." block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)] for k in block_keys: rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight" idx_str, _, sub = rest.partition(".") target_block = "{}encoder.layer.{}.".format(prefix, idx_str) # Fused QKV -> split query/key/value linears. if sub == "attn.qkv.weight": qkv = state_dict.pop(k) c = qkv.shape[0] // 3 state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone() state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone() state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone() continue if sub == "attn.qkv.bias": qkv = state_dict.pop(k) c = qkv.shape[0] // 3 state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone() state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone() state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone() continue # Sub-key remap (suffix preserved). if sub.startswith("attn.proj."): tail = sub[len("attn.proj."):] new = "attention.output.dense." + tail elif sub.startswith("attn.q_norm."): new = "attention.q_norm." + sub[len("attn.q_norm."):] elif sub.startswith("attn.k_norm."): new = "attention.k_norm." + sub[len("attn.k_norm."):] elif sub == "ls1.gamma": new = "layer_scale1.lambda1" elif sub == "ls2.gamma": new = "layer_scale2.lambda1" elif sub.startswith("mlp.w12."): new = "mlp.weights_in." + sub[len("mlp.w12."):] elif sub.startswith("mlp.w3."): new = "mlp.weights_out." + sub[len("mlp.w3."):] elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")): new = sub else: # Unrecognised key -- leave as-is so a later load can complain. raise ValueError("Unrecognised DA3 backbone key: {}".format(k)) state_dict[target_block + new] = state_dict.pop(k) return state_dict def drop_unused(state_dict): for k in list(state_dict.keys()): if k.startswith(DROP_PREFIXES): state_dict.pop(k) return state_dict def load_state_dict(path): if os.path.isdir(path): sd = {} files = sorted(glob.glob(os.path.join(path, "*.safetensors"))) if not files: raise FileNotFoundError("No .safetensors files in {}".format(path)) for f in files: sd.update(load_file(f)) return sd if path.endswith(".safetensors"): return load_file(path) sd = torch.load(path, map_location="cpu", weights_only=False) # Unwrap common nesting (e.g. {"model": ...} / {"state_dict": ...}). for wrap in ("state_dict", "model", "module"): if isinstance(sd, dict) and wrap in sd and isinstance(sd[wrap], dict): sd = sd[wrap] break return sd def main(): parser = argparse.ArgumentParser( description="Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout" ) parser.add_argument("--input", type=str, required=True, help="Path to original DA3 .safetensors / .pt / .pth file or directory") parser.add_argument("--output", type=str, required=True, help="Output .safetensors file path") args = parser.parse_args() print("Loading: {}".format(args.input)) sd = load_state_dict(args.input) print(" Loaded {} keys".format(len(sd))) # Original DA3 checkpoints store everything under a "model." string prefix # (e.g. "model.backbone.pretrained.*"). Strip it so the remap works on bare # "backbone.*" keys, then re-add it at the end: ComfyUI's loader resolves the # diffusion-model prefix to "model." for DA3, so the saved file must keep it. if any(k.startswith("model.") for k in sd): print(' Stripping "model." prefix for processing') sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()} if any(k.startswith("backbone.pretrained.") for k in sd): print(" Remapping backbone (backbone.pretrained.* -> Dinov2Model layout)...") sd = remap_backbone_keys(sd, prefix="backbone.") elif any(k.startswith("backbone.embeddings.") for k in sd): print(" Backbone already in ComfyUI layout, skipping remap.") else: raise ValueError("Input does not look like a DA3 checkpoint (no backbone.* keys found)") n_before = len(sd) sd = drop_unused(sd) dropped = n_before - len(sd) if dropped: print(" Dropped {} unused Gaussian-head keys".format(dropped)) # Re-add the "model." prefix expected by ComfyUI's diffusion-model loader. sd = {"model." + k: v for k, v in sd.items()} # safetensors requires contiguous tensors; the qkv split slices are cloned # above but enforce contiguity defensively for all tensors. sd = {k: v.contiguous() for k, v in sd.items()} os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) save_file(sd, args.output) print(" Saved {} keys to {}".format(len(sd), args.output)) if __name__ == "__main__": main()