| | from safetensors.torch import save_file, load_file |
| | import torch |
| |
|
| | def merge_model_components( |
| | unet_path, |
| | vae_path, |
| | text_encoder_path, |
| | output_path |
| | ): |
| | """ |
| | Merge UNet, VAE, and text encoder into a single safetensors file. |
| | |
| | Args: |
| | unet_path: Path to the main model/unet safetensors file |
| | vae_path: Path to the VAE safetensors file |
| | text_encoder_path: Path to the text encoder/CLIP safetensors file |
| | output_path: Path where the merged file will be saved |
| | """ |
| | |
| | print("Loading UNet/Model weights...") |
| | unet_state = load_file(unet_path) |
| | |
| | print("Loading VAE weights...") |
| | vae_state = load_file(vae_path) |
| | |
| | print("Loading Text Encoder weights...") |
| | text_encoder_state = load_file(text_encoder_path) |
| | |
| | |
| | print("Merging state dictionaries...") |
| | merged_state = {} |
| | |
| | |
| | merged_state.update(unet_state) |
| | |
| | |
| | for key, value in vae_state.items(): |
| | |
| | if not key.startswith('vae.'): |
| | merged_state[f'vae.{key}'] = value |
| | else: |
| | merged_state[key] = value |
| | |
| | |
| | for key, value in text_encoder_state.items(): |
| | |
| | if not key.startswith('text_encoder.'): |
| | merged_state[f'text_encoder.{key}'] = value |
| | else: |
| | merged_state[key] = value |
| | |
| | print(f"Total parameters in merged model: {len(merged_state)}") |
| | print(f"Saving merged model to {output_path}...") |
| | |
| | |
| | save_file(merged_state, output_path) |
| | |
| | print("✅ Merge complete!") |
| | print(f"File saved to: {output_path}") |
| | |
| | |
| | import os |
| | size_gb = os.path.getsize(output_path) / (1024**3) |
| | print(f"File size: {size_gb:.2f} GB") |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | merge_model_components( |
| | unet_path="flux1-depth-dev.safetensors", |
| | vae_path="vae/diffusion_pytorch_model.safetensors", |
| | text_encoder_path="text_encoder/model.safetensors", |
| | output_path="flux1-depth-dev_merged_model.safetensors" |
| | ) |
| |
|