bruAristimunha commited on
Commit
32f9a50
·
verified ·
1 Parent(s): fa132bf

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +343 -0
README.md ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - transformer
13
+ ---
14
+
15
+ # SSTDPN
16
+
17
+ SSTDPN from Can Han et al (2025) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.SSTDPN` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import SSTDPN
33
+
34
+ model = SSTDPN(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.SSTDPN.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/sstdpn.py#L17>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>SSTDPN from Can Han et al (2025) [Han2025]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span>
61
+
62
+
63
+
64
+ .. figure:: https://raw.githubusercontent.com/hancan16/SST-DPN/refs/heads/main/figs/framework.png
65
+ :align: center
66
+ :alt: SSTDPN Architecture
67
+ :width: 1000px
68
+
69
+ The **Spatial-Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
70
+ is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
71
+ aiming to address challenges related to discriminative feature extraction and
72
+ small-sample sizes [Han2025]_.
73
+
74
+ The framework systematically addresses three key challenges: multi-channel spatial–spectral
75
+ features and long-term temporal features [Han2025]_.
76
+
77
+ .. rubric:: Architectural Overview
78
+
79
+ SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
80
+ Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
81
+
82
+ 1. **Adaptive Spatial-Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
83
+ multi-channel spatial-spectral representation, followed by :class:`_SpatSpectralAttn`
84
+ (Spatial-Spectral Attention) to model relationships and highlight key spatial-spectral
85
+ channels [Han2025]_.
86
+
87
+ 2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
88
+ at multiple temporal scales to capture long-range temporal dependencies, serving as an
89
+ efficient alternative to transformers [Han2025]_.
90
+
91
+ 3. **Dual Prototype Learning (DPL)**: A training strategy that employs two sets of
92
+ prototypes—Inter-class Separation Prototypes (proto_sep) and Intra-class Compact
93
+ Prototypes (proto_cpt)—to optimize the feature space, enhancing generalization ability and
94
+ preventing overfitting on small datasets [Han2025]_. During inference (forward pass),
95
+ classification decisions are based on the distance (dot product) between the
96
+ feature vector and proto_sep for each class [Han2025]_.
97
+
98
+ .. rubric:: Macro Components
99
+
100
+ - `SSTDPN.encoder` **(Feature Extractor)**
101
+
102
+ - *Operations.* Combines Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling
103
+ via an internal :class:`_SSTEncoder`.
104
+ - *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
105
+ feature space :math:`z_i \in \mathbb{R}^d`.
106
+
107
+ - `_SSTEncoder.temporal_conv` **(Depthwise Temporal Convolution for Spectral Extraction)**
108
+
109
+ - *Operations.* Internal :class:`_DepthwiseTemporalConv1d` applying separate temporal
110
+ convolution filters to each channel with kernel size `temporal_conv_kernel_size` and
111
+ depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
112
+ - *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
113
+
114
+ - `_SSTEncoder.spt_attn` **(Spatial-Spectral Attention for Channel Gating)**
115
+
116
+ - *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
117
+ via variance-based pooling, followed by adaptive channel normalization and gating.
118
+ - *Role.* Reweights channels in the spatial-spectral dimension to extract efficient and
119
+ discriminative features by emphasizing task-relevant regions and frequency bands.
120
+
121
+ - `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
122
+
123
+ - *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
124
+ (equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
125
+ `activation` function (default: ELU).
126
+ - *Role.* Fuses the weighted spatial-spectral features across all electrodes to produce
127
+ a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
128
+
129
+ - `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
130
+
131
+ - *Operations.* Internal :class:`_MultiScaleVarPooler` using :class:`_VariancePool1D`
132
+ layers at multiple scales (`mvp_kernel_sizes`), followed by concatenation.
133
+ - *Role.* Captures long-range temporal features at multiple time scales. The variance
134
+ operation leverages the prior that variance represents EEG spectral power.
135
+
136
+ - `SSTDPN.proto_sep` / `SSTDPN.proto_cpt` **(Dual Prototypes)**
137
+
138
+ - *Operations.* Learnable vectors optimized during training using prototype learning losses.
139
+ The `proto_sep` (Inter-class Separation Prototype) is constrained via L2 weight-normalization
140
+ (:math:`\lVert s_i \rVert_2 \leq` `proto_sep_maxnorm`) during inference.
141
+ - *Role.* `proto_sep` achieves inter-class separation; `proto_cpt` enhances intra-class compactness.
142
+
143
+ .. rubric:: How the information is encoded temporally, spatially, and spectrally
144
+
145
+ * **Temporal.**
146
+ The initial :class:`_DepthwiseTemporalConv1d` uses a large kernel (e.g., 75). The MVP module employs pooling
147
+ kernels that are much larger (e.g., 50, 100, 200 samples) to capture long-term temporal
148
+ features effectively. Large kernel pooling layers are shown to be superior to transformer
149
+ modules for this task in EEG decoding according to [Han2025]_.
150
+
151
+ * **Spatial.**
152
+ The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
153
+ meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
154
+ mechanism explicitly models the relationships among these channels in the spatial-spectral
155
+ dimension, allowing for finer-grained spatial feature modeling compared to conventional
156
+ GCNs according to the authors [Han2025]_.
157
+ In other words, all electrode channels share :math:`F_1` temporal filters
158
+ independently to produce the spatial-spectral representation.
159
+
160
+ * **Spectral.**
161
+ Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
162
+ Furthermore, the use of Variance Pooling (in MVP) explicitly leverages the neurophysiological
163
+ prior that the **variance of EEG signals represents their spectral power**, which is an
164
+ important feature for distinguishing different MI classes [Han2025]_.
165
+
166
+ .. rubric:: Additional Mechanisms
167
+
168
+ - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial-spectral relationships
169
+ at the channel level, distinct from applying attention to deep feature dimensions,
170
+ which is common in comparison methods like :class:`ATCNet`.
171
+ - **Regularization.** Dual Prototype Learning acts as a regularization technique
172
+ by optimizing the feature space to be compact within classes and separated between
173
+ classes. This enhances model generalization and classification performance, particularly
174
+ useful for limited data typical of MI-EEG tasks, without requiring external transfer
175
+ learning data, according to [Han2025]_.
176
+
177
+ Notes
178
+ -----
179
+ * The implementation of the DPL loss functions (:math:`\mathcal{L}_S`, :math:`\mathcal{L}_C`, :math:`\mathcal{L}_{EF}`)
180
+ and the optimization of ICPs are typically handled outside the primary ``forward`` method, within the training strategy
181
+ (see Ref. 52 in [Han2025]_).
182
+ * The default parameters are configured based on the BCI Competition IV 2a dataset.
183
+ * The use of Prototype Learning (PL) methods is novel in the field of EEG-MI decoding.
184
+ * **Lowest FLOPs:** Achieves the lowest Floating Point Operations (FLOPs) (9.65 M) among competitive
185
+ SOTA methods, including braindecode models like :class:`ATCNet` (29.81 M) and
186
+ :class:`EEGConformer` (63.86 M), demonstrating computational efficiency [Han2025]_.
187
+ * **Transformer Alternative:** Multi-scale Variance Pooling (MVP) provides a accuracy
188
+ improvement over temporal attention transformer modules in ablation studies, offering a more
189
+ efficient alternative to transformer-based approaches like :class:`EEGConformer` [Han2025]_.
190
+
191
+ .. warning::
192
+
193
+ **Important:** To utilize the full potential of SSTDPN with Dual Prototype Learning (DPL),
194
+ users must implement the DPL optimization strategy outside the model's forward method.
195
+ For implementation details and training strategies, please consult the official code at
196
+ [Han2025Code]_:
197
+ https://github.com/hancan16/SST-DPN/blob/main/train.py
198
+
199
+ Parameters
200
+ ----------
201
+ n_spectral_filters_temporal : int, optional
202
+ Number of spectral filters extracted per channel via temporal convolution.
203
+ These represent the temporal spectral bands (equivalent to :math:`F_1` in the paper).
204
+ Default is 9.
205
+
206
+ n_fused_filters : int, optional
207
+ Number of output filters after pointwise fusion convolution.
208
+ These fuse the spectral filters across all channels (equivalent to :math:`F_2` in the paper).
209
+ Default is 48.
210
+
211
+ temporal_conv_kernel_size : int, optional
212
+ Kernel size for the temporal convolution layer. Controls the receptive field for extracting
213
+ spectral information. Default is 75 samples.
214
+
215
+ mvp_kernel_sizes : list[int], optional
216
+ Kernel sizes for Multi-scale Variance Pooling (MVP) module.
217
+ Larger kernels capture long-term temporal dependencies .
218
+
219
+ return_features : bool, optional
220
+ If True, the forward pass returns (features, logits). If False, returns only logits.
221
+ Default is False.
222
+
223
+ proto_sep_maxnorm : float, optional
224
+ Maximum L2 norm constraint for Inter-class Separation Prototypes during forward pass.
225
+ This constraint acts as an implicit force to push features away from the origin. Default is 1.0.
226
+
227
+ proto_cpt_std : float, optional
228
+ Standard deviation for Intra-class Compactness Prototype initialization. Default is 0.01.
229
+
230
+ spt_attn_global_context_kernel : int, optional
231
+ Kernel size for global context embedding in Spatial-Spectral Attention module.
232
+ Default is 250 samples.
233
+
234
+ spt_attn_epsilon : float, optional
235
+ Small epsilon value for numerical stability in Spatial-Spectral Attention. Default is 1e-5.
236
+
237
+ spt_attn_mode : str, optional
238
+ Embedding computation mode for Spatial-Spectral Attention ('var', 'l2', or 'l1').
239
+ Default is 'var' (variance-based mean-var operation).
240
+
241
+ activation : nn.Module, optional
242
+ Activation function to apply after the pointwise fusion convolution in :class:`_SSTEncoder`.
243
+ Should be a PyTorch activation module class. Default is nn.ELU.
244
+
245
+
246
+ References
247
+ ----------
248
+ .. [Han2025] Han, C., Liu, C., Wang, J., Wang, Y., Cai, C.,
249
+ & Qian, D. (2025). A spatial–spectral and temporal dual
250
+ prototype network for motor imagery brain–computer
251
+ interface. Knowledge-Based Systems, 315, 113315.
252
+ .. [Han2025Code] Han, C., Liu, C., Wang, J., Wang, Y.,
253
+ Cai, C., & Qian, D. (2025). A spatial–spectral and
254
+ temporal dual prototype network for motor imagery
255
+ brain–computer interface. Knowledge-Based Systems,
256
+ 315, 113315. GitHub repository.
257
+ https://github.com/hancan16/SST-DPN.
258
+
259
+ .. rubric:: Hugging Face Hub integration
260
+
261
+ When the optional ``huggingface_hub`` package is installed, all models
262
+ automatically gain the ability to be pushed to and loaded from the
263
+ Hugging Face Hub. Install with::
264
+
265
+ pip install braindecode[hub]
266
+
267
+ **Pushing a model to the Hub:**
268
+
269
+ .. code::
270
+ from braindecode.models import SSTDPN
271
+
272
+ # Train your model
273
+ model = SSTDPN(n_chans=22, n_outputs=4, n_times=1000)
274
+ # ... training code ...
275
+
276
+ # Push to the Hub
277
+ model.push_to_hub(
278
+ repo_id="username/my-sstdpn-model",
279
+ commit_message="Initial model upload",
280
+ )
281
+
282
+ **Loading a model from the Hub:**
283
+
284
+ .. code::
285
+ from braindecode.models import SSTDPN
286
+
287
+ # Load pretrained model
288
+ model = SSTDPN.from_pretrained("username/my-sstdpn-model")
289
+
290
+ # Load with a different number of outputs (head is rebuilt automatically)
291
+ model = SSTDPN.from_pretrained("username/my-sstdpn-model", n_outputs=4)
292
+
293
+ **Extracting features and replacing the head:**
294
+
295
+ .. code::
296
+ import torch
297
+
298
+ x = torch.randn(1, model.n_chans, model.n_times)
299
+ # Extract encoder features (consistent dict across all models)
300
+ out = model(x, return_features=True)
301
+ features = out["features"]
302
+
303
+ # Replace the classification head
304
+ model.reset_head(n_outputs=10)
305
+
306
+ **Saving and restoring full configuration:**
307
+
308
+ .. code::
309
+ import json
310
+
311
+ config = model.get_config() # all __init__ params
312
+ with open("config.json", "w") as f:
313
+ json.dump(config, f)
314
+
315
+ model2 = SSTDPN.from_config(config) # reconstruct (no weights)
316
+
317
+ All model parameters (both EEG-specific and model-specific such as
318
+ dropout rates, activation functions, number of filters) are automatically
319
+ saved to the Hub and restored when loading.
320
+
321
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
322
+ </div>
323
+
324
+ ## Citation
325
+
326
+ Please cite both the original paper for this architecture (see the
327
+ *References* section above) and braindecode:
328
+
329
+ ```bibtex
330
+ @article{aristimunha2025braindecode,
331
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
332
+ author = {Aristimunha, Bruno and others},
333
+ journal = {Zenodo},
334
+ year = {2025},
335
+ doi = {10.5281/zenodo.17699192},
336
+ }
337
+ ```
338
+
339
+ ## License
340
+
341
+ BSD-3-Clause for the model code (matching braindecode).
342
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
343
+ inherit the licence of that checkpoint and its training corpus.