Replace with clean markdown card
Browse files
README.md
CHANGED
|
@@ -14,13 +14,12 @@ tags:
|
|
| 14 |
|
| 15 |
# AttentionBaseNet
|
| 16 |
|
| 17 |
-
AttentionBaseNet from Wimpff M et al (2023) .
|
| 18 |
|
| 19 |
-
> **Architecture-only repository.**
|
| 20 |
> `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
|
| 21 |
-
> distributed here**
|
| 22 |
-
> data
|
| 23 |
-
> separately.
|
| 24 |
|
| 25 |
## Quick start
|
| 26 |
|
|
@@ -39,314 +38,54 @@ model = AttentionBaseNet(
|
|
| 39 |
)
|
| 40 |
```
|
| 41 |
|
| 42 |
-
The signal-shape arguments above are
|
| 43 |
-
|
| 44 |
|
| 45 |
## Documentation
|
| 46 |
-
|
| 47 |
-
-
|
| 48 |
-
<https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
|
| 49 |
-
- Interactive browser with live instantiation:
|
| 50 |
<https://huggingface.co/spaces/braindecode/model-explorer>
|
| 51 |
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
|
| 52 |
|
| 53 |
-
## Architecture description
|
| 54 |
-
|
| 55 |
-
The block below is the rendered class docstring (parameters,
|
| 56 |
-
references, architecture figure where available).
|
| 57 |
-
|
| 58 |
-
<div class='bd-doc'><main>
|
| 59 |
-
<p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
|
| 60 |
-
<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
|
| 65 |
-
:align: center
|
| 66 |
-
:alt: AttentionBaseNet Architecture
|
| 67 |
-
:width: 640px
|
| 68 |
-
|
| 69 |
-
.. rubric:: Architectural Overview
|
| 70 |
-
|
| 71 |
-
AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
|
| 72 |
-
The end-to-end flow is:
|
| 73 |
-
|
| 74 |
-
- (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
|
| 75 |
-
projections (depthwise across electrodes), then condenses time by pooling;
|
| 76 |
-
- (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
|
| 77 |
-
- (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
|
| 78 |
-
convs and an optional channel-attention module (SE/CBAM/ECA/…);
|
| 79 |
-
- (iv) **Classifier** flattens the sequence and applies a linear readout.
|
| 80 |
-
|
| 81 |
-
This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
|
| 82 |
-
attention unit that *re-weights channels* (and optionally temporal positions) before
|
| 83 |
-
classification.
|
| 84 |
-
|
| 85 |
-
.. rubric:: Macro Components
|
| 86 |
-
|
| 87 |
-
- :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
|
| 88 |
-
|
| 89 |
-
- *Operations.*
|
| 90 |
-
- **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
|
| 91 |
-
FIR-like filter bank with ``n_temporal_filters`` maps.
|
| 92 |
-
- **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
|
| 93 |
-
with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
|
| 94 |
-
- **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
|
| 95 |
-
- Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.
|
| 96 |
-
|
| 97 |
-
*Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
|
| 98 |
-
depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
|
| 99 |
-
integrator that reduces variance on short EEG windows.
|
| 100 |
-
|
| 101 |
-
- **Channel Expansion**
|
| 102 |
-
|
| 103 |
-
- *Operations.*
|
| 104 |
-
- A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
|
| 105 |
-
the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
|
| 106 |
-
This sets the embedding width for the attention block.
|
| 107 |
-
|
| 108 |
-
- :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**
|
| 109 |
-
|
| 110 |
-
- *Operations.*
|
| 111 |
-
- **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
|
| 112 |
-
BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
|
| 113 |
-
- **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
|
| 114 |
-
(some variants also apply temporal gating).
|
| 115 |
-
- **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
|
| 116 |
-
``(B, ch_dim, 1, T₂)``.
|
| 117 |
-
|
| 118 |
-
*Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
|
| 119 |
-
before the classifier; complements the convolutional priors with adaptive re-weighting.
|
| 120 |
-
|
| 121 |
-
- **Classifier (aggregation + readout)**
|
| 122 |
-
|
| 123 |
-
*Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
|
| 124 |
-
``(B, ch_dim·T₂)`` to classes.
|
| 125 |
-
|
| 126 |
-
.. rubric:: Convolutional Details
|
| 127 |
-
|
| 128 |
-
- **Temporal (where time-domain patterns are learned).**
|
| 129 |
-
Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
|
| 130 |
-
bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
|
| 131 |
-
short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
|
| 132 |
-
set the token rate and effective temporal resolution.
|
| 133 |
-
|
| 134 |
-
- **Spatial (how electrodes are processed).**
|
| 135 |
-
A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
|
| 136 |
-
learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
|
| 137 |
-
mirroring the interpretable spatial stage in shallow CNNs.
|
| 138 |
-
|
| 139 |
-
- **Spectral (how frequency content is captured).**
|
| 140 |
-
No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
|
| 141 |
-
emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
|
| 142 |
-
channel attention (DCT-based) summarizes frequencies to drive channel weights.
|
| 143 |
-
|
| 144 |
-
.. rubric:: Attention / Sequential Modules
|
| 145 |
-
|
| 146 |
-
- **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
|
| 147 |
-
EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
|
| 148 |
-
include temporal attention.
|
| 149 |
-
|
| 150 |
-
- **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
|
| 151 |
-
(if any) are internal to the module; the block returns the same shape before pooling.
|
| 152 |
-
|
| 153 |
-
- **Role.** Re-weights channels (and optionally time) to highlight informative sources
|
| 154 |
-
and suppress distractors, improving SNR ahead of the linear head.
|
| 155 |
-
|
| 156 |
-
.. rubric:: Additional Mechanisms
|
| 157 |
-
|
| 158 |
-
**Attention variants at a glance:**
|
| 159 |
-
|
| 160 |
-
- ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
|
| 161 |
-
- ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
|
| 162 |
-
- ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
|
| 163 |
-
- ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
|
| 164 |
-
- ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
|
| 165 |
-
- ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
|
| 166 |
-
- ``"gct"``: Gated Channel Transformation (global context normalization + gating).
|
| 167 |
-
- ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
|
| 168 |
-
- ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
|
| 169 |
-
- ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
|
| 170 |
-
|
| 171 |
-
**Auto-compatibility on short inputs:**
|
| 172 |
-
|
| 173 |
-
If the input duration is too short for the configured kernels/pools, the implementation
|
| 174 |
-
**automatically rescales** temporal lengths/strides downward (with a warning) to keep
|
| 175 |
-
shapes valid and preserve the pipeline semantics.
|
| 176 |
-
|
| 177 |
-
.. rubric:: Usage and Configuration
|
| 178 |
-
|
| 179 |
-
- ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
|
| 180 |
-
control the capacity and the number of spatial projections in the stem.
|
| 181 |
-
- ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
|
| 182 |
-
trade temporal resolution for compute; they determine the final sequence length ``T₂``.
|
| 183 |
-
- ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
|
| 184 |
-
- ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
|
| 185 |
-
``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
|
| 186 |
-
select and tune the reweighting mechanism.
|
| 187 |
-
- ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
|
| 188 |
-
- **Training tips.**
|
| 189 |
-
|
| 190 |
-
Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
|
| 191 |
-
only after the stem learns stable filters. For small datasets, prefer simpler modes
|
| 192 |
-
(``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
|
| 193 |
-
|
| 194 |
-
Parameters
|
| 195 |
-
----------
|
| 196 |
-
n_temporal_filters : int, optional
|
| 197 |
-
Number of temporal convolutional filters in the first layer. This defines
|
| 198 |
-
the number of output channels after the temporal convolution.
|
| 199 |
-
Default is 40.
|
| 200 |
-
temp_filter_length : int, default=15
|
| 201 |
-
The length of the temporal filters in the convolutional layers.
|
| 202 |
-
spatial_expansion : int, optional
|
| 203 |
-
Multiplicative factor to expand the spatial dimensions. Used to increase
|
| 204 |
-
the capacity of the model by expanding spatial features. Default is 1.
|
| 205 |
-
pool_length_inp : int, optional
|
| 206 |
-
Length of the pooling window in the input layer. Determines how much
|
| 207 |
-
temporal information is aggregated during pooling. Default is 75.
|
| 208 |
-
pool_stride_inp : int, optional
|
| 209 |
-
Stride of the pooling operation in the input layer. Controls the
|
| 210 |
-
downsampling factor in the temporal dimension. Default is 15.
|
| 211 |
-
drop_prob_inp : float, optional
|
| 212 |
-
Dropout rate applied after the input layer. This is the probability of
|
| 213 |
-
zeroing out elements during training to prevent overfitting.
|
| 214 |
-
Default is 0.5.
|
| 215 |
-
ch_dim : int, optional
|
| 216 |
-
Number of channels in the subsequent convolutional layers. This controls
|
| 217 |
-
the depth of the network after the initial layer. Default is 16.
|
| 218 |
-
attention_mode : str, optional
|
| 219 |
-
The type of attention mechanism to apply. If `None`, no attention is applied.
|
| 220 |
-
|
| 221 |
-
- "se" for Squeeze-and-excitation network
|
| 222 |
-
- "gsop" for Global Second-Order Pooling
|
| 223 |
-
- "fca" for Frequency Channel Attention Network
|
| 224 |
-
- "encnet" for context encoding module
|
| 225 |
-
- "eca" for Efficient channel attention for deep convolutional neural networks
|
| 226 |
-
- "ge" for Gather-Excite
|
| 227 |
-
- "gct" for Gated Channel Transformation
|
| 228 |
-
- "srm" for Style-based Recalibration Module
|
| 229 |
-
- "cbam" for Convolutional Block Attention Module
|
| 230 |
-
- "cat" for Learning to collaborate channel and temporal attention
|
| 231 |
-
from multi-information fusion
|
| 232 |
-
- "catlite" for Learning to collaborate channel attention
|
| 233 |
-
from multi-information fusion (lite version, cat w/o temporal attention)
|
| 234 |
-
|
| 235 |
-
pool_length : int, default=8
|
| 236 |
-
The length of the window for the average pooling operation.
|
| 237 |
-
pool_stride : int, default=8
|
| 238 |
-
The stride of the average pooling operation.
|
| 239 |
-
drop_prob_attn : float, default=0.5
|
| 240 |
-
The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
|
| 241 |
-
reduction_rate : int, default=4
|
| 242 |
-
The reduction rate used in the attention mechanism to reduce dimensionality
|
| 243 |
-
and computational complexity.
|
| 244 |
-
use_mlp : bool, default=False
|
| 245 |
-
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
|
| 246 |
-
the attention mechanism for further processing.
|
| 247 |
-
freq_idx : int, default=0
|
| 248 |
-
DCT index used in fca attention mechanism.
|
| 249 |
-
n_codewords : int, default=4
|
| 250 |
-
The number of codewords (clusters) used in attention mechanisms that employ
|
| 251 |
-
quantization or clustering strategies.
|
| 252 |
-
kernel_size : int, default=9
|
| 253 |
-
The kernel size used in certain types of attention mechanisms for convolution
|
| 254 |
-
operations.
|
| 255 |
-
activation : type[nn.Module] = nn.ELU,
|
| 256 |
-
Activation function class to apply. Should be a PyTorch activation
|
| 257 |
-
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
| 258 |
-
extra_params : bool, default=False
|
| 259 |
-
Flag to indicate whether additional, custom parameters should be passed to
|
| 260 |
-
the attention mechanism.
|
| 261 |
-
|
| 262 |
-
Notes
|
| 263 |
-
-----
|
| 264 |
-
- Sequence length after each stage is computed internally; the final classifier expects
|
| 265 |
-
a flattened ``ch_dim x T₂`` vector.
|
| 266 |
-
- Attention operates on *channel* dimension by design; temporal gating exists only in
|
| 267 |
-
specific variants (CBAM/CAT).
|
| 268 |
-
- The paper and original code with more details about the methodological
|
| 269 |
-
choices are available at the [Martin2023]_ and [MartinCode]_.
|
| 270 |
-
|
| 271 |
-
.. versionadded:: 0.9
|
| 272 |
-
|
| 273 |
-
References
|
| 274 |
-
----------
|
| 275 |
-
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
|
| 276 |
-
EEG motor imagery decoding: A framework for comparative analysis with
|
| 277 |
-
channel attention mechanisms. arXiv preprint arXiv:2310.11198.
|
| 278 |
-
.. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
|
| 279 |
-
GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
|
| 280 |
-
|
| 281 |
-
.. rubric:: Hugging Face Hub integration
|
| 282 |
-
|
| 283 |
-
When the optional ``huggingface_hub`` package is installed, all models
|
| 284 |
-
automatically gain the ability to be pushed to and loaded from the
|
| 285 |
-
Hugging Face Hub. Install with::
|
| 286 |
-
|
| 287 |
-
pip install braindecode[hub]
|
| 288 |
-
|
| 289 |
-
**Pushing a model to the Hub:**
|
| 290 |
-
|
| 291 |
-
.. code::
|
| 292 |
-
from braindecode.models import AttentionBaseNet
|
| 293 |
-
|
| 294 |
-
# Train your model
|
| 295 |
-
model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
|
| 296 |
-
# ... training code ...
|
| 297 |
-
|
| 298 |
-
# Push to the Hub
|
| 299 |
-
model.push_to_hub(
|
| 300 |
-
repo_id="username/my-attentionbasenet-model",
|
| 301 |
-
commit_message="Initial model upload",
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
**Loading a model from the Hub:**
|
| 305 |
-
|
| 306 |
-
.. code::
|
| 307 |
-
from braindecode.models import AttentionBaseNet
|
| 308 |
-
|
| 309 |
-
# Load pretrained model
|
| 310 |
-
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
|
| 311 |
-
|
| 312 |
-
# Load with a different number of outputs (head is rebuilt automatically)
|
| 313 |
-
model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
|
| 314 |
-
|
| 315 |
-
**Extracting features and replacing the head:**
|
| 316 |
|
| 317 |
-
|
| 318 |
-
import torch
|
| 319 |
|
| 320 |
-
|
| 321 |
-
# Extract encoder features (consistent dict across all models)
|
| 322 |
-
out = model(x, return_features=True)
|
| 323 |
-
features = out["features"]
|
| 324 |
|
| 325 |
-
# Replace the classification head
|
| 326 |
-
model.reset_head(n_outputs=10)
|
| 327 |
|
| 328 |
-
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
-
config = model.get_config() # all __init__ params
|
| 334 |
-
with open("config.json", "w") as f:
|
| 335 |
-
json.dump(config, f)
|
| 336 |
|
| 337 |
-
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
saved to the Hub and restored when loading.
|
| 342 |
|
| 343 |
-
See :ref:`load-pretrained-models` for a complete tutorial.</main>
|
| 344 |
-
</div>
|
| 345 |
|
| 346 |
## Citation
|
| 347 |
|
| 348 |
-
|
| 349 |
-
*References* section above) and braindecode:
|
| 350 |
|
| 351 |
```bibtex
|
| 352 |
@article{aristimunha2025braindecode,
|
|
|
|
| 14 |
|
| 15 |
# AttentionBaseNet
|
| 16 |
|
| 17 |
+
AttentionBaseNet from Wimpff M et al (2023) [Martin2023].
|
| 18 |
|
| 19 |
+
> **Architecture-only repository.** Documents the
|
| 20 |
> `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
|
| 21 |
+
> distributed here.** Instantiate the model and train it on your own
|
| 22 |
+
> data.
|
|
|
|
| 23 |
|
| 24 |
## Quick start
|
| 25 |
|
|
|
|
| 38 |
)
|
| 39 |
```
|
| 40 |
|
| 41 |
+
The signal-shape arguments above are illustrative defaults — adjust to
|
| 42 |
+
match your recording.
|
| 43 |
|
| 44 |
## Documentation
|
| 45 |
+
- Full API reference: <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
|
| 46 |
+
- Interactive browser (live instantiation, parameter counts):
|
|
|
|
|
|
|
| 47 |
<https://huggingface.co/spaces/braindecode/model-explorer>
|
| 48 |
- Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
## Architecture
|
|
|
|
| 52 |
|
| 53 |
+

|
|
|
|
|
|
|
|
|
|
| 54 |
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
## Parameters
|
| 57 |
|
| 58 |
+
| Parameter | Type | Description |
|
| 59 |
+
|---|---|---|
|
| 60 |
+
| `n_temporal_filters` | int, optional | Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40. |
|
| 61 |
+
| `temp_filter_length` | int, default=15 | The length of the temporal filters in the convolutional layers. |
|
| 62 |
+
| `spatial_expansion` | int, optional | Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1. |
|
| 63 |
+
| `pool_length_inp` | int, optional | Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75. |
|
| 64 |
+
| `pool_stride_inp` | int, optional | Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15. |
|
| 65 |
+
| `drop_prob_inp` | float, optional | Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5. |
|
| 66 |
+
| `ch_dim` | int, optional | Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16. |
|
| 67 |
+
| `attention_mode` | str, optional | The type of attention mechanism to apply. If `None`, no attention is applied. - "se" for Squeeze-and-excitation network - "gsop" for Global Second-Order Pooling - "fca" for Frequency Channel Attention Network - "encnet" for context encoding module - "eca" for Efficient channel attention for deep convolutional neural networks - "ge" for Gather-Excite - "gct" for Gated Channel Transformation - "srm" for Style-based Recalibration Module - "cbam" for Convolutional Block Attention Module - "cat" for Learning to collaborate channel and temporal attention from multi-information fusion - "catlite" for Learning to collaborate channel attention from multi-information fusion (lite version, cat w/o temporal attention) |
|
| 68 |
+
| `pool_length` | int, default=8 | The length of the window for the average pooling operation. |
|
| 69 |
+
| `pool_stride` | int, default=8 | The stride of the average pooling operation. |
|
| 70 |
+
| `drop_prob_attn` | float, default=0.5 | The dropout rate for regularization for the attention layer. Values should be between 0 and 1. |
|
| 71 |
+
| `reduction_rate` | int, default=4 | The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity. |
|
| 72 |
+
| `use_mlp` | bool, default=False | Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing. |
|
| 73 |
+
| `freq_idx` | int, default=0 | DCT index used in fca attention mechanism. |
|
| 74 |
+
| `n_codewords` | int, default=4 | The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies. |
|
| 75 |
+
| `kernel_size` | int, default=9 | The kernel size used in certain types of attention mechanisms for convolution operations. |
|
| 76 |
+
| `activation` | type[nn.Module] = nn.ELU, | Activation function class to apply. Should be a PyTorch activation module class like `nn.ReLU` or `nn.ELU`. Default is `nn.ELU`. |
|
| 77 |
+
| `extra_params` | bool, default=False | Flag to indicate whether additional, custom parameters should be passed to the attention mechanism. |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
## References
|
| 81 |
|
| 82 |
+
1. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198.
|
| 83 |
+
2. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
|
|
|
|
| 84 |
|
|
|
|
|
|
|
| 85 |
|
| 86 |
## Citation
|
| 87 |
|
| 88 |
+
Cite the original architecture paper (see *References* above) and braindecode:
|
|
|
|
| 89 |
|
| 90 |
```bibtex
|
| 91 |
@article{aristimunha2025braindecode,
|