| | from safetensors.torch import save_file, load_file |
| | import torch |
| | import os |
| |
|
| | def inspect_keys(file_path, max_keys=10): |
| | """Helper function to inspect the structure of a safetensors file.""" |
| | state = load_file(file_path) |
| | keys = list(state.keys()) |
| | print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}") |
| | print(f"First {max_keys} keys:") |
| | for k in keys[:max_keys]: |
| | print(f" {k}") |
| | return keys |
| |
|
| | def merge_for_comfyui( |
| | unet_path, |
| | vae_path, |
| | text_encoder_path, |
| | output_path, |
| | model_type="flux" |
| | ): |
| | """ |
| | Merge components into ComfyUI-compatible safetensors checkpoint. |
| | |
| | Args: |
| | unet_path: Path to the main model/transformer safetensors |
| | vae_path: Path to the VAE safetensors |
| | text_encoder_path: Path to the text encoder/CLIP safetensors |
| | output_path: Path for the merged checkpoint |
| | model_type: Type of model (flux, sd15, sdxl) |
| | """ |
| | |
| | print("=" * 60) |
| | print("STEP 1: Inspecting input files...") |
| | print("=" * 60) |
| | |
| | |
| | unet_keys = inspect_keys(unet_path) |
| | vae_keys = inspect_keys(vae_path) |
| | text_encoder_keys = inspect_keys(text_encoder_path) |
| | |
| | print("\n" + "=" * 60) |
| | print("STEP 2: Loading weights...") |
| | print("=" * 60) |
| | |
| | unet_state = load_file(unet_path) |
| | vae_state = load_file(vae_path) |
| | text_encoder_state = load_file(text_encoder_path) |
| | |
| | print("\n" + "=" * 60) |
| | print("STEP 3: Merging with proper key structure...") |
| | print("=" * 60) |
| | |
| | merged_state = {} |
| | |
| | |
| | sample_unet_key = unet_keys[0] |
| | sample_vae_key = vae_keys[0] |
| | sample_te_key = text_encoder_keys[0] |
| | |
| | print(f"\nDetected key patterns:") |
| | print(f" UNet: {sample_unet_key}") |
| | print(f" VAE: {sample_vae_key}") |
| | print(f" Text Encoder: {sample_te_key}") |
| | |
| | |
| | for key, value in unet_state.items(): |
| | |
| | if key.startswith('model.') or key.startswith('diffusion_model.'): |
| | merged_state[key] = value |
| | else: |
| | |
| | merged_state[f'model.diffusion_model.{key}'] = value |
| | |
| | |
| | for key, value in vae_state.items(): |
| | if key.startswith('first_stage_model.') or key.startswith('vae.'): |
| | merged_state[key] = value |
| | elif key.startswith('decoder.') or key.startswith('encoder.'): |
| | merged_state[f'first_stage_model.{key}'] = value |
| | else: |
| | merged_state[f'first_stage_model.decoder.{key}'] = value |
| | |
| | |
| | for key, value in text_encoder_state.items(): |
| | if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'): |
| | merged_state[key] = value |
| | else: |
| | |
| | if model_type.lower() == "flux": |
| | merged_state[f'text_encoders.{key}'] = value |
| | else: |
| | merged_state[f'cond_stage_model.transformer.{key}'] = value |
| | |
| | print(f"\nMerged state contains {len(merged_state)} parameters") |
| | |
| | |
| | print("\n" + "=" * 60) |
| | print("STEP 4: Saving merged checkpoint...") |
| | print("=" * 60) |
| | |
| | save_file(merged_state, output_path) |
| | |
| | print("\n✅ Merge complete!") |
| | print(f"File saved to: {output_path}") |
| | |
| | size_gb = os.path.getsize(output_path) / (1024**3) |
| | print(f"File size: {size_gb:.2f} GB") |
| | |
| | |
| | print("\n" + "=" * 60) |
| | print("STEP 5: Verifying merged file...") |
| | print("=" * 60) |
| | inspect_keys(output_path, max_keys=20) |
| |
|
| |
|
| | def simple_merge_keep_structure( |
| | unet_path, |
| | vae_path, |
| | text_encoder_path, |
| | output_path |
| | ): |
| | """ |
| | Simple merge that preserves original key structure. |
| | Use this if the files already have proper ComfyUI keys. |
| | """ |
| | print("Loading all components...") |
| | |
| | unet_state = load_file(unet_path) |
| | vae_state = load_file(vae_path) |
| | text_encoder_state = load_file(text_encoder_path) |
| | |
| | print("Merging...") |
| | merged_state = {} |
| | merged_state.update(unet_state) |
| | merged_state.update(vae_state) |
| | merged_state.update(text_encoder_state) |
| | |
| | print(f"Saving {len(merged_state)} parameters...") |
| | save_file(merged_state, output_path) |
| | |
| | size_gb = os.path.getsize(output_path) / (1024**3) |
| | print(f"✅ Done! File size: {size_gb:.2f} GB") |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | merge_for_comfyui( |
| | unet_path="../flux1-depth-dev.safetensors", |
| | vae_path="../vae/diffusion_pytorch_model.safetensors", |
| | text_encoder_path="../text_encoder/model.safetensors", |
| | output_path="../flux1-depth-dev_merged_model.safetensors", |
| | model_type="flux" |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|