Kyle Howells commited on
Commit
78e496a
·
1 Parent(s): 3fd7249

Add DeepFilterNet3 MLX model weights, config, and conversion script

Browse files

- model.safetensors: Pre-converted MLX weights (8.3MB)
- config.json: Model architecture configuration
- convert_deepfilternet.py: PyTorch to MLX safetensors conversion script

Files changed (3) hide show
  1. config.json +46 -0
  2. convert_deepfilternet.py +196 -0
  3. model.safetensors +3 -0
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sample_rate": 48000,
3
+ "fft_size": 960,
4
+ "hop_size": 480,
5
+ "nb_erb": 32,
6
+ "nb_df": 96,
7
+ "df_order": 5,
8
+ "df_lookahead": 2,
9
+ "lsnr_max": 35,
10
+ "lsnr_min": -15,
11
+ "conv_ch": 64,
12
+ "conv_k_enc": 1,
13
+ "conv_k_dec": 1,
14
+ "conv_width_factor": 1,
15
+ "conv_dec_mode": "transposed",
16
+ "emb_hidden_dim": 256,
17
+ "emb_num_layers": 3,
18
+ "df_hidden_dim": 256,
19
+ "df_num_layers": 2,
20
+ "gru_groups": 8,
21
+ "linear_groups": 16,
22
+ "enc_linear_groups": 32,
23
+ "group_shuffle": false,
24
+ "mask_pf": false,
25
+ "conv_lookahead": 2,
26
+ "conv_depthwise": true,
27
+ "convt_depthwise": false,
28
+ "enc_concat": false,
29
+ "emb_gru_skip_enc": "none",
30
+ "emb_gru_skip": "none",
31
+ "df_gru_skip": "groupedlinear",
32
+ "dfop_method": "df",
33
+ "conv_kernel": [
34
+ 1,
35
+ 3
36
+ ],
37
+ "convt_kernel": [
38
+ 1,
39
+ 3
40
+ ],
41
+ "conv_kernel_inp": [
42
+ 3,
43
+ 3
44
+ ],
45
+ "model_version": "DeepFilterNet3"
46
+ }
convert_deepfilternet.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert DeepFilterNet PyTorch weights to MLX format.
4
+
5
+ This script converts pretrained DeepFilterNet models from the original
6
+ PyTorch implementation to MLX-compatible format with proper weight mapping.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import configparser
12
+ from pathlib import Path
13
+ from typing import Dict, Any, List, Tuple
14
+ import re
15
+
16
+ import mlx.core as mx
17
+ import numpy as np
18
+ import torch
19
+
20
+
21
+ def convert_weight(weight: torch.Tensor) -> mx.array:
22
+ """Convert PyTorch tensor to MLX array."""
23
+ return mx.array(weight.detach().cpu().numpy())
24
+
25
+
26
+ def parse_config(config_path: Path) -> Dict[str, Any]:
27
+ """Parse DeepFilterNet config.ini file."""
28
+ config = configparser.ConfigParser()
29
+ config.read(config_path)
30
+
31
+ linear_groups = config.getint("deepfilternet", "linear_groups", fallback=16)
32
+ df_order = config.getint(
33
+ "df",
34
+ "df_order",
35
+ fallback=config.getint("deepfilternet", "df_order", fallback=5),
36
+ )
37
+ df_lookahead = config.getint(
38
+ "df",
39
+ "df_lookahead",
40
+ fallback=config.getint("deepfilternet", "df_lookahead", fallback=0),
41
+ )
42
+
43
+ result = {
44
+ # [df] section
45
+ "sample_rate": config.getint("df", "sr", fallback=48000),
46
+ "fft_size": config.getint("df", "fft_size", fallback=960),
47
+ "hop_size": config.getint("df", "hop_size", fallback=480),
48
+ "nb_erb": config.getint("df", "nb_erb", fallback=32),
49
+ "nb_df": config.getint("df", "nb_df", fallback=96),
50
+ "df_order": df_order,
51
+ "df_lookahead": df_lookahead,
52
+ "lsnr_max": config.getint("df", "lsnr_max", fallback=35),
53
+ "lsnr_min": config.getint("df", "lsnr_min", fallback=-15),
54
+
55
+ # [deepfilternet] section
56
+ "conv_ch": config.getint("deepfilternet", "conv_ch", fallback=64),
57
+ "conv_k_enc": config.getint("deepfilternet", "conv_k_enc", fallback=1),
58
+ "conv_k_dec": config.getint("deepfilternet", "conv_k_dec", fallback=1),
59
+ "conv_width_factor": config.getint("deepfilternet", "conv_width_factor", fallback=1),
60
+ "conv_dec_mode": config.get("deepfilternet", "conv_dec_mode", fallback="transposed"),
61
+ "emb_hidden_dim": config.getint("deepfilternet", "emb_hidden_dim", fallback=256),
62
+ "emb_num_layers": config.getint("deepfilternet", "emb_num_layers", fallback=3),
63
+ "df_hidden_dim": config.getint("deepfilternet", "df_hidden_dim", fallback=256),
64
+ "df_num_layers": config.getint("deepfilternet", "df_num_layers", fallback=2),
65
+ "gru_groups": config.getint("deepfilternet", "gru_groups", fallback=8),
66
+ "linear_groups": linear_groups,
67
+ # DeepFilterNet2 configs do not expose enc_linear_groups separately; in that case it
68
+ # should follow linear_groups to keep grouped-linear tensor shapes aligned.
69
+ "enc_linear_groups": config.getint("deepfilternet", "enc_linear_groups", fallback=linear_groups),
70
+ "group_shuffle": config.getboolean("deepfilternet", "group_shuffle", fallback=False),
71
+ "mask_pf": config.getboolean("deepfilternet", "mask_pf", fallback=False),
72
+ "conv_lookahead": config.getint("deepfilternet", "conv_lookahead", fallback=2),
73
+ "conv_depthwise": config.getboolean("deepfilternet", "conv_depthwise", fallback=True),
74
+ "convt_depthwise": config.getboolean("deepfilternet", "convt_depthwise", fallback=False),
75
+ "enc_concat": config.getboolean("deepfilternet", "enc_concat", fallback=False),
76
+ "emb_gru_skip_enc": config.get("deepfilternet", "emb_gru_skip_enc", fallback="none"),
77
+ "emb_gru_skip": config.get("deepfilternet", "emb_gru_skip", fallback="none"),
78
+ "df_gru_skip": config.get("deepfilternet", "df_gru_skip", fallback="groupedlinear"),
79
+ "dfop_method": config.get("deepfilternet", "dfop_method", fallback="real_unfold"),
80
+ }
81
+
82
+ # Parse conv_kernel strings
83
+ conv_kernel = config.get("deepfilternet", "conv_kernel", fallback="1,3")
84
+ result["conv_kernel"] = [int(x) for x in conv_kernel.split(",")]
85
+
86
+ convt_kernel = config.get("deepfilternet", "convt_kernel", fallback="1,3")
87
+ result["convt_kernel"] = [int(x) for x in convt_kernel.split(",")]
88
+
89
+ conv_kernel_inp = config.get("deepfilternet", "conv_kernel_inp", fallback="3,3")
90
+ result["conv_kernel_inp"] = [int(x) for x in conv_kernel_inp.split(",")]
91
+
92
+ return result
93
+
94
+
95
+ def convert_pytorch_to_mlx(
96
+ checkpoint_path: Path,
97
+ config_path: Path,
98
+ output_dir: Path,
99
+ model_name: str = "DeepFilterNet3",
100
+ ):
101
+ """Convert PyTorch checkpoint to MLX format with proper weight mapping."""
102
+
103
+ print(f"Loading checkpoint from {checkpoint_path}")
104
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
105
+
106
+ # Get state dict
107
+ if "state_dict" in ckpt:
108
+ state_dict = ckpt["state_dict"]
109
+ elif "model_state_dict" in ckpt:
110
+ state_dict = ckpt["model_state_dict"]
111
+ else:
112
+ state_dict = ckpt
113
+
114
+ print(f"Found {len(state_dict)} parameters in checkpoint")
115
+
116
+ # Parse config
117
+ print(f"Parsing config from {config_path}")
118
+ config_dict = parse_config(config_path)
119
+ config_dict["model_version"] = model_name
120
+
121
+ # Print weight shapes for debugging
122
+ print("\nPyTorch weight shapes:")
123
+ for key, value in list(state_dict.items())[:20]:
124
+ print(f" {key}: {tuple(value.shape)}")
125
+ print(" ...")
126
+
127
+ # Convert weights - direct mapping since we'll match the architecture
128
+ print("\nConverting weights to MLX format...")
129
+ mlx_weights = {}
130
+
131
+ for key, value in state_dict.items():
132
+ # Skip buffers that aren't needed for inference
133
+ if "num_batches_tracked" in key:
134
+ continue
135
+
136
+ # Convert weight
137
+ mlx_array = convert_weight(value)
138
+ mlx_weights[key] = mlx_array
139
+
140
+ print(f"Converted {len(mlx_weights)} weights")
141
+
142
+ # Create output directory
143
+ output_dir.mkdir(parents=True, exist_ok=True)
144
+
145
+ # Save weights
146
+ weights_path = output_dir / "model.safetensors"
147
+ print(f"Saving weights to {weights_path}")
148
+ mx.save_safetensors(str(weights_path), mlx_weights)
149
+
150
+ # Save config
151
+ config_out_path = output_dir / "config.json"
152
+ print(f"Saving config to {config_out_path}")
153
+ with open(config_out_path, "w") as f:
154
+ json.dump(config_dict, f, indent=2)
155
+
156
+ print(f"\nConversion complete! Output saved to {output_dir}")
157
+ print(f" - model.safetensors: {weights_path.stat().st_size / 1024 / 1024:.1f} MB")
158
+ print(f" - config.json")
159
+
160
+ return mlx_weights, config_dict
161
+
162
+
163
+ def main():
164
+ parser = argparse.ArgumentParser(description="Convert DeepFilterNet PyTorch weights to MLX")
165
+ parser.add_argument("--input", type=str, required=True, help="Path to DeepFilterNet model directory")
166
+ parser.add_argument("--output", type=str, required=True, help="Output directory for MLX model")
167
+ parser.add_argument("--name", type=str, default="DeepFilterNet3", help="Model name")
168
+ args = parser.parse_args()
169
+
170
+ input_dir = Path(args.input)
171
+ output_dir = Path(args.output)
172
+
173
+ # Find checkpoint
174
+ checkpoint_dir = input_dir / "checkpoints"
175
+ if checkpoint_dir.exists():
176
+ # Look for best checkpoint
177
+ checkpoints = list(checkpoint_dir.glob("*.best"))
178
+ if not checkpoints:
179
+ checkpoints = list(checkpoint_dir.glob("*.ckpt"))
180
+ if checkpoints:
181
+ checkpoint_path = checkpoints[0]
182
+ else:
183
+ raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")
184
+ else:
185
+ raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")
186
+
187
+ # Find config
188
+ config_path = input_dir / "config.ini"
189
+ if not config_path.exists():
190
+ raise FileNotFoundError(f"Config file not found: {config_path}")
191
+
192
+ convert_pytorch_to_mlx(checkpoint_path, config_path, output_dir, args.name)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fca0af2f25cad49d74fc9ac5f9155813416e4b350ec344bb433b4a48a9a76d38
3
+ size 8682709