bruAristimunha commited on
Commit
2e2923c
·
verified ·
1 Parent(s): 8350b19

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +353 -0
README.md ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ ---
13
+
14
+ # BrainModule
15
+
16
+ BrainModule from , also known as SimpleConv.
17
+
18
+ > **Architecture-only repository.** This repo documents the
19
+ > `braindecode.models.BrainModule` class. **No pretrained weights are
20
+ > distributed here** — instantiate the model and train it on your own
21
+ > data, or fine-tune from a published foundation-model checkpoint
22
+ > separately.
23
+
24
+ ## Quick start
25
+
26
+ ```bash
27
+ pip install braindecode
28
+ ```
29
+
30
+ ```python
31
+ from braindecode.models import BrainModule
32
+
33
+ model = BrainModule(
34
+ n_chans=22,
35
+ sfreq=250,
36
+ input_window_seconds=4.0,
37
+ n_outputs=4,
38
+ )
39
+ ```
40
+
41
+ The signal-shape arguments above are example defaults — adjust them
42
+ to match your recording.
43
+
44
+ ## Documentation
45
+
46
+ - Full API reference (parameters, references, architecture figure):
47
+ <https://braindecode.org/stable/generated/braindecode.models.BrainModule.html>
48
+ - Interactive browser with live instantiation:
49
+ <https://huggingface.co/spaces/braindecode/model-explorer>
50
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/brainmodule.py#L25>
51
+
52
+ ## Architecture description
53
+
54
+ The block below is the rendered class docstring (parameters,
55
+ references, architecture figure where available).
56
+
57
+ <div class='bd-doc'><main>
58
+ <p>BrainModule from [brainmagick]_, also known as SimpleConv.</p>
59
+ <blockquote>
60
+ <p>A dilated convolutional encoder for EEG decoding, using residual
61
+ connections and optional GLU gating for improved expressivity.</p>
62
+ </blockquote>
63
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span>
64
+
65
+
66
+
67
+ .. figure:: ../_static/model/simpleconv.png
68
+ :align: center
69
+ :alt: BrainModule Architecture
70
+ :width: 500px
71
+
72
+ Figure adapted Extended Data Fig. 4 from [brainmagick]_ to highlight only the model part.
73
+ Architecture of the brain module. Architecture used to process the brain recordings.
74
+ For each layer, the authors note first the number of output channels, while the number of time steps
75
+ is constant throughout the layers. The model is composed of a spatial attention layer,
76
+ then a 1x1 convolution without activation. A 'Subject Layer' is selected based on the subject index s,
77
+ which consists in a 1x1 convolution learnt only for that subject with no activation. Then,
78
+ the authors apply five convolutional blocks made of three convolutions. The first
79
+ two use residual skip connection and increasing dilation, followed by a BatchNorm layer and a
80
+ GELU activation. The third convolution is not residual, and uses a GLU activation
81
+ (which halves the number of channels) and no normalization.
82
+ Finally, the authors apply two 1x1 convolutions with a GELU in between.
83
+
84
+ The BrainModule (also referred to as SimpleConv) is a deep dilated
85
+ convolutional encoder specifically designed to decode perceived speech from
86
+ non-invasive brain recordings like EEG and MEG. It is engineered to address
87
+ the high noise levels and inter-individual variability inherent in
88
+ non-invasive neuroimaging by using a single architecture trained across
89
+ large cohorts while accommodating participant-specific differences.
90
+
91
+ .. rubric:: Architecture Overview
92
+
93
+ The BrainModule integrates three primary mechanisms to align brain activity
94
+ with deep speech representations:
95
+
96
+ 1. **Spatial-temporal feature extraction.** The model uses a dedicated
97
+ spatial attention layer to remap sensor data based on physical
98
+ locations, followed by temporal processing through dilated convolutions.
99
+ 2. **Subject-specific adaptation.** To leverage inter-subject variability,
100
+ the architecture includes a "Subject Layer" or participant-specific
101
+ 1x1 convolution that allows the model to share core weights across a
102
+ cohort while learning individual-specific neural patterns.
103
+ 3. **Dilated residual blocks with gating.** The core encoder employs a
104
+ stack of convolutional blocks featuring skip connections and increasing
105
+ dilation to expand the receptive field without losing temporal
106
+ resolution, supplemented by optional Gated Linear Units (GLU) for
107
+ increased expressivity.
108
+
109
+ .. rubric:: Macro Components
110
+
111
+ ``BrainModule.input_projection`` (Initial Processing)
112
+ **Operations.** Raw M/EEG input
113
+ :math:`\mathbf{X} \in \mathbb{R}^{C \times T}` is first processed
114
+ through a spatial attention layer that projects sensor locations onto a
115
+ 2D plane using Fourier-parameterized functions. This is followed by a
116
+ subject-specific 1x1 convolution
117
+ :math:`\mathbf{M}_s \in \mathbb{R}^{D_1 \times D_1}` if subject
118
+ features are enabled. The resulting features are projected to the
119
+ ``hidden_dim`` (default 320) to ensure compatibility with subsequent
120
+ residual connections.
121
+
122
+ **Role.** Converts high-dimensional, subject-dependent sensor data into
123
+ a standardized latent space while preserving spatial and temporal
124
+ relationships.
125
+
126
+ ``BrainModule.encoder`` (Convolutional Sequence)
127
+ **Operations.** Implemented via
128
+ :class:`~braindecode.models.brainmodule._ConvSequence`, this component
129
+ consists of a stack of ``k`` convolutional blocks. Each block typically
130
+ contains: (a) **Residual dilated convolutions.** Two layers with kernel
131
+ size 3, residual skip connections, and dilation factors that grow
132
+ exponentially (e.g., powers of two with periodic resets) to capture
133
+ multi-scale temporal context. (b) **GLU gating.** Every ``N`` layers
134
+ (defined by ``glu``), a Gated Linear Unit is applied, which halves the
135
+ channel dimension and introduces non-linear gating to filter
136
+ intermediate representations.
137
+
138
+ **Role.** Extracts deep hierarchical temporal features from the brain
139
+ signal, significantly expanding the model's receptive field to align
140
+ with the contextual windows of speech modules like wav2vec 2.0.
141
+
142
+ .. rubric:: Temporal, Spatial, and Spectral Encoding
143
+
144
+ - **Temporal:** Increasing dilation factors across layers allow the model to
145
+ integrate information over large time windows without the computational
146
+ cost of standard large kernels, while a 150 ms input shift facilitates
147
+ alignment between stimulus and brain response.
148
+ - **Spatial:** The spatial attention layer learns a softmax weighting over
149
+ input sensors based on their 3D coordinates, allowing the model to focus
150
+ on regions typically activated during auditory stimulation (e.g., the
151
+ temporal cortex).
152
+ - **Spectral:** Through the optional ``n_fft`` parameter, the model can
153
+ apply an STFT transformation, converting time-domain signals into a
154
+ spectrogram representation before encoding.
155
+
156
+ .. rubric:: Additional Mechanisms
157
+
158
+ - **Clamping and scaling:** The model relies on clamping input values
159
+ (e.g., at 20 standard deviations) to prevent outliers and large
160
+ electromagnetic artifacts from destabilizing the BatchNorm estimates and
161
+ optimization process.
162
+ - **Scaled subject embeddings:** When ``subject_dim`` is used, the
163
+ :class:`~braindecode.models.brainmodule._ScaledEmbedding` layer scales up
164
+ the learning rate for subject-specific features to prevent slow
165
+ convergence in multi-participant training.
166
+
167
+
168
+ - **_ConvSequence and residual logic:** This class handles the actual
169
+ stacking of layers. It is designed to be flexible with the ``growth``
170
+ parameter; if the channel size changes between layers (``growth != 1.0``),
171
+ it automatically applies a 1x1 ``skip_projection`` convolution to the
172
+ residual path so dimensions match for addition.
173
+ - **_ChannelDropout:** Unlike standard dropout which zeroes individual
174
+ neurons, this zeroes entire channels. It includes a rescale feature that
175
+ multiplies the remaining channels by a factor
176
+ ``total_channels / active_channels`` to maintain the expected value of the
177
+ signal during training.
178
+ - **_ScaledEmbedding:** This is a clever optimization for multi-subject
179
+ learning. By dividing the initial weights by a scale and then multiplying
180
+ the output by the same scale, it effectively increases the gradient
181
+ magnitude for the embedding weights, allowing subject-specific features to
182
+ learn faster than the shared backbone.
183
+
184
+
185
+ Parameters
186
+ ----------
187
+ hidden_dim : int, default=320
188
+ Hidden dimension for convolutional layers. Input is projected to this
189
+ dimension before the convolutional blocks.
190
+ depth : int, default=10
191
+ Number of convolutional blocks. Each block contains a dilated convolution
192
+ with batch normalization and activation, followed by a residual connection.
193
+ kernel_size : int, default=3
194
+ Convolutional kernel size. Must be odd for proper padding with dilation.
195
+ growth : float, default=1.0
196
+ Channel size multiplier: hidden_dim * (growth ** layer_index).
197
+ Values > 1.0 grow channels deeper; < 1.0 shrink them.
198
+ Note: growth != 1.0 disables residual connections between layers
199
+ with different channel sizes.
200
+ dilation_growth : int, default=2
201
+ Dilation multiplier per layer (e.g., 2 means dilation doubles each layer).
202
+ Improves receptive field exponentially. Requires odd kernel_size.
203
+ dilation_period : int, default=5
204
+ Reset dilation to 1 every N layers. Prevents dilation from growing
205
+ too large and maintains local connectivity.
206
+ conv_drop_prob : float, default=0.0
207
+ Dropout probability for convolutional layers.
208
+ dropout_input : float, default=0.0
209
+ Dropout probability applied to model input only.
210
+ batch_norm : bool, default=True
211
+ If True, apply batch normalization after each convolution.
212
+ activation : type[nn.Module], default=nn.GELU
213
+ Activation function class to use (e.g., nn.GELU, nn.ReLU, nn.ELU).
214
+ n_subjects : int, default=200
215
+ Number of unique subjects (for subject-specific pathways).
216
+ Only used if subject_dim > 0.
217
+ subject_dim : int, default=0
218
+ Dimension of subject embeddings. If 0, no subject-specific features.
219
+ If > 0, adds subject embeddings to the input before encoding.
220
+ subject_layers : bool, default=False
221
+ If True, apply subject-specific linear transformations to input channels.
222
+ Each subject has its own weight matrix. Requires subject_dim > 0.
223
+ subject_layers_dim : str, default="input"
224
+ Where to apply subject layers: "input" or "hidden".
225
+ subject_layers_id : bool, default=False
226
+ If True, initialize subject layers as identity matrices.
227
+ embedding_scale : float, default=1.0
228
+ Scaling factor for subject embeddings learning rate.
229
+ n_fft : int, optional
230
+ FFT size for STFT processing. If None, no STFT is applied.
231
+ If specified, applies spectrogram transform before encoding.
232
+ fft_complex : bool, default=True
233
+ If True, keep complex spectrogram. If False, use power spectrogram.
234
+ Only used when n_fft is not None.
235
+ channel_dropout_prob : float, default=0.0
236
+ Probability of dropping each channel during training (0.0 to 1.0).
237
+ If 0.0, no channel dropout is applied.
238
+ channel_dropout_type : str, optional
239
+ If specified with chs_info, only drop channels of this type
240
+ (e.g., 'eeg', 'ref', 'eog'). If None with dropout_prob > 0, drops any channel.
241
+ glu : int, default=2
242
+ If > 0, applies Gated Linear Units (GLU) every N convolutional layers.
243
+ GLUs gate intermediate representations for more expressivity.
244
+ If 0, no GLU is applied.
245
+ glu_context : int, default=1
246
+ Context window size for GLU gates. If > 0, uses contextual information
247
+ from neighboring time steps for gating. Requires glu > 0.
248
+
249
+ References
250
+ ----------
251
+ .. [brainmagick] Défossez, A., Caucheteux, C., Rapin, J., Kabeli, O., & King, J. R.
252
+ (2023). Decoding speech perception from non-invasive brain recordings. Nature
253
+ Machine Intelligence, 5(10), 1097-1107.
254
+
255
+ Notes
256
+ -----
257
+ - Input shape: (batch, n_chans, n_times)
258
+ - Output shape: (batch, n_outputs)
259
+ - The model uses dilated convolutions with stride=1 to maintain temporal
260
+ resolution while achieving large receptive fields.
261
+ - Residual connections are applied at every layer where input and output
262
+ channels match.
263
+ - Subject-specific features (subject_dim > 0, subject_layers) require passing
264
+ subject indices in the forward pass as an optional parameter or via batch.
265
+ - STFT processing (n_fft > 0) automatically transforms input to spectrogram domain.
266
+
267
+ .. versionadded:: 1.2
268
+
269
+ .. rubric:: Hugging Face Hub integration
270
+
271
+ When the optional ``huggingface_hub`` package is installed, all models
272
+ automatically gain the ability to be pushed to and loaded from the
273
+ Hugging Face Hub. Install with::
274
+
275
+ pip install braindecode[hub]
276
+
277
+ **Pushing a model to the Hub:**
278
+
279
+ .. code::
280
+ from braindecode.models import BrainModule
281
+
282
+ # Train your model
283
+ model = BrainModule(n_chans=22, n_outputs=4, n_times=1000)
284
+ # ... training code ...
285
+
286
+ # Push to the Hub
287
+ model.push_to_hub(
288
+ repo_id="username/my-brainmodule-model",
289
+ commit_message="Initial model upload",
290
+ )
291
+
292
+ **Loading a model from the Hub:**
293
+
294
+ .. code::
295
+ from braindecode.models import BrainModule
296
+
297
+ # Load pretrained model
298
+ model = BrainModule.from_pretrained("username/my-brainmodule-model")
299
+
300
+ # Load with a different number of outputs (head is rebuilt automatically)
301
+ model = BrainModule.from_pretrained("username/my-brainmodule-model", n_outputs=4)
302
+
303
+ **Extracting features and replacing the head:**
304
+
305
+ .. code::
306
+ import torch
307
+
308
+ x = torch.randn(1, model.n_chans, model.n_times)
309
+ # Extract encoder features (consistent dict across all models)
310
+ out = model(x, return_features=True)
311
+ features = out["features"]
312
+
313
+ # Replace the classification head
314
+ model.reset_head(n_outputs=10)
315
+
316
+ **Saving and restoring full configuration:**
317
+
318
+ .. code::
319
+ import json
320
+
321
+ config = model.get_config() # all __init__ params
322
+ with open("config.json", "w") as f:
323
+ json.dump(config, f)
324
+
325
+ model2 = BrainModule.from_config(config) # reconstruct (no weights)
326
+
327
+ All model parameters (both EEG-specific and model-specific such as
328
+ dropout rates, activation functions, number of filters) are automatically
329
+ saved to the Hub and restored when loading.
330
+
331
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
332
+ </div>
333
+
334
+ ## Citation
335
+
336
+ Please cite both the original paper for this architecture (see the
337
+ *References* section above) and braindecode:
338
+
339
+ ```bibtex
340
+ @article{aristimunha2025braindecode,
341
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
342
+ author = {Aristimunha, Bruno and others},
343
+ journal = {Zenodo},
344
+ year = {2025},
345
+ doi = {10.5281/zenodo.17699192},
346
+ }
347
+ ```
348
+
349
+ ## License
350
+
351
+ BSD-3-Clause for the model code (matching braindecode).
352
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
353
+ inherit the licence of that checkpoint and its training corpus.