import torch import yaml import argparse from models import AVCDiT_models def add_exact_keys(mapping, keys): for k in keys: mapping[k] = k def add_mlp_block_keys(mapping, mlp_name, num_blocks): for i in range(num_blocks): for fc in ["fc1", "fc2"]: for param in ["weight", "bias"]: k = f"blocks.{i}.{mlp_name}.{fc}.{param}" mapping[k] = k def load_from_two_checkpoints(model, ckpt1_path, ckpt2_path, map1=None, map2=None, device='cuda'): ckpt1 = torch.load(ckpt1_path, map_location=device, weights_only=False) ckpt2 = torch.load(ckpt2_path, map_location=device, weights_only=False) state1 = {k.replace('_orig_mod.', ''): v for k, v in ckpt1["ema"].items()} state2 = {k.replace('_orig_mod.', ''): v for k, v in ckpt2["ema"].items()} model_state = model.state_dict() new_state = {} source_info = {} # key: model param name, value: ckpt source name if map1: for k_model, k_ckpt in map1.items(): if ( k_ckpt in state1 and k_model in model_state and state1[k_ckpt].shape == model_state[k_model].shape ): new_state[k_model] = state1[k_ckpt] source_info[k_model] = "ckpt1" if map2: for k_model, k_ckpt in map2.items(): if ( k_ckpt in state2 and k_model in model_state and state2[k_ckpt].shape == model_state[k_model].shape ): new_state[k_model] = state2[k_ckpt] source_info[k_model] = "ckpt2" for k_model, tensor in model_state.items(): if k_model not in new_state: if k_model in state1 and state1[k_model].shape == tensor.shape: new_state[k_model] = state1[k_model] source_info[k_model] = "fallback_ckpt1" model.load_state_dict(new_state, strict=False) print(f"Loaded {len(new_state)} / {len(model_state)} parameters") return new_state def main(args): with open(args.config, "r") as f: config = yaml.safe_load(f) model_name = config.get("model", "AVCDiT-B/2") print(f"Using model: {model_name}") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = AVCDiT_models[model_name]( context_size=4, input_size=28, in_channels=4, mode="av" ).to(device) depth = len(model.blocks) map1 = {} add_exact_keys(map1, [ "pos_embed_v", "x_embedder_v.proj.weight", "x_embedder_v.proj.bias", "final_layer.linear.weight", "final_layer.linear.bias", "final_layer.adaLN_modulation.1.weight", "final_layer.adaLN_modulation.1.bias", ]) add_mlp_block_keys(map1, "mlp_v", depth) map2 = {} add_exact_keys(map2, [ "pos_embed_a_cond", "pos_embed_a_pred", "x_embedder_a.weight", "x_embedder_a.bias", "final_layer_a.linear.weight", "final_layer_a.linear.bias", "final_layer_a.adaLN_modulation.1.weight", "final_layer_a.adaLN_modulation.1.bias", ]) add_mlp_block_keys(map2, "mlp_a", depth) merged_state_dict = load_from_two_checkpoints( model, ckpt1_path=args.v_expert, ckpt2_path=args.a_expert, map1=map1, map2=map2, device=device ) torch.save({"ema": merged_state_dict}, args.output) print(f"Merged model saved to {args.output}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--v_expert", type=str, required=True) parser.add_argument("--a_expert", type=str, required=True) parser.add_argument("--output", type=str, default="experts_merged.pth") args = parser.parse_args() main(args)