bruAristimunha commited on
Commit
96fc511
·
verified ·
1 Parent(s): d0e8cfe

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +241 -0
README.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - foundation-model
12
+ - convolutional
13
+ ---
14
+
15
+ # InterpolatedSignalJEPA
16
+
17
+ Channel-interpolating wrapper around :class:`SignalJEPA`.
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.InterpolatedSignalJEPA` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import InterpolatedSignalJEPA
33
+
34
+ model = InterpolatedSignalJEPA(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.InterpolatedSignalJEPA.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/interpolated.py#L1>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>Channel-interpolating wrapper around :class:`SignalJEPA`.</p>
60
+ <p>:bdg-dark-line:`Channel`</p>
61
+ <p>Accepts arbitrary user <span class="docutils literal">chs_info</span> and projects them to the
62
+ backbone's canonical channel set via
63
+ :class:`~braindecode.modules.ChannelInterpolationLayer`.</p>
64
+ <p>For all other parameters and behavior see the backbone
65
+ documentation reproduced below.</p>
66
+ <p>Architecture introduced in signal-JEPA for self-supervised pre-training, Guetschel, P et al (2024) [1]_</p>
67
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span>
68
+
69
+ :bdg-dark-line:`Channel`<span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#d9534f;color:white;font-size:11px;font-weight:600;margin-right:4px;">Foundation Model</span>
70
+
71
+
72
+
73
+ This model is not meant for classification but for SSL pre-training.
74
+ Its output shape depends on the input shape.
75
+ For classification purposes, three variants of this model are available:
76
+
77
+ * :class:`SignalJEPA_Contextual`
78
+ * :class:`SignalJEPA_PostLocal`
79
+ * :class:`SignalJEPA_PreLocal`
80
+
81
+ The classification architectures can either be instantiated from scratch
82
+ (random parameters) or from a pre-trained :class:`SignalJEPA` model.
83
+
84
+ .. versionadded:: 0.9
85
+
86
+ .. rubric:: Pretrained Weights
87
+
88
+ Two checkpoint variants are published on HuggingFace:
89
+
90
+ - ``braindecode/signal-jepa``: full encoder + pre-trained channel embedding
91
+ table (62 rows, one per pre-training channel). Use when your channel
92
+ names are a subset of the pre-training set (``channel_embedding='pretrain_aligned'``).
93
+ - ``braindecode/signal-jepa_without-chans``: same encoder, channel
94
+ embedding weights stripped. Use when your channel set differs from
95
+ pre-training; the table is freshly initialized from your channel
96
+ locations (``channel_embedding='scratch'``, the default).
97
+
98
+ .. important::
99
+ **Pre-trained Weights Available**
100
+
101
+ .. code:: python
102
+ from braindecode.models import SignalJEPA
103
+
104
+ # Load encoder + pre-trained channel embeddings (62 channels):
105
+ model = SignalJEPA.from_pretrained("braindecode/signal-jepa")
106
+
107
+ # Select a subset of the 62 pre-training channels:
108
+ model = SignalJEPA.from_pretrained(
109
+ "braindecode/signal-jepa",
110
+ chs_info=[{"ch_name": "Fp1", "loc": [...]}, {"ch_name": "Cz", "loc": [...]}],
111
+ )
112
+
113
+ # Arbitrary channel set (channel embedding re-initialized from your locs):
114
+ model = SignalJEPA.from_pretrained(
115
+ "braindecode/signal-jepa_without-chans",
116
+ chs_info=[{"ch_name": "A", "loc": [...]}, ...],
117
+ strict=False,
118
+ )
119
+
120
+ To push your own trained model to the Hub:
121
+
122
+ .. code:: python
123
+ model.push_to_hub(
124
+ repo_id="username/my-sjepa-model",
125
+ commit_message="Upload trained SignalJEPA model",
126
+ )
127
+
128
+ Requires installing ``braindecode[hub]`` for Hub integration.
129
+
130
+ .. rubric:: Usage
131
+
132
+ .. code:: python
133
+ from braindecode.models import SignalJEPA
134
+
135
+ model = SignalJEPA(
136
+ chs_info=[{"ch_name": "Fp1", "loc": [...]}, ...],
137
+ input_window_seconds=16.0,
138
+ sfreq=128,
139
+ )
140
+
141
+ # Forward: (batch, n_chans, n_times) -> (batch, n_chans * n_patches, emb_dim)
142
+ features = model(eeg_data)
143
+
144
+ .. warning::
145
+
146
+ Pre-trained at **128 Hz** on EEG bandpass-filtered between
147
+ **0.5 and 40 Hz** and rescaled by a factor of :math:`10^{6}`
148
+ (volts to microvolts). Apply the same preprocessing to your
149
+ data to match the pre-training distribution.
150
+
151
+ References
152
+ ----------
153
+ .. [1] Guetschel, P., Moreau, T., & Tangermann, M. (2024).
154
+ S-JEPA: towards seamless cross-dataset transfer through dynamic spatial attention.
155
+ In 9th Graz Brain-Computer Interface Conference, https://www.doi.org/10.3217/978-3-99161-014-4-003
156
+
157
+ .. rubric:: Hugging Face Hub integration
158
+
159
+ When the optional ``huggingface_hub`` package is installed, all models
160
+ automatically gain the ability to be pushed to and loaded from the
161
+ Hugging Face Hub. Install with::
162
+
163
+ pip install braindecode[hub]
164
+
165
+ **Pushing a model to the Hub:**
166
+
167
+ .. code::
168
+ from braindecode.models import SignalJEPA
169
+
170
+ # Train your model
171
+ model = SignalJEPA(n_chans=22, n_outputs=4, n_times=1000)
172
+ # ... training code ...
173
+
174
+ # Push to the Hub
175
+ model.push_to_hub(
176
+ repo_id="username/my-signaljepa-model",
177
+ commit_message="Initial model upload",
178
+ )
179
+
180
+ **Loading a model from the Hub:**
181
+
182
+ .. code::
183
+ from braindecode.models import SignalJEPA
184
+
185
+ # Load pretrained model
186
+ model = SignalJEPA.from_pretrained("username/my-signaljepa-model")
187
+
188
+ # Load with a different number of outputs (head is rebuilt automatically)
189
+ model = SignalJEPA.from_pretrained("username/my-signaljepa-model", n_outputs=4)
190
+
191
+ **Extracting features and replacing the head:**
192
+
193
+ .. code::
194
+ import torch
195
+
196
+ x = torch.randn(1, model.n_chans, model.n_times)
197
+ # Extract encoder features (consistent dict across all models)
198
+ out = model(x, return_features=True)
199
+ features = out["features"]
200
+
201
+ # Replace the classification head
202
+ model.reset_head(n_outputs=10)
203
+
204
+ **Saving and restoring full configuration:**
205
+
206
+ .. code::
207
+ import json
208
+
209
+ config = model.get_config() # all __init__ params
210
+ with open("config.json", "w") as f:
211
+ json.dump(config, f)
212
+
213
+ model2 = SignalJEPA.from_config(config) # reconstruct (no weights)
214
+
215
+ All model parameters (both EEG-specific and model-specific such as
216
+ dropout rates, activation functions, number of filters) are automatically
217
+ saved to the Hub and restored when loading.
218
+
219
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
220
+ </div>
221
+
222
+ ## Citation
223
+
224
+ Please cite both the original paper for this architecture (see the
225
+ *References* section above) and braindecode:
226
+
227
+ ```bibtex
228
+ @article{aristimunha2025braindecode,
229
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
230
+ author = {Aristimunha, Bruno and others},
231
+ journal = {Zenodo},
232
+ year = {2025},
233
+ doi = {10.5281/zenodo.17699192},
234
+ }
235
+ ```
236
+
237
+ ## License
238
+
239
+ BSD-3-Clause for the model code (matching braindecode).
240
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
241
+ inherit the licence of that checkpoint and its training corpus.