bruAristimunha commited on
Commit
90a6a4f
·
verified ·
1 Parent(s): 3abf28b

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +364 -0
README.md ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - transformer
13
+ ---
14
+
15
+ # ATCNet
16
+
17
+ ATCNet from Altaheri et al (2022) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.ATCNet` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import ATCNet
33
+
34
+ model = ATCNet(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.ATCNet.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/atcnet.py#L15>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>ATCNet from Altaheri et al (2022) [1]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#6c757d;color:white;font-size:11px;font-weight:600;margin-right:4px;">Recurrent</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
+
62
+
63
+
64
+ .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
65
+ :align: center
66
+ :alt: ATCNet Architecture
67
+ :width: 650px
68
+
69
+ .. rubric:: Architectural Overview
70
+
71
+ ATCNet is a *convolution-first* architecture augmented with a *lightweight attention–TCN*
72
+ sequence module. The end-to-end flow is:
73
+
74
+ - (i) :class:`_ConvBlock` learns temporal filter-banks and spatial projections (EEGNet-style),
75
+ downsampling time to a compact feature map;
76
+
77
+ - (ii) Sliding Windows carve overlapping temporal windows from this map;
78
+
79
+ - (iii) for each window, :class:`_AttentionBlock` applies small multi-head self-attention
80
+ over time, followed by a :class:`_TCNResidualBlock` stack (causal, dilated);
81
+
82
+ - (iv) window-level features are aggregated (mean of window logits or concatenation)
83
+ and mapped via a max-norm–constrained linear layer.
84
+
85
+ Relative to ViT, ATCNet replaces linear patch projection with learned *temporal–spatial*
86
+ convolutions; it processes *parallel* window encoders (attention→TCN) instead of a deep
87
+ stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
88
+
89
+ .. rubric:: Macro Components
90
+
91
+ - :class:`_ConvBlock` **(Shallow conv stem → feature map)**
92
+
93
+ - *Operations.*
94
+ - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
95
+ FIR-like filter bank (``F1`` maps).
96
+ - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
97
+ ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
98
+ - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
99
+ - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
100
+ **BN → ELU → AvgPool → Dropout**.
101
+
102
+ The output shape is ``(B, F2, T_c, 1)`` with ``F2 = F1·D`` and ``T_c = T/(P1·P2)``.
103
+ Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific
104
+ topographies. Pooling acts as a local integrator, reducing variance and imposing a
105
+ useful inductive bias on short EEG windows.
106
+
107
+ - **Sliding-Window Sequencer**
108
+
109
+ From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
110
+ of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
111
+ ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
112
+ encoders over shifted contexts and is key to robustness on nonstationary EEG.
113
+
114
+ - :class:`_AttentionBlock` **(small MHA on temporal positions)**
115
+
116
+ Attention here is *local to a window* and purely temporal.
117
+
118
+ - *Operations.*
119
+ - Rearrange to ``(B, T_w, F2)``,
120
+ - Normalization :class:`torch.nn.LayerNorm`
121
+ - Custom MultiHeadAttention :class:`_MHA` (``num_heads=H``, per-head dim ``d_h``) + residual add,
122
+ - Dropout :class:`torch.nn.Dropout`
123
+ - Rearrange back to ``(B, F2, T_w)``.
124
+
125
+ *Role.* Re-weights evidence across the window, letting the model emphasize informative
126
+ segments (onsets, bursts) before causal convolutions aggregate history.
127
+
128
+ - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
129
+
130
+ *Operations:*
131
+
132
+ - Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
133
+ - Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
134
+ a residual (identity or 1x1 mapping).
135
+ - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
136
+
137
+ *Role.* Efficient long-range temporal integration with stable gradients; the dilated
138
+ receptive field complements attention's soft selection.
139
+
140
+ - **Aggregation & Classifier**
141
+
142
+ *Operations:*
143
+
144
+ - Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
145
+ and **average** across windows (default, matching official code), or
146
+ - (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
147
+
148
+ The max-norm constraint regularizes the readout.
149
+
150
+ .. rubric:: Convolutional Details
151
+
152
+ - **Temporal.** Temporal structure is learned in three places:
153
+ - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
154
+ - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
155
+ - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
156
+ (long-range dependencies). The minimum sequence length required by the TCN stack is
157
+ ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
158
+ when inputs are shorter to preserve feasibility.
159
+
160
+ - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
161
+ producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
162
+ This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
163
+
164
+ .. rubric:: Attention / Sequential Modules
165
+
166
+ - **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
167
+ in :class:`_MHA`, allowing ``embed_dim = H·d_h`` independent of input and output dims.
168
+ - **Shapes.** ``(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)``. Attention operates along
169
+ the **temporal** axis within a window; channels/features stay in the embedding dim ``F2``.
170
+ - **Role.** Highlights salient temporal positions prior to causal convolution; small attention
171
+ keeps compute modest while improving context modeling over pooled features.
172
+
173
+ .. rubric:: Additional Mechanisms
174
+
175
+ - **Parallel encoders over shifted windows.** Improves montage/phase robustness by
176
+ ensembling nearby contexts rather than committing to a single segmentation.
177
+ - **Max-norm classifier.** Enforces weight norm constraints at the readout, a common
178
+ stabilization trick in EEG decoding.
179
+ - **ViT vs. ATCNet (design choices).** Convolutional *nonlinear* projection rather than
180
+ linear patchification; attention followed by **TCN** (not MLP); *parallel* window
181
+ encoders rather than stacked encoders.
182
+
183
+ .. rubric:: Usage and Configuration
184
+
185
+ - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
186
+ (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
187
+ - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
188
+ ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
189
+ - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
190
+ - ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
191
+ - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
192
+ longer inputs (see minimum length above). The implementation warns and *rescales*
193
+ kernels/pools/windows if inputs are too short.
194
+ - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
195
+ the official code; ``concat=True`` mirrors the paper's concatenation variant.
196
+
197
+ Parameters
198
+ ----------
199
+ input_window_seconds : float, optional
200
+ Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
201
+ dataset.
202
+ sfreq : int, optional
203
+ Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
204
+ BCI-IV 2a dataset.
205
+ conv_block_n_filters : int
206
+ Number temporal filters in the first convolutional layer of the
207
+ convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
208
+ to 16 as in [1]_.
209
+ conv_block_kernel_length_1 : int
210
+ Length of temporal filters in the first convolutional layer of the
211
+ convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
212
+ to 64 as in [1]_.
213
+ conv_block_kernel_length_2 : int
214
+ Length of temporal filters in the last convolutional layer of the
215
+ convolutional block. Defaults to 16 as in [1]_.
216
+ conv_block_pool_size_1 : int
217
+ Length of first average pooling kernel in the convolutional block.
218
+ Defaults to 8 as in [1]_.
219
+ conv_block_pool_size_2 : int
220
+ Length of first average pooling kernel in the convolutional block,
221
+ denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
222
+ conv_block_depth_mult : int
223
+ Depth multiplier of depthwise convolution in the convolutional block,
224
+ denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
225
+ conv_block_dropout : float
226
+ Dropout probability used in the convolution block, denoted pc in
227
+ table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
228
+ n_windows : int
229
+ Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
230
+ head_dim : int
231
+ Embedding dimension used in each self-attention head, denoted dh in
232
+ table 1 of the paper [1]_. Defaults to 8 as in [1]_.
233
+ num_heads : int
234
+ Number of attention heads, denoted H in table 1 of the paper [1]_.
235
+ Defaults to 2 as in [1]_.
236
+ att_dropout : float
237
+ Dropout probability used in the attention block, denoted pa in table 1
238
+ of the paper [1]_. Defaults to 0.5 as in [1]_.
239
+ tcn_depth : int
240
+ Depth of Temporal Convolutional Network block (i.e. number of TCN
241
+ Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
242
+ as in [1]_.
243
+ tcn_kernel_size : int
244
+ Temporal kernel size used in TCN block, denoted Kt in table 1 of the
245
+ paper [1]_. Defaults to 4 as in [1]_.
246
+ tcn_dropout : float
247
+ Dropout probability used in the TCN block, denoted pt in table 1
248
+ of the paper [1]_. Defaults to 0.3 as in [1]_.
249
+ tcn_activation : torch.nn.Module
250
+ Nonlinear activation to use. Defaults to nn.ELU().
251
+ concat : bool
252
+ When ``True``, concatenates each slidding window embedding before
253
+ feeding it to a fully-connected layer, as done in [1]_. When ``False``,
254
+ maps each slidding window to `n_outputs` logits and average them.
255
+ Defaults to ``False`` contrary to what is reported in [1]_, but
256
+ matching what the official code does [2]_.
257
+ max_norm_const : float
258
+ Maximum L2-norm constraint imposed on weights of the last
259
+ fully-connected layer. Defaults to 0.25.
260
+
261
+ Notes
262
+ -----
263
+ - Inputs substantially shorter than the implied minimum length trigger **automatic
264
+ downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
265
+ - The attention–TCN sequence operates **per window**; the last causal step is used as the
266
+ window feature, aligning the temporal semantics across windows.
267
+
268
+ .. versionadded:: 1.1
269
+
270
+ - More detailed documentation of the model.
271
+
272
+ References
273
+ ----------
274
+ .. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
275
+ *Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.*
276
+ IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
277
+ .. [2] Official EEG-ATCNet implementation (TensorFlow):
278
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
279
+
280
+ .. rubric:: Hugging Face Hub integration
281
+
282
+ When the optional ``huggingface_hub`` package is installed, all models
283
+ automatically gain the ability to be pushed to and loaded from the
284
+ Hugging Face Hub. Install with::
285
+
286
+ pip install braindecode[hub]
287
+
288
+ **Pushing a model to the Hub:**
289
+
290
+ .. code::
291
+ from braindecode.models import ATCNet
292
+
293
+ # Train your model
294
+ model = ATCNet(n_chans=22, n_outputs=4, n_times=1000)
295
+ # ... training code ...
296
+
297
+ # Push to the Hub
298
+ model.push_to_hub(
299
+ repo_id="username/my-atcnet-model",
300
+ commit_message="Initial model upload",
301
+ )
302
+
303
+ **Loading a model from the Hub:**
304
+
305
+ .. code::
306
+ from braindecode.models import ATCNet
307
+
308
+ # Load pretrained model
309
+ model = ATCNet.from_pretrained("username/my-atcnet-model")
310
+
311
+ # Load with a different number of outputs (head is rebuilt automatically)
312
+ model = ATCNet.from_pretrained("username/my-atcnet-model", n_outputs=4)
313
+
314
+ **Extracting features and replacing the head:**
315
+
316
+ .. code::
317
+ import torch
318
+
319
+ x = torch.randn(1, model.n_chans, model.n_times)
320
+ # Extract encoder features (consistent dict across all models)
321
+ out = model(x, return_features=True)
322
+ features = out["features"]
323
+
324
+ # Replace the classification head
325
+ model.reset_head(n_outputs=10)
326
+
327
+ **Saving and restoring full configuration:**
328
+
329
+ .. code::
330
+ import json
331
+
332
+ config = model.get_config() # all __init__ params
333
+ with open("config.json", "w") as f:
334
+ json.dump(config, f)
335
+
336
+ model2 = ATCNet.from_config(config) # reconstruct (no weights)
337
+
338
+ All model parameters (both EEG-specific and model-specific such as
339
+ dropout rates, activation functions, number of filters) are automatically
340
+ saved to the Hub and restored when loading.
341
+
342
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
343
+ </div>
344
+
345
+ ## Citation
346
+
347
+ Please cite both the original paper for this architecture (see the
348
+ *References* section above) and braindecode:
349
+
350
+ ```bibtex
351
+ @article{aristimunha2025braindecode,
352
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
353
+ author = {Aristimunha, Bruno and others},
354
+ journal = {Zenodo},
355
+ year = {2025},
356
+ doi = {10.5281/zenodo.17699192},
357
+ }
358
+ ```
359
+
360
+ ## License
361
+
362
+ BSD-3-Clause for the model code (matching braindecode).
363
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
364
+ inherit the licence of that checkpoint and its training corpus.