Spaces:
Sleeping
Sleeping
| """Braindecode Model Explorer — interactive architecture browser. | |
| Hugging Face Space that browses every EEG architecture in braindecode. | |
| For each model: rendered docstring (figure, references, parameter list) | |
| plus live instantiation to inspect param count and layer summary. | |
| No pretrained weights are loaded — this is a pure architecture browser. | |
| Aesthetic: editorial scientific instrument. IBM Plex (Sans / Serif / | |
| Mono), Okabe-Ito colorblind-safe palette, warm-paper background. All | |
| visual styling lives in GLOBAL_CSS below; the docstring renderer emits | |
| structural HTML only. | |
| """ | |
| from __future__ import annotations | |
| import inspect | |
| from typing import Any | |
| import gradio as gr | |
| import torch | |
| from torchinfo import summary | |
| import braindecode.models as M | |
| from braindecode.models.base import EEGModuleMixin | |
| from docstring_renderer import ( | |
| get_signature_str, | |
| get_source_link, | |
| render_docstring_html, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Catalog: discover every EEGModuleMixin subclass exported by braindecode. | |
| # --------------------------------------------------------------------------- | |
| def _discover_models() -> dict[str, type]: | |
| catalog: dict[str, type] = {} | |
| for name in sorted(getattr(M, "__all__", []) or dir(M)): | |
| if name.startswith("_"): | |
| continue | |
| obj = getattr(M, name, None) | |
| if ( | |
| inspect.isclass(obj) | |
| and issubclass(obj, EEGModuleMixin) | |
| and obj is not EEGModuleMixin | |
| ): | |
| catalog[name] = obj | |
| return catalog | |
| MODELS: dict[str, type] = _discover_models() | |
| MODEL_NAMES: list[str] = sorted(MODELS.keys()) | |
| DEFAULT_MODEL = "EEGNetv4" if "EEGNetv4" in MODELS else MODEL_NAMES[0] | |
| try: | |
| import braindecode as _bd | |
| BD_VERSION = getattr(_bd, "__version__", "unknown") | |
| except Exception: | |
| BD_VERSION = "unknown" | |
| # --------------------------------------------------------------------------- | |
| # Heuristic defaults for the signal-shape form. Different model families | |
| # expect very different inputs (sleep stagers want 30 s @ 100 Hz; motor- | |
| # imagery models want ~4 s @ 250 Hz). | |
| # --------------------------------------------------------------------------- | |
| DEFAULTS = { | |
| "sleep": dict(n_chans=2, sfreq=100, input_window_seconds=30.0, n_outputs=5), | |
| "biot": dict(n_chans=16, sfreq=200, input_window_seconds=10.0, n_outputs=2), | |
| "bendr": dict(n_chans=20, sfreq=256, input_window_seconds=4.0, n_outputs=2), | |
| "labram": dict(n_chans=22, sfreq=200, input_window_seconds=4.0, n_outputs=2), | |
| "default": dict(n_chans=22, sfreq=250, input_window_seconds=4.0, n_outputs=4), | |
| } | |
| def _defaults_for(name: str) -> dict[str, Any]: | |
| lower = name.lower() | |
| if "sleep" in lower or name in {"USleep", "AttnSleep", "DeepSleepNet"}: | |
| return DEFAULTS["sleep"] | |
| if "biot" in lower: | |
| return DEFAULTS["biot"] | |
| if "bendr" in lower: | |
| return DEFAULTS["bendr"] | |
| if "labram" in lower or "cbramod" in lower or "eegpt" in lower: | |
| return DEFAULTS["labram"] | |
| return DEFAULTS["default"] | |
| # --------------------------------------------------------------------------- | |
| # Global stylesheet — IBM Plex + Okabe-Ito + spatial system. Injected | |
| # once via gr.Blocks(css=...). | |
| # --------------------------------------------------------------------------- | |
| GLOBAL_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;500;600;700&family=IBM+Plex+Serif:wght@500;600&family=IBM+Plex+Mono:wght@400;500;600&display=swap'); | |
| :root { | |
| --bd-blue: #0072B2; | |
| --bd-green: #009E73; | |
| --bd-orange: #D55E00; | |
| --bd-pink: #CC79A7; | |
| --bd-yellow: #E69F00; | |
| --bd-skyblue: #56B4E9; | |
| --bd-paper: #FAFAF7; | |
| --bd-paper-deep: #F1EFE8; | |
| --bd-rule: #E5E2D9; | |
| --bd-ink: #1a1a1a; | |
| --bd-meta: #6b6b6b; | |
| } | |
| /* Container & background ---------------------------------------------- */ | |
| body, .gradio-container { | |
| background: var(--bd-paper) !important; | |
| font-family: 'IBM Plex Sans', system-ui, sans-serif !important; | |
| color: var(--bd-ink); | |
| } | |
| .gradio-container { max-width: 1320px !important; padding: 0 24px !important; } | |
| .gradio-container * { font-family: inherit; } | |
| /* Header band --------------------------------------------------------- */ | |
| .bd-header { | |
| display: flex; | |
| align-items: baseline; | |
| justify-content: space-between; | |
| padding: 22px 0 18px 0; | |
| border-bottom: 1px solid var(--bd-rule); | |
| margin-bottom: 28px; | |
| flex-wrap: wrap; | |
| gap: 12px; | |
| } | |
| .bd-header-title { | |
| font-family: 'IBM Plex Serif', serif; | |
| font-size: 26px; | |
| font-weight: 600; | |
| color: var(--bd-ink); | |
| letter-spacing: -0.015em; | |
| } | |
| .bd-header-title .bd-mark { | |
| color: var(--bd-blue); | |
| font-weight: 500; | |
| font-style: italic; | |
| } | |
| .bd-header-meta { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 12px; | |
| color: var(--bd-meta); | |
| letter-spacing: 0.04em; | |
| text-transform: uppercase; | |
| } | |
| .bd-header-meta .bd-dot { color: var(--bd-blue); margin: 0 8px; } | |
| /* Info card (model display) ------------------------------------------- */ | |
| .bd-info { | |
| margin: 0 0 24px 0; | |
| padding-bottom: 18px; | |
| border-bottom: 2px solid var(--bd-blue); | |
| } | |
| .bd-display { | |
| font-family: 'IBM Plex Serif', serif; | |
| font-size: 36px; | |
| font-weight: 600; | |
| color: var(--bd-blue); | |
| letter-spacing: -0.02em; | |
| line-height: 1.1; | |
| margin: 0 0 6px 0; | |
| } | |
| .bd-tagline { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 14px; | |
| color: var(--bd-meta); | |
| margin-bottom: 14px; | |
| font-style: italic; | |
| } | |
| .bd-sig { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 13px; | |
| line-height: 1.55; | |
| white-space: pre; | |
| overflow-x: auto; | |
| padding: 10px 14px; | |
| background: var(--bd-paper-deep); | |
| border-left: 2px solid var(--bd-blue); | |
| color: #2a2a2a; | |
| margin: 0 0 10px 0; | |
| } | |
| .bd-sig::-webkit-scrollbar { height: 6px; } | |
| .bd-sig::-webkit-scrollbar-thumb { background: var(--bd-rule); border-radius: 3px; } | |
| .bd-sig::-webkit-scrollbar-thumb:hover { background: var(--bd-blue); } | |
| .bd-source { | |
| display: inline-block; | |
| color: var(--bd-meta); | |
| font-size: 13px; | |
| text-decoration: none; | |
| border-bottom: 1px solid transparent; | |
| transition: all 0.15s ease; | |
| } | |
| .bd-source:hover { color: var(--bd-blue); border-bottom-color: var(--bd-blue); } | |
| /* Stat tile (live param count) ---------------------------------------- */ | |
| .bd-stat-card { | |
| background: var(--bd-paper-deep); | |
| border: 1px solid var(--bd-rule); | |
| border-radius: 4px; | |
| padding: 14px 16px; | |
| margin-top: 12px; | |
| } | |
| .bd-meta-label { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 11px; | |
| font-weight: 600; | |
| letter-spacing: 0.1em; | |
| text-transform: uppercase; | |
| color: var(--bd-meta); | |
| margin: 0 0 4px 0; | |
| } | |
| .bd-stat { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 28px; | |
| font-weight: 600; | |
| font-variant-numeric: tabular-nums; | |
| color: var(--bd-blue); | |
| line-height: 1; | |
| } | |
| .bd-stat-sub { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 11px; | |
| color: var(--bd-meta); | |
| margin-top: 6px; | |
| letter-spacing: 0.02em; | |
| } | |
| /* Section heading separator ------------------------------------------- */ | |
| .bd-section-rule { | |
| display: flex; align-items: center; | |
| gap: 12px; | |
| margin: 28px 0 14px 0; | |
| } | |
| .bd-section-rule::before, .bd-section-rule::after { | |
| content: ""; flex: 1; | |
| height: 1px; background: var(--bd-rule); | |
| } | |
| .bd-section-rule span { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 11px; | |
| font-weight: 600; | |
| letter-spacing: 0.14em; | |
| text-transform: uppercase; | |
| color: var(--bd-meta); | |
| } | |
| /* Docstring rendering (consumed by render_docstring_html) ------------- */ | |
| .bd-doc { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 16px; | |
| line-height: 1.65; | |
| color: var(--bd-ink); | |
| } | |
| .bd-doc p, .bd-doc li { font-size: 16px; margin: 8px 0; } | |
| .bd-doc h1, .bd-doc h2, .bd-doc h3 { | |
| font-family: 'IBM Plex Serif', serif; | |
| color: var(--bd-blue); | |
| margin-top: 1.4em; | |
| margin-bottom: 0.45em; | |
| letter-spacing: -0.01em; | |
| } | |
| .bd-doc h1 { font-size: 24px; font-weight: 600; } | |
| .bd-doc h2 { font-size: 20px; font-weight: 600; } | |
| .bd-doc h3 { font-size: 17px; font-weight: 600; | |
| font-family: 'IBM Plex Sans', sans-serif; } | |
| .bd-doc pre { | |
| background: var(--bd-paper-deep); | |
| padding: 12px 14px; | |
| border-radius: 4px; | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 13px; | |
| line-height: 1.55; | |
| overflow-x: auto; | |
| border-left: 2px solid var(--bd-blue); | |
| } | |
| .bd-doc code { | |
| background: rgba(0, 114, 178, 0.08); | |
| padding: 1px 6px; | |
| border-radius: 3px; | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 14px; | |
| color: #0a4d77; | |
| } | |
| .bd-doc pre code { background: transparent; padding: 0; color: inherit; font-size: inherit; } | |
| .bd-doc img { | |
| max-width: 100%; | |
| display: block; | |
| margin: 18px auto; | |
| border-radius: 4px; | |
| box-shadow: 0 2px 14px rgba(0, 114, 178, 0.10); | |
| } | |
| .bd-doc table { | |
| border-collapse: collapse; | |
| margin: 14px 0; | |
| font-size: 14px; | |
| font-variant-numeric: tabular-nums; | |
| width: 100%; | |
| } | |
| .bd-doc th, .bd-doc td { | |
| border: 1px solid var(--bd-rule); | |
| padding: 7px 12px; | |
| text-align: left; | |
| vertical-align: top; | |
| } | |
| .bd-doc th { background: var(--bd-paper-deep); font-weight: 600; color: var(--bd-meta); font-size: 12px; letter-spacing: 0.04em; text-transform: uppercase; } | |
| .bd-doc .admonition { | |
| border-left: 3px solid var(--bd-blue); | |
| background: rgba(0, 114, 178, 0.05); | |
| padding: 10px 16px; | |
| margin: 16px 0; | |
| border-radius: 0 4px 4px 0; | |
| font-size: 15px; | |
| } | |
| .bd-doc .admonition.important { border-color: var(--bd-orange); background: rgba(213, 94, 0, 0.05); } | |
| .bd-doc .admonition.note { border-color: var(--bd-green); background: rgba(0, 158, 115, 0.05); } | |
| .bd-doc .admonition-title { font-weight: 600; margin-bottom: 4px; } | |
| .bd-doc dl.field-list { | |
| display: grid; grid-template-columns: max-content auto; | |
| gap: 6px 16px; font-size: 15px; margin: 12px 0; | |
| } | |
| .bd-doc dl.field-list dt { font-weight: 600; color: var(--bd-meta); font-size: 13px; letter-spacing: 0.03em; text-transform: uppercase; padding-top: 2px; } | |
| .bd-doc a { color: var(--bd-blue); text-decoration: none; border-bottom: 1px solid rgba(0, 114, 178, 0.3); } | |
| .bd-doc a:hover { border-bottom-color: var(--bd-blue); } | |
| /* Inline badge produced by docstring_renderer ------------------------- */ | |
| .bd-badge { | |
| display: inline-block; | |
| padding: 3px 10px; | |
| border-radius: 3px; | |
| color: white; | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 12px; | |
| font-weight: 600; | |
| letter-spacing: 0.02em; | |
| margin: 0 4px 4px 0; | |
| } | |
| /* Form labels --------------------------------------------------------- */ | |
| label > span, .gradio-container .label-wrap span { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 12px !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.06em !important; | |
| text-transform: uppercase !important; | |
| color: var(--bd-meta) !important; | |
| } | |
| input[type="number"], textarea, select { | |
| font-family: 'IBM Plex Mono', monospace !important; | |
| font-size: 14px !important; | |
| } | |
| /* Primary button ------------------------------------------------------ */ | |
| button.primary, button[variant="primary"], .bd-cta { | |
| background: var(--bd-blue) !important; | |
| color: white !important; | |
| font-family: 'IBM Plex Sans', sans-serif !important; | |
| font-size: 13px !important; | |
| font-weight: 600 !important; | |
| letter-spacing: 0.06em !important; | |
| text-transform: uppercase !important; | |
| border-radius: 4px !important; | |
| padding: 11px 18px !important; | |
| border: none !important; | |
| transition: background 0.15s ease; | |
| } | |
| button.primary:hover, button[variant="primary"]:hover, .bd-cta:hover { | |
| background: #005a8c !important; | |
| } | |
| /* Footer -------------------------------------------------------------- */ | |
| .bd-footer { | |
| margin: 40px 0 20px 0; | |
| padding-top: 18px; | |
| border-top: 1px solid var(--bd-rule); | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| font-size: 12px; | |
| color: var(--bd-meta); | |
| letter-spacing: 0.04em; | |
| display: flex; | |
| justify-content: space-between; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| } | |
| .bd-footer a { color: var(--bd-blue); text-decoration: none; } | |
| .bd-footer a:hover { text-decoration: underline; } | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # HTML fragments | |
| # --------------------------------------------------------------------------- | |
| import html as _html | |
| def _info_card(name: str) -> str: | |
| """Primary visual anchor — display name + scrollable signature + source link.""" | |
| cls = MODELS[name] | |
| sig = _html.escape(get_signature_str(cls)) | |
| link = get_source_link(cls) or "#" | |
| docstring_first = (cls.__doc__ or "").strip().splitlines() | |
| tagline = "" | |
| if docstring_first: | |
| first = docstring_first[0].strip() | |
| # Strip rST cite markers like [Foo2023]_ | |
| import re as _re | |
| first = _re.sub(r"\[\w+\]_", "", first).strip() | |
| tagline = _html.escape(first[:200]) | |
| return ( | |
| f'<div class="bd-info">' | |
| f' <div class="bd-display">{_html.escape(name)}</div>' | |
| f' {f"<div class=\"bd-tagline\">{tagline}</div>" if tagline else ""}' | |
| f' <pre class="bd-sig">{sig}</pre>' | |
| f' <a class="bd-source" href="{link}" target="_blank">↗ Source on GitHub</a>' | |
| f'</div>' | |
| ) | |
| def _stat_tile(params: int | None = None, *, n_chans: int | None = None, | |
| n_times: int | None = None, out_shape=None) -> str: | |
| """Live parameter count + input/output shapes.""" | |
| if params is None: | |
| return ( | |
| '<div class="bd-stat-card">' | |
| '<div class="bd-meta-label">Parameters</div>' | |
| '<div class="bd-stat" style="color: var(--bd-meta);">—</div>' | |
| '<div class="bd-stat-sub">press build to instantiate</div>' | |
| '</div>' | |
| ) | |
| pretty_out = ( | |
| f"({', '.join(str(d) for d in out_shape)})" | |
| if isinstance(out_shape, tuple) | |
| else str(out_shape) | |
| ) | |
| return ( | |
| '<div class="bd-stat-card">' | |
| '<div class="bd-meta-label">Parameters</div>' | |
| f'<div class="bd-stat">{params:,}</div>' | |
| f'<div class="bd-stat-sub">in (b, {n_chans}, {n_times}) → {pretty_out}</div>' | |
| '</div>' | |
| ) | |
| def _header_band() -> str: | |
| return ( | |
| '<div class="bd-header">' | |
| '<div class="bd-header-title">braindecode <span class="bd-mark">model explorer</span></div>' | |
| f'<div class="bd-header-meta">v{BD_VERSION}<span class="bd-dot">•</span>{len(MODELS)} architectures<span class="bd-dot">•</span>no weights</div>' | |
| '</div>' | |
| ) | |
| def _section_rule(label: str) -> str: | |
| return f'<div class="bd-section-rule"><span>{label}</span></div>' | |
| def _footer() -> str: | |
| return ( | |
| '<div class="bd-footer">' | |
| '<div>An architecture browser for <a href="https://braindecode.org">braindecode</a>. ' | |
| 'No pretrained weights served here — see ' | |
| '<a href="https://huggingface.co/braindecode">huggingface.co/braindecode</a>.</div>' | |
| '<div><a href="https://github.com/braindecode/braindecode">github.com/braindecode/braindecode</a></div>' | |
| '</div>' | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Event handlers | |
| # --------------------------------------------------------------------------- | |
| def show_model(name: str): | |
| if name not in MODELS: | |
| return "", "", _stat_tile(), {}, {}, {}, {} | |
| info = _info_card(name) | |
| doc_html = render_docstring_html(MODELS[name].__doc__) | |
| d = _defaults_for(name) | |
| return ( | |
| info, | |
| doc_html, | |
| _stat_tile(), # reset stat tile when switching models | |
| gr.update(value=d["n_chans"]), | |
| gr.update(value=d["sfreq"]), | |
| gr.update(value=d["input_window_seconds"]), | |
| gr.update(value=d["n_outputs"]), | |
| ) | |
| def instantiate(name, n_chans, sfreq, window_s, n_outputs): | |
| """Build the model and return (stat_html, layer_summary_md).""" | |
| if name not in MODELS: | |
| return _stat_tile(), "Pick a model first." | |
| cls = MODELS[name] | |
| n_times = int(round(window_s * sfreq)) | |
| kwargs = dict( | |
| n_chans=int(n_chans), | |
| sfreq=float(sfreq), | |
| input_window_seconds=float(window_s), | |
| n_outputs=int(n_outputs), | |
| ) | |
| sig_params = set(inspect.signature(cls.__init__).parameters) | |
| kwargs = {k: v for k, v in kwargs.items() if k in sig_params} | |
| try: | |
| model = cls(**kwargs) | |
| except Exception as exc: # noqa: BLE001 | |
| err = f"❌ **Failed to instantiate `{name}`** with `{kwargs}`:\n```\n{exc}\n```" | |
| return _stat_tile(), err | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| try: | |
| info = summary( | |
| model, | |
| input_size=(1, int(n_chans), n_times), | |
| depth=3, | |
| verbose=0, | |
| col_names=("output_size", "num_params"), | |
| ) | |
| summary_str = str(info) | |
| except Exception as exc: # noqa: BLE001 | |
| summary_str = f"(torchinfo summary unavailable: {exc})" | |
| out_shape: Any = "?" | |
| try: | |
| x = torch.randn(2, int(n_chans), n_times) | |
| with torch.no_grad(): | |
| y = model(x) | |
| out_shape = tuple(y.shape) if hasattr(y, "shape") else type(y).__name__ | |
| except Exception as exc: # noqa: BLE001 | |
| out_shape = f"forward failed: {exc}" | |
| stat = _stat_tile( | |
| params=n_params, n_chans=int(n_chans), n_times=n_times, out_shape=out_shape | |
| ) | |
| return stat, f"```\n{summary_str}\n```" | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| def build_app() -> gr.Blocks: | |
| theme = gr.themes.Soft( | |
| primary_hue=gr.themes.colors.blue, | |
| font=[gr.themes.GoogleFont("IBM Plex Sans"), "system-ui", "sans-serif"], | |
| font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"], | |
| ) | |
| with gr.Blocks( | |
| title="Braindecode Model Explorer", | |
| theme=theme, | |
| css=GLOBAL_CSS, | |
| ) as app: | |
| gr.HTML(_header_band()) | |
| with gr.Row(equal_height=False): | |
| # ---------- LEFT: controls + stat tile ---------- | |
| with gr.Column(scale=1, min_width=280): | |
| model_dd = gr.Dropdown( | |
| choices=MODEL_NAMES, | |
| value=DEFAULT_MODEL, | |
| label="Architecture", | |
| interactive=True, | |
| filterable=True, | |
| ) | |
| gr.HTML(_section_rule("Signal configuration")) | |
| with gr.Group(): | |
| n_chans = gr.Number(value=22, label="n_chans", precision=0) | |
| sfreq = gr.Number(value=250, label="sfreq · Hz") | |
| window_s = gr.Number(value=4.0, label="window · seconds") | |
| n_outputs = gr.Number(value=4, label="n_outputs", precision=0) | |
| run_btn = gr.Button( | |
| "Build network", variant="primary", elem_classes="bd-cta" | |
| ) | |
| stat_html = gr.HTML(_stat_tile()) | |
| # ---------- RIGHT: model info + docstring ---------- | |
| with gr.Column(scale=3): | |
| info_html = gr.HTML(_info_card(DEFAULT_MODEL)) | |
| gr.HTML(_section_rule("Architecture documentation")) | |
| doc_html = gr.HTML( | |
| render_docstring_html(MODELS[DEFAULT_MODEL].__doc__) | |
| ) | |
| with gr.Accordion("Layer summary (after build)", open=False): | |
| summary_md = gr.Markdown( | |
| "_Press **Build network** to populate the summary._" | |
| ) | |
| gr.HTML(_footer()) | |
| # ---------- wiring ---------- | |
| model_dd.change( | |
| show_model, | |
| inputs=model_dd, | |
| outputs=[info_html, doc_html, stat_html, n_chans, sfreq, window_s, n_outputs], | |
| ) | |
| run_btn.click( | |
| instantiate, | |
| inputs=[model_dd, n_chans, sfreq, window_s, n_outputs], | |
| outputs=[stat_html, summary_md], | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| # On HF Spaces the sandbox blocks localhost-only binds; expose on 0.0.0.0 | |
| # so the front-door proxy can reach us. Locally this still works fine. | |
| build_app().launch(server_name="0.0.0.0", server_port=7860) | |