bruAristimunha commited on
Commit
010c443
·
verified ·
1 Parent(s): 8408b55

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +295 -0
README.md ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - sleep-staging
13
+ ---
14
+
15
+ # DeepSleepNet
16
+
17
+ DeepSleepNet from Supratak et al (2017) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.DeepSleepNet` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import DeepSleepNet
33
+
34
+ model = DeepSleepNet(
35
+ n_chans=2,
36
+ sfreq=100,
37
+ input_window_seconds=30.0,
38
+ n_outputs=5,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.DeepSleepNet.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/deepsleepnet.py#L12>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>DeepSleepNet from Supratak et al (2017) [Supratak2017]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#6c757d;color:white;font-size:11px;font-weight:600;margin-right:4px;">Recurrent</span>
61
+
62
+
63
+
64
+ .. figure:: https://raw.githubusercontent.com/akaraspt/deepsleepnet/master/img/deepsleepnet.png
65
+ :align: center
66
+ :alt: DeepSleepNet Architecture
67
+ :width: 700px
68
+
69
+ DeepSleepNet is a deep learning model for automatic sleep stage scoring
70
+ based on raw single-channel EEG. It consists of two main parts:
71
+
72
+ 1. **Representation learning** — two CNNs with different filter sizes
73
+ extract time-invariant features from each 30-s EEG epoch.
74
+ 2. **Sequence residual learning** — bidirectional LSTMs learn temporal
75
+ information such as stage transition rules, combined with a residual
76
+ shortcut from the CNN features.
77
+
78
+ .. rubric:: Representation Learning
79
+
80
+ Two parallel CNN paths process the raw input simultaneously:
81
+
82
+ - **Small-filter path** — first conv uses filter length ≈ Fs/2 and
83
+ stride ≈ Fs/16, capturing *when* characteristic transients occur
84
+ (temporal precision).
85
+ - **Large-filter path** — first conv uses filter length ≈ 4·Fs and
86
+ stride ≈ Fs/2, capturing *which* frequency components dominate
87
+ (frequency precision).
88
+
89
+ Each path consists of four convolutional layers (1-D convolution →
90
+ :class:`~torch.nn.BatchNorm2d` → activation, configurable via the
91
+ per-path activation settings) and two :class:`~torch.nn.MaxPool2d`
92
+ layers with :class:`~torch.nn.Dropout` after the first pooling.
93
+ Outputs from both paths are **concatenated** to form the epoch
94
+ embedding.
95
+
96
+ .. rubric:: Sequence Residual Learning
97
+
98
+ Two layers of bidirectional LSTMs encode temporal dependencies across
99
+ epochs. A **residual shortcut** (fully connected →
100
+ :class:`~torch.nn.BatchNorm1d` → :class:`~torch.nn.ReLU`) projects
101
+ the CNN features to the BiLSTM output dimension and is **added** to
102
+ the BiLSTM output, improving gradient flow and preserving salient
103
+ CNN evidence.
104
+
105
+ .. rubric:: Implementation Differences
106
+
107
+ .. note::
108
+
109
+ **Peephole connections.** The original implementation uses
110
+ TensorFlow ``LSTMCell`` with ``use_peepholes=True``, which allows
111
+ gates to inspect the cell state. :class:`torch.nn.LSTM` does not
112
+ support peepholes; this implementation uses standard LSTM gates.
113
+
114
+ **Sequence length.** The original model processes **sequences of
115
+ epochs** through the BiLSTM to capture cross-epoch transition rules.
116
+ This implementation processes **single epochs** (sequence length 1),
117
+ so the BiLSTM acts as a nonlinear feature transform with a residual
118
+ connection. To leverage multi-epoch context, batch consecutive
119
+ epochs as a sequence externally.
120
+
121
+ **Activation.** The original uses :class:`~torch.nn.ReLU` for both
122
+ CNN paths. This implementation defaults to :class:`~torch.nn.ELU`
123
+ for the large-filter path (``activation_large``), which can be
124
+ overridden.
125
+
126
+ .. rubric:: Training (from the paper)
127
+
128
+ - **Two-step procedure.** (i) Pre-train the CNN part on a
129
+ class-balanced training set using oversampling; (ii) fine-tune the
130
+ whole network with sequential batches using a lower learning rate
131
+ for the CNNs and a higher one for the sequence residual part.
132
+ - **Dropout** with probability 0.5 is used throughout the model.
133
+ - **L2 weight decay** (λ = 10⁻³) is applied only to the first
134
+ convolutional layers of both CNN paths.
135
+ - **Gradient clipping** rescales gradients when their global norm
136
+ exceeds a threshold.
137
+ - **State handling.** BiLSTM states are reinitialized per subject so
138
+ that temporal context does not leak across recordings.
139
+
140
+ Parameters
141
+ ----------
142
+ activation_large : type[nn.Module], default=nn.ELU
143
+ Activation class for the large-filter CNN path.
144
+ activation_small : type[nn.Module], default=nn.ReLU
145
+ Activation class for the small-filter CNN path.
146
+ return_feats : bool, default=False
147
+ If True, return features before the final linear layer.
148
+ drop_prob : float, default=0.5
149
+ Dropout probability applied throughout the network.
150
+ bilstm_hidden_size : int, default=512
151
+ Hidden size of the BiLSTM. The residual FC output dimension is
152
+ ``2 * bilstm_hidden_size`` to match the concatenated directions.
153
+ bilstm_num_layers : int, default=2
154
+ Number of stacked BiLSTM layers.
155
+ small_n_filters_1 : int, default=64
156
+ First-conv output channels for the small-filter path.
157
+ small_n_filters_2 : int, default=128
158
+ Deep-conv (conv2--conv4) output channels for the small-filter path.
159
+ small_first_kernel_size : int, default=50
160
+ First-conv kernel size for the small path (paper: Fs/2).
161
+ small_first_stride : int, default=6
162
+ First-conv stride for the small path (paper: Fs/16).
163
+ small_first_padding : int, default=22
164
+ First-conv padding for the small path.
165
+ small_pool1_kernel_size : int, default=8
166
+ First max-pool kernel for the small path.
167
+ small_pool1_stride : int, default=8
168
+ First max-pool stride for the small path.
169
+ small_pool1_padding : int, default=2
170
+ First max-pool padding for the small path.
171
+ small_deep_kernel_size : int, default=8
172
+ Deep-conv kernel size for the small path.
173
+ small_pool2_kernel_size : int, default=4
174
+ Second max-pool kernel for the small path.
175
+ small_pool2_stride : int, default=4
176
+ Second max-pool stride for the small path.
177
+ small_pool2_padding : int, default=1
178
+ Second max-pool padding for the small path.
179
+ large_n_filters_1 : int, default=64
180
+ First-conv output channels for the large-filter path.
181
+ large_n_filters_2 : int, default=128
182
+ Deep-conv (conv2--conv4) output channels for the large-filter path.
183
+ large_first_kernel_size : int, default=400
184
+ First-conv kernel size for the large path (paper: 4*Fs).
185
+ large_first_stride : int, default=50
186
+ First-conv stride for the large path (paper: Fs/2).
187
+ large_first_padding : int, default=175
188
+ First-conv padding for the large path.
189
+ large_pool1_kernel_size : int, default=4
190
+ First max-pool kernel for the large path.
191
+ large_pool1_stride : int, default=4
192
+ First max-pool stride for the large path.
193
+ large_pool1_padding : int, default=0
194
+ First max-pool padding for the large path.
195
+ large_deep_kernel_size : int, default=6
196
+ Deep-conv kernel size for the large path.
197
+ large_pool2_kernel_size : int, default=2
198
+ Second max-pool kernel for the large path.
199
+ large_pool2_stride : int, default=2
200
+ Second max-pool stride for the large path.
201
+ large_pool2_padding : int, default=1
202
+ Second max-pool padding for the large path.
203
+
204
+ References
205
+ ----------
206
+ .. [Supratak2017] Supratak, A., Dong, H., Wu, C., & Guo, Y. (2017).
207
+ DeepSleepNet: A model for automatic sleep stage scoring based
208
+ on raw single-channel EEG. IEEE Transactions on Neural Systems
209
+ and Rehabilitation Engineering, 25(11), 1998-2008.
210
+
211
+ .. rubric:: Hugging Face Hub integration
212
+
213
+ When the optional ``huggingface_hub`` package is installed, all models
214
+ automatically gain the ability to be pushed to and loaded from the
215
+ Hugging Face Hub. Install with::
216
+
217
+ pip install braindecode[hub]
218
+
219
+ **Pushing a model to the Hub:**
220
+
221
+ .. code::
222
+ from braindecode.models import DeepSleepNet
223
+
224
+ # Train your model
225
+ model = DeepSleepNet(n_chans=22, n_outputs=4, n_times=1000)
226
+ # ... training code ...
227
+
228
+ # Push to the Hub
229
+ model.push_to_hub(
230
+ repo_id="username/my-deepsleepnet-model",
231
+ commit_message="Initial model upload",
232
+ )
233
+
234
+ **Loading a model from the Hub:**
235
+
236
+ .. code::
237
+ from braindecode.models import DeepSleepNet
238
+
239
+ # Load pretrained model
240
+ model = DeepSleepNet.from_pretrained("username/my-deepsleepnet-model")
241
+
242
+ # Load with a different number of outputs (head is rebuilt automatically)
243
+ model = DeepSleepNet.from_pretrained("username/my-deepsleepnet-model", n_outputs=4)
244
+
245
+ **Extracting features and replacing the head:**
246
+
247
+ .. code::
248
+ import torch
249
+
250
+ x = torch.randn(1, model.n_chans, model.n_times)
251
+ # Extract encoder features (consistent dict across all models)
252
+ out = model(x, return_features=True)
253
+ features = out["features"]
254
+
255
+ # Replace the classification head
256
+ model.reset_head(n_outputs=10)
257
+
258
+ **Saving and restoring full configuration:**
259
+
260
+ .. code::
261
+ import json
262
+
263
+ config = model.get_config() # all __init__ params
264
+ with open("config.json", "w") as f:
265
+ json.dump(config, f)
266
+
267
+ model2 = DeepSleepNet.from_config(config) # reconstruct (no weights)
268
+
269
+ All model parameters (both EEG-specific and model-specific such as
270
+ dropout rates, activation functions, number of filters) are automatically
271
+ saved to the Hub and restored when loading.
272
+
273
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
274
+ </div>
275
+
276
+ ## Citation
277
+
278
+ Please cite both the original paper for this architecture (see the
279
+ *References* section above) and braindecode:
280
+
281
+ ```bibtex
282
+ @article{aristimunha2025braindecode,
283
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
284
+ author = {Aristimunha, Bruno and others},
285
+ journal = {Zenodo},
286
+ year = {2025},
287
+ doi = {10.5281/zenodo.17699192},
288
+ }
289
+ ```
290
+
291
+ ## License
292
+
293
+ BSD-3-Clause for the model code (matching braindecode).
294
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
295
+ inherit the licence of that checkpoint and its training corpus.