| |
| """Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout. |
| |
| Usage: |
| python scripts/convert_da3.py \ |
| --input da3_original.safetensors \ |
| --output models/diffusion_models/da3_comfy.safetensors |
| |
| This applies, *offline*, the exact transform that ComfyUI used to do at load |
| time for DA3: |
| |
| * remap the DINOv2 backbone keys ``backbone.pretrained.*`` (upstream DA3) |
| to the ``Dinov2Model`` runtime layout (``backbone.embeddings.*``, |
| ``backbone.encoder.layer.*``, ``backbone.layernorm.*``); |
| * split each fused ``attn.qkv`` projection into separate query/key/value |
| linears; |
| * drop the unused Gaussian-splat head weights (``gs_head.*``, ``gs_adapter.*``). |
| |
| The head (``head.*``), camera encoder/decoder (``cam_enc.*``, ``cam_dec.*``) |
| and any other keys are passed through unchanged. After conversion the file |
| loads directly via ComfyUI auto-detection with no in-code remap. |
| """ |
|
|
| import argparse |
| import glob |
| import os |
|
|
| import torch |
|
|
| from safetensors.torch import load_file, save_file |
|
|
| DROP_PREFIXES = ("gs_head.", "gs_adapter.") |
|
|
|
|
| def remap_backbone_keys(state_dict, prefix="backbone."): |
| """Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` layout.""" |
| pre = prefix + "pretrained." |
| src_keys = [k for k in state_dict.keys() if k.startswith(pre)] |
| if not src_keys: |
| return state_dict |
|
|
| static_renames = { |
| pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight", |
| pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias", |
| pre + "pos_embed": prefix + "embeddings.position_embeddings", |
| pre + "cls_token": prefix + "embeddings.cls_token", |
| pre + "camera_token": prefix + "embeddings.camera_token", |
| pre + "norm.weight": prefix + "layernorm.weight", |
| pre + "norm.bias": prefix + "layernorm.bias", |
| } |
| for src, dst in static_renames.items(): |
| if src in state_dict: |
| state_dict[dst] = state_dict.pop(src) |
|
|
| block_pre = pre + "blocks." |
| block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)] |
| for k in block_keys: |
| rest = k[len(block_pre):] |
| idx_str, _, sub = rest.partition(".") |
| target_block = "{}encoder.layer.{}.".format(prefix, idx_str) |
|
|
| |
| if sub == "attn.qkv.weight": |
| qkv = state_dict.pop(k) |
| c = qkv.shape[0] // 3 |
| state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone() |
| state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone() |
| state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone() |
| continue |
| if sub == "attn.qkv.bias": |
| qkv = state_dict.pop(k) |
| c = qkv.shape[0] // 3 |
| state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone() |
| state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone() |
| state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone() |
| continue |
|
|
| |
| if sub.startswith("attn.proj."): |
| tail = sub[len("attn.proj."):] |
| new = "attention.output.dense." + tail |
| elif sub.startswith("attn.q_norm."): |
| new = "attention.q_norm." + sub[len("attn.q_norm."):] |
| elif sub.startswith("attn.k_norm."): |
| new = "attention.k_norm." + sub[len("attn.k_norm."):] |
| elif sub == "ls1.gamma": |
| new = "layer_scale1.lambda1" |
| elif sub == "ls2.gamma": |
| new = "layer_scale2.lambda1" |
| elif sub.startswith("mlp.w12."): |
| new = "mlp.weights_in." + sub[len("mlp.w12."):] |
| elif sub.startswith("mlp.w3."): |
| new = "mlp.weights_out." + sub[len("mlp.w3."):] |
| elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")): |
| new = sub |
| else: |
| |
| raise ValueError("Unrecognised DA3 backbone key: {}".format(k)) |
|
|
| state_dict[target_block + new] = state_dict.pop(k) |
|
|
| return state_dict |
|
|
|
|
| def drop_unused(state_dict): |
| for k in list(state_dict.keys()): |
| if k.startswith(DROP_PREFIXES): |
| state_dict.pop(k) |
| return state_dict |
|
|
|
|
| def load_state_dict(path): |
| if os.path.isdir(path): |
| sd = {} |
| files = sorted(glob.glob(os.path.join(path, "*.safetensors"))) |
| if not files: |
| raise FileNotFoundError("No .safetensors files in {}".format(path)) |
| for f in files: |
| sd.update(load_file(f)) |
| return sd |
|
|
| if path.endswith(".safetensors"): |
| return load_file(path) |
|
|
| sd = torch.load(path, map_location="cpu", weights_only=False) |
| |
| for wrap in ("state_dict", "model", "module"): |
| if isinstance(sd, dict) and wrap in sd and isinstance(sd[wrap], dict): |
| sd = sd[wrap] |
| break |
| return sd |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Repackage an original Depth Anything 3 checkpoint into ComfyUI's native layout" |
| ) |
| parser.add_argument("--input", type=str, required=True, |
| help="Path to original DA3 .safetensors / .pt / .pth file or directory") |
| parser.add_argument("--output", type=str, required=True, |
| help="Output .safetensors file path") |
| args = parser.parse_args() |
|
|
| print("Loading: {}".format(args.input)) |
| sd = load_state_dict(args.input) |
| print(" Loaded {} keys".format(len(sd))) |
|
|
| |
| |
| |
| |
| if any(k.startswith("model.") for k in sd): |
| print(' Stripping "model." prefix for processing') |
| sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()} |
|
|
| if any(k.startswith("backbone.pretrained.") for k in sd): |
| print(" Remapping backbone (backbone.pretrained.* -> Dinov2Model layout)...") |
| sd = remap_backbone_keys(sd, prefix="backbone.") |
| elif any(k.startswith("backbone.embeddings.") for k in sd): |
| print(" Backbone already in ComfyUI layout, skipping remap.") |
| else: |
| raise ValueError("Input does not look like a DA3 checkpoint (no backbone.* keys found)") |
|
|
| n_before = len(sd) |
| sd = drop_unused(sd) |
| dropped = n_before - len(sd) |
| if dropped: |
| print(" Dropped {} unused Gaussian-head keys".format(dropped)) |
|
|
| |
| sd = {"model." + k: v for k, v in sd.items()} |
|
|
| |
| |
| sd = {k: v.contiguous() for k, v in sd.items()} |
|
|
| os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) |
| save_file(sd, args.output) |
| print(" Saved {} keys to {}".format(len(sd), args.output)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|