| | |
| | """ |
| | Convert DeepFilterNet PyTorch weights to MLX format. |
| | |
| | This script converts pretrained DeepFilterNet models from the original |
| | PyTorch implementation to MLX-compatible format with proper weight mapping. |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import configparser |
| | from pathlib import Path |
| | from typing import Dict, Any, List, Tuple |
| | import re |
| |
|
| | import mlx.core as mx |
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def convert_weight(weight: torch.Tensor) -> mx.array: |
| | """Convert PyTorch tensor to MLX array.""" |
| | return mx.array(weight.detach().cpu().numpy()) |
| |
|
| |
|
| | def parse_config(config_path: Path) -> Dict[str, Any]: |
| | """Parse DeepFilterNet config.ini file.""" |
| | config = configparser.ConfigParser() |
| | config.read(config_path) |
| | |
| | linear_groups = config.getint("deepfilternet", "linear_groups", fallback=16) |
| | df_order = config.getint( |
| | "df", |
| | "df_order", |
| | fallback=config.getint("deepfilternet", "df_order", fallback=5), |
| | ) |
| | df_lookahead = config.getint( |
| | "df", |
| | "df_lookahead", |
| | fallback=config.getint("deepfilternet", "df_lookahead", fallback=0), |
| | ) |
| |
|
| | result = { |
| | |
| | "sample_rate": config.getint("df", "sr", fallback=48000), |
| | "fft_size": config.getint("df", "fft_size", fallback=960), |
| | "hop_size": config.getint("df", "hop_size", fallback=480), |
| | "nb_erb": config.getint("df", "nb_erb", fallback=32), |
| | "nb_df": config.getint("df", "nb_df", fallback=96), |
| | "df_order": df_order, |
| | "df_lookahead": df_lookahead, |
| | "lsnr_max": config.getint("df", "lsnr_max", fallback=35), |
| | "lsnr_min": config.getint("df", "lsnr_min", fallback=-15), |
| | |
| | |
| | "conv_ch": config.getint("deepfilternet", "conv_ch", fallback=64), |
| | "conv_k_enc": config.getint("deepfilternet", "conv_k_enc", fallback=1), |
| | "conv_k_dec": config.getint("deepfilternet", "conv_k_dec", fallback=1), |
| | "conv_width_factor": config.getint("deepfilternet", "conv_width_factor", fallback=1), |
| | "conv_dec_mode": config.get("deepfilternet", "conv_dec_mode", fallback="transposed"), |
| | "emb_hidden_dim": config.getint("deepfilternet", "emb_hidden_dim", fallback=256), |
| | "emb_num_layers": config.getint("deepfilternet", "emb_num_layers", fallback=3), |
| | "df_hidden_dim": config.getint("deepfilternet", "df_hidden_dim", fallback=256), |
| | "df_num_layers": config.getint("deepfilternet", "df_num_layers", fallback=2), |
| | "gru_groups": config.getint("deepfilternet", "gru_groups", fallback=8), |
| | "linear_groups": linear_groups, |
| | |
| | |
| | "enc_linear_groups": config.getint("deepfilternet", "enc_linear_groups", fallback=linear_groups), |
| | "group_shuffle": config.getboolean("deepfilternet", "group_shuffle", fallback=False), |
| | "mask_pf": config.getboolean("deepfilternet", "mask_pf", fallback=False), |
| | "conv_lookahead": config.getint("deepfilternet", "conv_lookahead", fallback=2), |
| | "conv_depthwise": config.getboolean("deepfilternet", "conv_depthwise", fallback=True), |
| | "convt_depthwise": config.getboolean("deepfilternet", "convt_depthwise", fallback=False), |
| | "enc_concat": config.getboolean("deepfilternet", "enc_concat", fallback=False), |
| | "emb_gru_skip_enc": config.get("deepfilternet", "emb_gru_skip_enc", fallback="none"), |
| | "emb_gru_skip": config.get("deepfilternet", "emb_gru_skip", fallback="none"), |
| | "df_gru_skip": config.get("deepfilternet", "df_gru_skip", fallback="groupedlinear"), |
| | "dfop_method": config.get("deepfilternet", "dfop_method", fallback="real_unfold"), |
| | } |
| | |
| | |
| | conv_kernel = config.get("deepfilternet", "conv_kernel", fallback="1,3") |
| | result["conv_kernel"] = [int(x) for x in conv_kernel.split(",")] |
| | |
| | convt_kernel = config.get("deepfilternet", "convt_kernel", fallback="1,3") |
| | result["convt_kernel"] = [int(x) for x in convt_kernel.split(",")] |
| | |
| | conv_kernel_inp = config.get("deepfilternet", "conv_kernel_inp", fallback="3,3") |
| | result["conv_kernel_inp"] = [int(x) for x in conv_kernel_inp.split(",")] |
| | |
| | return result |
| |
|
| |
|
| | def convert_pytorch_to_mlx( |
| | checkpoint_path: Path, |
| | config_path: Path, |
| | output_dir: Path, |
| | model_name: str = "DeepFilterNet3", |
| | ): |
| | """Convert PyTorch checkpoint to MLX format with proper weight mapping.""" |
| | |
| | print(f"Loading checkpoint from {checkpoint_path}") |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| | |
| | |
| | if "state_dict" in ckpt: |
| | state_dict = ckpt["state_dict"] |
| | elif "model_state_dict" in ckpt: |
| | state_dict = ckpt["model_state_dict"] |
| | else: |
| | state_dict = ckpt |
| | |
| | print(f"Found {len(state_dict)} parameters in checkpoint") |
| | |
| | |
| | print(f"Parsing config from {config_path}") |
| | config_dict = parse_config(config_path) |
| | config_dict["model_version"] = model_name |
| | |
| | |
| | print("\nPyTorch weight shapes:") |
| | for key, value in list(state_dict.items())[:20]: |
| | print(f" {key}: {tuple(value.shape)}") |
| | print(" ...") |
| | |
| | |
| | print("\nConverting weights to MLX format...") |
| | mlx_weights = {} |
| | |
| | for key, value in state_dict.items(): |
| | |
| | if "num_batches_tracked" in key: |
| | continue |
| | |
| | |
| | mlx_array = convert_weight(value) |
| | mlx_weights[key] = mlx_array |
| | |
| | print(f"Converted {len(mlx_weights)} weights") |
| | |
| | |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | weights_path = output_dir / "model.safetensors" |
| | print(f"Saving weights to {weights_path}") |
| | mx.save_safetensors(str(weights_path), mlx_weights) |
| | |
| | |
| | config_out_path = output_dir / "config.json" |
| | print(f"Saving config to {config_out_path}") |
| | with open(config_out_path, "w") as f: |
| | json.dump(config_dict, f, indent=2) |
| | |
| | print(f"\nConversion complete! Output saved to {output_dir}") |
| | print(f" - model.safetensors: {weights_path.stat().st_size / 1024 / 1024:.1f} MB") |
| | print(f" - config.json") |
| | |
| | return mlx_weights, config_dict |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Convert DeepFilterNet PyTorch weights to MLX") |
| | parser.add_argument("--input", type=str, required=True, help="Path to DeepFilterNet model directory") |
| | parser.add_argument("--output", type=str, required=True, help="Output directory for MLX model") |
| | parser.add_argument("--name", type=str, default="DeepFilterNet3", help="Model name") |
| | args = parser.parse_args() |
| | |
| | input_dir = Path(args.input) |
| | output_dir = Path(args.output) |
| | |
| | |
| | checkpoint_dir = input_dir / "checkpoints" |
| | if checkpoint_dir.exists(): |
| | |
| | checkpoints = list(checkpoint_dir.glob("*.best")) |
| | if not checkpoints: |
| | checkpoints = list(checkpoint_dir.glob("*.ckpt")) |
| | if checkpoints: |
| | checkpoint_path = checkpoints[0] |
| | else: |
| | raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") |
| | else: |
| | raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") |
| | |
| | |
| | config_path = input_dir / "config.ini" |
| | if not config_path.exists(): |
| | raise FileNotFoundError(f"Config file not found: {config_path}") |
| | |
| | convert_pytorch_to_mlx(checkpoint_path, config_path, output_dir, args.name) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|