bruAristimunha commited on
Commit
1daad10
·
verified ·
1 Parent(s): 6318a15

Replace with clean markdown card

Browse files
Files changed (1) hide show
  1. README.md +35 -296
README.md CHANGED
@@ -14,13 +14,12 @@ tags:
14
 
15
  # AttentionBaseNet
16
 
17
- AttentionBaseNet from Wimpff M et al (2023) .
18
 
19
- > **Architecture-only repository.** This repo documents the
20
  > `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
21
- > distributed here** instantiate the model and train it on your own
22
- > data, or fine-tune from a published foundation-model checkpoint
23
- > separately.
24
 
25
  ## Quick start
26
 
@@ -39,314 +38,54 @@ model = AttentionBaseNet(
39
  )
40
  ```
41
 
42
- The signal-shape arguments above are example defaults — adjust them
43
- to match your recording.
44
 
45
  ## Documentation
46
-
47
- - Full API reference (parameters, references, architecture figure):
48
- <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
49
- - Interactive browser with live instantiation:
50
  <https://huggingface.co/spaces/braindecode/model-explorer>
51
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
52
 
53
- ## Architecture description
54
-
55
- The block below is the rendered class docstring (parameters,
56
- references, architecture figure where available).
57
-
58
- <div class='bd-doc'><main>
59
- <p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
60
- <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
-
62
-
63
-
64
- .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
65
- :align: center
66
- :alt: AttentionBaseNet Architecture
67
- :width: 640px
68
-
69
- .. rubric:: Architectural Overview
70
-
71
- AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
72
- The end-to-end flow is:
73
-
74
- - (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
75
- projections (depthwise across electrodes), then condenses time by pooling;
76
- - (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
77
- - (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
78
- convs and an optional channel-attention module (SE/CBAM/ECA/…);
79
- - (iv) **Classifier** flattens the sequence and applies a linear readout.
80
-
81
- This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
82
- attention unit that *re-weights channels* (and optionally temporal positions) before
83
- classification.
84
-
85
- .. rubric:: Macro Components
86
-
87
- - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
88
-
89
- - *Operations.*
90
- - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
91
- FIR-like filter bank with ``n_temporal_filters`` maps.
92
- - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
93
- with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
94
- - **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
95
- - Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.
96
-
97
- *Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
98
- depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
99
- integrator that reduces variance on short EEG windows.
100
-
101
- - **Channel Expansion**
102
-
103
- - *Operations.*
104
- - A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
105
- the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
106
- This sets the embedding width for the attention block.
107
-
108
- - :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**
109
-
110
- - *Operations.*
111
- - **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
112
- BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
113
- - **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
114
- (some variants also apply temporal gating).
115
- - **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
116
- ``(B, ch_dim, 1, T₂)``.
117
-
118
- *Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
119
- before the classifier; complements the convolutional priors with adaptive re-weighting.
120
-
121
- - **Classifier (aggregation + readout)**
122
-
123
- *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
124
- ``(B, ch_dim·T₂)`` to classes.
125
-
126
- .. rubric:: Convolutional Details
127
-
128
- - **Temporal (where time-domain patterns are learned).**
129
- Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
130
- bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
131
- short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
132
- set the token rate and effective temporal resolution.
133
-
134
- - **Spatial (how electrodes are processed).**
135
- A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
136
- learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
137
- mirroring the interpretable spatial stage in shallow CNNs.
138
-
139
- - **Spectral (how frequency content is captured).**
140
- No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
141
- emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
142
- channel attention (DCT-based) summarizes frequencies to drive channel weights.
143
-
144
- .. rubric:: Attention / Sequential Modules
145
-
146
- - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
147
- EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
148
- include temporal attention.
149
-
150
- - **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
151
- (if any) are internal to the module; the block returns the same shape before pooling.
152
-
153
- - **Role.** Re-weights channels (and optionally time) to highlight informative sources
154
- and suppress distractors, improving SNR ahead of the linear head.
155
-
156
- .. rubric:: Additional Mechanisms
157
-
158
- **Attention variants at a glance:**
159
-
160
- - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
161
- - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
162
- - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
163
- - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
164
- - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
165
- - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
166
- - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
167
- - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
168
- - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
169
- - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
170
-
171
- **Auto-compatibility on short inputs:**
172
-
173
- If the input duration is too short for the configured kernels/pools, the implementation
174
- **automatically rescales** temporal lengths/strides downward (with a warning) to keep
175
- shapes valid and preserve the pipeline semantics.
176
-
177
- .. rubric:: Usage and Configuration
178
-
179
- - ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
180
- control the capacity and the number of spatial projections in the stem.
181
- - ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
182
- trade temporal resolution for compute; they determine the final sequence length ``T₂``.
183
- - ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
184
- - ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
185
- ``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
186
- select and tune the reweighting mechanism.
187
- - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
188
- - **Training tips.**
189
-
190
- Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
191
- only after the stem learns stable filters. For small datasets, prefer simpler modes
192
- (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
193
-
194
- Parameters
195
- ----------
196
- n_temporal_filters : int, optional
197
- Number of temporal convolutional filters in the first layer. This defines
198
- the number of output channels after the temporal convolution.
199
- Default is 40.
200
- temp_filter_length : int, default=15
201
- The length of the temporal filters in the convolutional layers.
202
- spatial_expansion : int, optional
203
- Multiplicative factor to expand the spatial dimensions. Used to increase
204
- the capacity of the model by expanding spatial features. Default is 1.
205
- pool_length_inp : int, optional
206
- Length of the pooling window in the input layer. Determines how much
207
- temporal information is aggregated during pooling. Default is 75.
208
- pool_stride_inp : int, optional
209
- Stride of the pooling operation in the input layer. Controls the
210
- downsampling factor in the temporal dimension. Default is 15.
211
- drop_prob_inp : float, optional
212
- Dropout rate applied after the input layer. This is the probability of
213
- zeroing out elements during training to prevent overfitting.
214
- Default is 0.5.
215
- ch_dim : int, optional
216
- Number of channels in the subsequent convolutional layers. This controls
217
- the depth of the network after the initial layer. Default is 16.
218
- attention_mode : str, optional
219
- The type of attention mechanism to apply. If `None`, no attention is applied.
220
-
221
- - "se" for Squeeze-and-excitation network
222
- - "gsop" for Global Second-Order Pooling
223
- - "fca" for Frequency Channel Attention Network
224
- - "encnet" for context encoding module
225
- - "eca" for Efficient channel attention for deep convolutional neural networks
226
- - "ge" for Gather-Excite
227
- - "gct" for Gated Channel Transformation
228
- - "srm" for Style-based Recalibration Module
229
- - "cbam" for Convolutional Block Attention Module
230
- - "cat" for Learning to collaborate channel and temporal attention
231
- from multi-information fusion
232
- - "catlite" for Learning to collaborate channel attention
233
- from multi-information fusion (lite version, cat w/o temporal attention)
234
-
235
- pool_length : int, default=8
236
- The length of the window for the average pooling operation.
237
- pool_stride : int, default=8
238
- The stride of the average pooling operation.
239
- drop_prob_attn : float, default=0.5
240
- The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
241
- reduction_rate : int, default=4
242
- The reduction rate used in the attention mechanism to reduce dimensionality
243
- and computational complexity.
244
- use_mlp : bool, default=False
245
- Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
246
- the attention mechanism for further processing.
247
- freq_idx : int, default=0
248
- DCT index used in fca attention mechanism.
249
- n_codewords : int, default=4
250
- The number of codewords (clusters) used in attention mechanisms that employ
251
- quantization or clustering strategies.
252
- kernel_size : int, default=9
253
- The kernel size used in certain types of attention mechanisms for convolution
254
- operations.
255
- activation : type[nn.Module] = nn.ELU,
256
- Activation function class to apply. Should be a PyTorch activation
257
- module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
258
- extra_params : bool, default=False
259
- Flag to indicate whether additional, custom parameters should be passed to
260
- the attention mechanism.
261
-
262
- Notes
263
- -----
264
- - Sequence length after each stage is computed internally; the final classifier expects
265
- a flattened ``ch_dim x T₂`` vector.
266
- - Attention operates on *channel* dimension by design; temporal gating exists only in
267
- specific variants (CBAM/CAT).
268
- - The paper and original code with more details about the methodological
269
- choices are available at the [Martin2023]_ and [MartinCode]_.
270
-
271
- .. versionadded:: 0.9
272
-
273
- References
274
- ----------
275
- .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
276
- EEG motor imagery decoding: A framework for comparative analysis with
277
- channel attention mechanisms. arXiv preprint arXiv:2310.11198.
278
- .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
279
- GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
280
-
281
- .. rubric:: Hugging Face Hub integration
282
-
283
- When the optional ``huggingface_hub`` package is installed, all models
284
- automatically gain the ability to be pushed to and loaded from the
285
- Hugging Face Hub. Install with::
286
-
287
- pip install braindecode[hub]
288
-
289
- **Pushing a model to the Hub:**
290
-
291
- .. code::
292
- from braindecode.models import AttentionBaseNet
293
-
294
- # Train your model
295
- model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
296
- # ... training code ...
297
-
298
- # Push to the Hub
299
- model.push_to_hub(
300
- repo_id="username/my-attentionbasenet-model",
301
- commit_message="Initial model upload",
302
- )
303
-
304
- **Loading a model from the Hub:**
305
-
306
- .. code::
307
- from braindecode.models import AttentionBaseNet
308
-
309
- # Load pretrained model
310
- model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
311
-
312
- # Load with a different number of outputs (head is rebuilt automatically)
313
- model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
314
-
315
- **Extracting features and replacing the head:**
316
 
317
- .. code::
318
- import torch
319
 
320
- x = torch.randn(1, model.n_chans, model.n_times)
321
- # Extract encoder features (consistent dict across all models)
322
- out = model(x, return_features=True)
323
- features = out["features"]
324
 
325
- # Replace the classification head
326
- model.reset_head(n_outputs=10)
327
 
328
- **Saving and restoring full configuration:**
329
 
330
- .. code::
331
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
 
333
- config = model.get_config() # all __init__ params
334
- with open("config.json", "w") as f:
335
- json.dump(config, f)
336
 
337
- model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights)
338
 
339
- All model parameters (both EEG-specific and model-specific such as
340
- dropout rates, activation functions, number of filters) are automatically
341
- saved to the Hub and restored when loading.
342
 
343
- See :ref:`load-pretrained-models` for a complete tutorial.</main>
344
- </div>
345
 
346
  ## Citation
347
 
348
- Please cite both the original paper for this architecture (see the
349
- *References* section above) and braindecode:
350
 
351
  ```bibtex
352
  @article{aristimunha2025braindecode,
 
14
 
15
  # AttentionBaseNet
16
 
17
+ AttentionBaseNet from Wimpff M et al (2023) [Martin2023].
18
 
19
+ > **Architecture-only repository.** Documents the
20
  > `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
21
+ > distributed here.** Instantiate the model and train it on your own
22
+ > data.
 
23
 
24
  ## Quick start
25
 
 
38
  )
39
  ```
40
 
41
+ The signal-shape arguments above are illustrative defaults — adjust to
42
+ match your recording.
43
 
44
  ## Documentation
45
+ - Full API reference: <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
46
+ - Interactive browser (live instantiation, parameter counts):
 
 
47
  <https://huggingface.co/spaces/braindecode/model-explorer>
48
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ ## Architecture
 
52
 
53
+ ![AttentionBaseNet architecture](https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg)
 
 
 
54
 
 
 
55
 
56
+ ## Parameters
57
 
58
+ | Parameter | Type | Description |
59
+ |---|---|---|
60
+ | `n_temporal_filters` | int, optional | Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40. |
61
+ | `temp_filter_length` | int, default=15 | The length of the temporal filters in the convolutional layers. |
62
+ | `spatial_expansion` | int, optional | Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1. |
63
+ | `pool_length_inp` | int, optional | Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75. |
64
+ | `pool_stride_inp` | int, optional | Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15. |
65
+ | `drop_prob_inp` | float, optional | Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5. |
66
+ | `ch_dim` | int, optional | Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16. |
67
+ | `attention_mode` | str, optional | The type of attention mechanism to apply. If `None`, no attention is applied. - "se" for Squeeze-and-excitation network - "gsop" for Global Second-Order Pooling - "fca" for Frequency Channel Attention Network - "encnet" for context encoding module - "eca" for Efficient channel attention for deep convolutional neural networks - "ge" for Gather-Excite - "gct" for Gated Channel Transformation - "srm" for Style-based Recalibration Module - "cbam" for Convolutional Block Attention Module - "cat" for Learning to collaborate channel and temporal attention from multi-information fusion - "catlite" for Learning to collaborate channel attention from multi-information fusion (lite version, cat w/o temporal attention) |
68
+ | `pool_length` | int, default=8 | The length of the window for the average pooling operation. |
69
+ | `pool_stride` | int, default=8 | The stride of the average pooling operation. |
70
+ | `drop_prob_attn` | float, default=0.5 | The dropout rate for regularization for the attention layer. Values should be between 0 and 1. |
71
+ | `reduction_rate` | int, default=4 | The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity. |
72
+ | `use_mlp` | bool, default=False | Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing. |
73
+ | `freq_idx` | int, default=0 | DCT index used in fca attention mechanism. |
74
+ | `n_codewords` | int, default=4 | The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies. |
75
+ | `kernel_size` | int, default=9 | The kernel size used in certain types of attention mechanisms for convolution operations. |
76
+ | `activation` | type[nn.Module] = nn.ELU, | Activation function class to apply. Should be a PyTorch activation module class like `nn.ReLU` or `nn.ELU`. Default is `nn.ELU`. |
77
+ | `extra_params` | bool, default=False | Flag to indicate whether additional, custom parameters should be passed to the attention mechanism. |
78
 
 
 
 
79
 
80
+ ## References
81
 
82
+ 1. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198.
83
+ 2. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
 
84
 
 
 
85
 
86
  ## Citation
87
 
88
+ Cite the original architecture paper (see *References* above) and braindecode:
 
89
 
90
  ```bibtex
91
  @article{aristimunha2025braindecode,