AttentionBaseNet

AttentionBaseNet from Wimpff M et al (2023) [Martin2023].

Architecture-only repository. Documents the braindecode.models.AttentionBaseNet class. No pretrained weights are distributed here. Instantiate the model and train it on your own data.

Quick start

pip install braindecode
from braindecode.models import AttentionBaseNet

model = AttentionBaseNet(
    n_chans=22,
    sfreq=250,
    input_window_seconds=4.0,
    n_outputs=4,
)

The signal-shape arguments above are illustrative defaults — adjust to match your recording.

Documentation

Architecture

AttentionBaseNet architecture

Parameters

Parameter Type Description
n_temporal_filters int, optional Number of temporal convolutional filters in the first layer. This defines the number of output channels after the temporal convolution. Default is 40.
temp_filter_length int, default=15 The length of the temporal filters in the convolutional layers.
spatial_expansion int, optional Multiplicative factor to expand the spatial dimensions. Used to increase the capacity of the model by expanding spatial features. Default is 1.
pool_length_inp int, optional Length of the pooling window in the input layer. Determines how much temporal information is aggregated during pooling. Default is 75.
pool_stride_inp int, optional Stride of the pooling operation in the input layer. Controls the downsampling factor in the temporal dimension. Default is 15.
drop_prob_inp float, optional Dropout rate applied after the input layer. This is the probability of zeroing out elements during training to prevent overfitting. Default is 0.5.
ch_dim int, optional Number of channels in the subsequent convolutional layers. This controls the depth of the network after the initial layer. Default is 16.
attention_mode str, optional The type of attention mechanism to apply. If None, no attention is applied. - "se" for Squeeze-and-excitation network - "gsop" for Global Second-Order Pooling - "fca" for Frequency Channel Attention Network - "encnet" for context encoding module - "eca" for Efficient channel attention for deep convolutional neural networks - "ge" for Gather-Excite - "gct" for Gated Channel Transformation - "srm" for Style-based Recalibration Module - "cbam" for Convolutional Block Attention Module - "cat" for Learning to collaborate channel and temporal attention from multi-information fusion - "catlite" for Learning to collaborate channel attention from multi-information fusion (lite version, cat w/o temporal attention)
pool_length int, default=8 The length of the window for the average pooling operation.
pool_stride int, default=8 The stride of the average pooling operation.
drop_prob_attn float, default=0.5 The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
reduction_rate int, default=4 The reduction rate used in the attention mechanism to reduce dimensionality and computational complexity.
use_mlp bool, default=False Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within the attention mechanism for further processing.
freq_idx int, default=0 DCT index used in fca attention mechanism.
n_codewords int, default=4 The number of codewords (clusters) used in attention mechanisms that employ quantization or clustering strategies.
kernel_size int, default=9 The kernel size used in certain types of attention mechanisms for convolution operations.
activation type[nn.Module] = nn.ELU, Activation function class to apply. Should be a PyTorch activation module class like nn.ReLU or nn.ELU. Default is nn.ELU.
extra_params bool, default=False Flag to indicate whether additional, custom parameters should be passed to the attention mechanism.

References

  1. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023. EEG motor imagery decoding: A framework for comparative analysis with channel attention mechanisms. arXiv preprint arXiv:2310.11198.
  2. Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B. GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)

Citation

Cite the original architecture paper (see References above) and braindecode:

@article{aristimunha2025braindecode,
  title   = {Braindecode: a deep learning library for raw electrophysiological data},
  author  = {Aristimunha, Bruno and others},
  journal = {Zenodo},
  year    = {2025},
  doi     = {10.5281/zenodo.17699192},
}

License

BSD-3-Clause for the model code (matching braindecode). Pretraining-derived weights, if you fine-tune from a checkpoint, inherit the licence of that checkpoint and its training corpus.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for braindecode/AttentionBaseNet