"""Braindecode Model Explorer — interactive architecture browser. Hugging Face Space that browses every EEG architecture in braindecode. For each model: rendered docstring (figure, references, parameter list) plus live instantiation to inspect param count and layer summary. No pretrained weights are loaded — this is a pure architecture browser. Aesthetic: editorial scientific instrument. IBM Plex (Sans / Serif / Mono), Okabe-Ito colorblind-safe palette, warm-paper background. All visual styling lives in GLOBAL_CSS below; the docstring renderer emits structural HTML only. """ from __future__ import annotations import inspect from typing import Any import gradio as gr import torch from torchinfo import summary import braindecode.models as M from braindecode.models.base import EEGModuleMixin from docstring_renderer import ( get_signature_str, get_source_link, render_docstring_html, ) # --------------------------------------------------------------------------- # Catalog: discover every EEGModuleMixin subclass exported by braindecode. # --------------------------------------------------------------------------- def _discover_models() -> dict[str, type]: catalog: dict[str, type] = {} for name in sorted(getattr(M, "__all__", []) or dir(M)): if name.startswith("_"): continue obj = getattr(M, name, None) if ( inspect.isclass(obj) and issubclass(obj, EEGModuleMixin) and obj is not EEGModuleMixin ): catalog[name] = obj return catalog MODELS: dict[str, type] = _discover_models() MODEL_NAMES: list[str] = sorted(MODELS.keys()) DEFAULT_MODEL = "EEGNetv4" if "EEGNetv4" in MODELS else MODEL_NAMES[0] try: import braindecode as _bd BD_VERSION = getattr(_bd, "__version__", "unknown") except Exception: BD_VERSION = "unknown" # --------------------------------------------------------------------------- # Heuristic defaults for the signal-shape form. Different model families # expect very different inputs (sleep stagers want 30 s @ 100 Hz; motor- # imagery models want ~4 s @ 250 Hz). # --------------------------------------------------------------------------- DEFAULTS = { "sleep": dict(n_chans=2, sfreq=100, input_window_seconds=30.0, n_outputs=5), "biot": dict(n_chans=16, sfreq=200, input_window_seconds=10.0, n_outputs=2), "bendr": dict(n_chans=20, sfreq=256, input_window_seconds=4.0, n_outputs=2), "labram": dict(n_chans=22, sfreq=200, input_window_seconds=4.0, n_outputs=2), "default": dict(n_chans=22, sfreq=250, input_window_seconds=4.0, n_outputs=4), } def _defaults_for(name: str) -> dict[str, Any]: lower = name.lower() if "sleep" in lower or name in {"USleep", "AttnSleep", "DeepSleepNet"}: return DEFAULTS["sleep"] if "biot" in lower: return DEFAULTS["biot"] if "bendr" in lower: return DEFAULTS["bendr"] if "labram" in lower or "cbramod" in lower or "eegpt" in lower: return DEFAULTS["labram"] return DEFAULTS["default"] # --------------------------------------------------------------------------- # Global stylesheet — IBM Plex + Okabe-Ito + spatial system. Injected # once via gr.Blocks(css=...). # --------------------------------------------------------------------------- GLOBAL_CSS = """ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:wght@400;500;600;700&family=IBM+Plex+Serif:wght@500;600&family=IBM+Plex+Mono:wght@400;500;600&display=swap'); :root { --bd-blue: #0072B2; --bd-green: #009E73; --bd-orange: #D55E00; --bd-pink: #CC79A7; --bd-yellow: #E69F00; --bd-skyblue: #56B4E9; --bd-paper: #FAFAF7; --bd-paper-deep: #F1EFE8; --bd-rule: #E5E2D9; --bd-ink: #1a1a1a; --bd-meta: #6b6b6b; } /* Container & background ---------------------------------------------- */ body, .gradio-container { background: var(--bd-paper) !important; font-family: 'IBM Plex Sans', system-ui, sans-serif !important; color: var(--bd-ink); } .gradio-container { max-width: 1320px !important; padding: 0 24px !important; } .gradio-container * { font-family: inherit; } /* Header band --------------------------------------------------------- */ .bd-header { display: flex; align-items: baseline; justify-content: space-between; padding: 22px 0 18px 0; border-bottom: 1px solid var(--bd-rule); margin-bottom: 28px; flex-wrap: wrap; gap: 12px; } .bd-header-title { font-family: 'IBM Plex Serif', serif; font-size: 26px; font-weight: 600; color: var(--bd-ink); letter-spacing: -0.015em; } .bd-header-title .bd-mark { color: var(--bd-blue); font-weight: 500; font-style: italic; } .bd-header-meta { font-family: 'IBM Plex Mono', monospace; font-size: 12px; color: var(--bd-meta); letter-spacing: 0.04em; text-transform: uppercase; } .bd-header-meta .bd-dot { color: var(--bd-blue); margin: 0 8px; } /* Info card (model display) ------------------------------------------- */ .bd-info { margin: 0 0 24px 0; padding-bottom: 18px; border-bottom: 2px solid var(--bd-blue); } .bd-display { font-family: 'IBM Plex Serif', serif; font-size: 36px; font-weight: 600; color: var(--bd-blue); letter-spacing: -0.02em; line-height: 1.1; margin: 0 0 6px 0; } .bd-tagline { font-family: 'IBM Plex Sans', sans-serif; font-size: 14px; color: var(--bd-meta); margin-bottom: 14px; font-style: italic; } .bd-sig { font-family: 'IBM Plex Mono', monospace; font-size: 13px; line-height: 1.55; white-space: pre; overflow-x: auto; padding: 10px 14px; background: var(--bd-paper-deep); border-left: 2px solid var(--bd-blue); color: #2a2a2a; margin: 0 0 10px 0; } .bd-sig::-webkit-scrollbar { height: 6px; } .bd-sig::-webkit-scrollbar-thumb { background: var(--bd-rule); border-radius: 3px; } .bd-sig::-webkit-scrollbar-thumb:hover { background: var(--bd-blue); } .bd-source { display: inline-block; color: var(--bd-meta); font-size: 13px; text-decoration: none; border-bottom: 1px solid transparent; transition: all 0.15s ease; } .bd-source:hover { color: var(--bd-blue); border-bottom-color: var(--bd-blue); } /* Stat tile (live param count) ---------------------------------------- */ .bd-stat-card { background: var(--bd-paper-deep); border: 1px solid var(--bd-rule); border-radius: 4px; padding: 14px 16px; margin-top: 12px; } .bd-meta-label { font-family: 'IBM Plex Sans', sans-serif; font-size: 11px; font-weight: 600; letter-spacing: 0.1em; text-transform: uppercase; color: var(--bd-meta); margin: 0 0 4px 0; } .bd-stat { font-family: 'IBM Plex Mono', monospace; font-size: 28px; font-weight: 600; font-variant-numeric: tabular-nums; color: var(--bd-blue); line-height: 1; } .bd-stat-sub { font-family: 'IBM Plex Mono', monospace; font-size: 11px; color: var(--bd-meta); margin-top: 6px; letter-spacing: 0.02em; } /* Section heading separator ------------------------------------------- */ .bd-section-rule { display: flex; align-items: center; gap: 12px; margin: 28px 0 14px 0; } .bd-section-rule::before, .bd-section-rule::after { content: ""; flex: 1; height: 1px; background: var(--bd-rule); } .bd-section-rule span { font-family: 'IBM Plex Sans', sans-serif; font-size: 11px; font-weight: 600; letter-spacing: 0.14em; text-transform: uppercase; color: var(--bd-meta); } /* Docstring rendering (consumed by render_docstring_html) ------------- */ .bd-doc { font-family: 'IBM Plex Sans', sans-serif; font-size: 16px; line-height: 1.65; color: var(--bd-ink); } .bd-doc p, .bd-doc li { font-size: 16px; margin: 8px 0; } .bd-doc h1, .bd-doc h2, .bd-doc h3 { font-family: 'IBM Plex Serif', serif; color: var(--bd-blue); margin-top: 1.4em; margin-bottom: 0.45em; letter-spacing: -0.01em; } .bd-doc h1 { font-size: 24px; font-weight: 600; } .bd-doc h2 { font-size: 20px; font-weight: 600; } .bd-doc h3 { font-size: 17px; font-weight: 600; font-family: 'IBM Plex Sans', sans-serif; } .bd-doc pre { background: var(--bd-paper-deep); padding: 12px 14px; border-radius: 4px; font-family: 'IBM Plex Mono', monospace; font-size: 13px; line-height: 1.55; overflow-x: auto; border-left: 2px solid var(--bd-blue); } .bd-doc code { background: rgba(0, 114, 178, 0.08); padding: 1px 6px; border-radius: 3px; font-family: 'IBM Plex Mono', monospace; font-size: 14px; color: #0a4d77; } .bd-doc pre code { background: transparent; padding: 0; color: inherit; font-size: inherit; } .bd-doc img { max-width: 100%; display: block; margin: 18px auto; border-radius: 4px; box-shadow: 0 2px 14px rgba(0, 114, 178, 0.10); } .bd-doc table { border-collapse: collapse; margin: 14px 0; font-size: 14px; font-variant-numeric: tabular-nums; width: 100%; } .bd-doc th, .bd-doc td { border: 1px solid var(--bd-rule); padding: 7px 12px; text-align: left; vertical-align: top; } .bd-doc th { background: var(--bd-paper-deep); font-weight: 600; color: var(--bd-meta); font-size: 12px; letter-spacing: 0.04em; text-transform: uppercase; } .bd-doc .admonition { border-left: 3px solid var(--bd-blue); background: rgba(0, 114, 178, 0.05); padding: 10px 16px; margin: 16px 0; border-radius: 0 4px 4px 0; font-size: 15px; } .bd-doc .admonition.important { border-color: var(--bd-orange); background: rgba(213, 94, 0, 0.05); } .bd-doc .admonition.note { border-color: var(--bd-green); background: rgba(0, 158, 115, 0.05); } .bd-doc .admonition-title { font-weight: 600; margin-bottom: 4px; } .bd-doc dl.field-list { display: grid; grid-template-columns: max-content auto; gap: 6px 16px; font-size: 15px; margin: 12px 0; } .bd-doc dl.field-list dt { font-weight: 600; color: var(--bd-meta); font-size: 13px; letter-spacing: 0.03em; text-transform: uppercase; padding-top: 2px; } .bd-doc a { color: var(--bd-blue); text-decoration: none; border-bottom: 1px solid rgba(0, 114, 178, 0.3); } .bd-doc a:hover { border-bottom-color: var(--bd-blue); } /* Inline badge produced by docstring_renderer ------------------------- */ .bd-badge { display: inline-block; padding: 3px 10px; border-radius: 3px; color: white; font-family: 'IBM Plex Sans', sans-serif; font-size: 12px; font-weight: 600; letter-spacing: 0.02em; margin: 0 4px 4px 0; } /* Form labels --------------------------------------------------------- */ label > span, .gradio-container .label-wrap span { font-family: 'IBM Plex Sans', sans-serif; font-size: 12px !important; font-weight: 600 !important; letter-spacing: 0.06em !important; text-transform: uppercase !important; color: var(--bd-meta) !important; } input[type="number"], textarea, select { font-family: 'IBM Plex Mono', monospace !important; font-size: 14px !important; } /* Primary button ------------------------------------------------------ */ button.primary, button[variant="primary"], .bd-cta { background: var(--bd-blue) !important; color: white !important; font-family: 'IBM Plex Sans', sans-serif !important; font-size: 13px !important; font-weight: 600 !important; letter-spacing: 0.06em !important; text-transform: uppercase !important; border-radius: 4px !important; padding: 11px 18px !important; border: none !important; transition: background 0.15s ease; } button.primary:hover, button[variant="primary"]:hover, .bd-cta:hover { background: #005a8c !important; } /* Footer -------------------------------------------------------------- */ .bd-footer { margin: 40px 0 20px 0; padding-top: 18px; border-top: 1px solid var(--bd-rule); font-family: 'IBM Plex Sans', sans-serif; font-size: 12px; color: var(--bd-meta); letter-spacing: 0.04em; display: flex; justify-content: space-between; flex-wrap: wrap; gap: 8px; } .bd-footer a { color: var(--bd-blue); text-decoration: none; } .bd-footer a:hover { text-decoration: underline; } """ # --------------------------------------------------------------------------- # HTML fragments # --------------------------------------------------------------------------- import html as _html def _info_card(name: str) -> str: """Primary visual anchor — display name + scrollable signature + source link.""" cls = MODELS[name] sig = _html.escape(get_signature_str(cls)) link = get_source_link(cls) or "#" docstring_first = (cls.__doc__ or "").strip().splitlines() tagline = "" if docstring_first: first = docstring_first[0].strip() # Strip rST cite markers like [Foo2023]_ import re as _re first = _re.sub(r"\[\w+\]_", "", first).strip() tagline = _html.escape(first[:200]) return ( f'
' f'
{_html.escape(name)}
' f' {f"
{tagline}
" if tagline else ""}' f'
{sig}
' f' ↗ Source on GitHub' f'
' ) def _stat_tile(params: int | None = None, *, n_chans: int | None = None, n_times: int | None = None, out_shape=None) -> str: """Live parameter count + input/output shapes.""" if params is None: return ( '
' '
Parameters
' '
' '
press build to instantiate
' '
' ) pretty_out = ( f"({', '.join(str(d) for d in out_shape)})" if isinstance(out_shape, tuple) else str(out_shape) ) return ( '
' '
Parameters
' f'
{params:,}
' f'
in (b, {n_chans}, {n_times}) → {pretty_out}
' '
' ) def _header_band() -> str: return ( '
' '
braindecode model explorer
' f'
v{BD_VERSION}{len(MODELS)} architecturesno weights
' '
' ) def _section_rule(label: str) -> str: return f'
{label}
' def _footer() -> str: return ( '' ) # --------------------------------------------------------------------------- # Event handlers # --------------------------------------------------------------------------- def show_model(name: str): if name not in MODELS: return "", "", _stat_tile(), {}, {}, {}, {} info = _info_card(name) doc_html = render_docstring_html(MODELS[name].__doc__) d = _defaults_for(name) return ( info, doc_html, _stat_tile(), # reset stat tile when switching models gr.update(value=d["n_chans"]), gr.update(value=d["sfreq"]), gr.update(value=d["input_window_seconds"]), gr.update(value=d["n_outputs"]), ) def instantiate(name, n_chans, sfreq, window_s, n_outputs): """Build the model and return (stat_html, layer_summary_md).""" if name not in MODELS: return _stat_tile(), "Pick a model first." cls = MODELS[name] n_times = int(round(window_s * sfreq)) kwargs = dict( n_chans=int(n_chans), sfreq=float(sfreq), input_window_seconds=float(window_s), n_outputs=int(n_outputs), ) sig_params = set(inspect.signature(cls.__init__).parameters) kwargs = {k: v for k, v in kwargs.items() if k in sig_params} try: model = cls(**kwargs) except Exception as exc: # noqa: BLE001 err = f"❌ **Failed to instantiate `{name}`** with `{kwargs}`:\n```\n{exc}\n```" return _stat_tile(), err n_params = sum(p.numel() for p in model.parameters()) try: info = summary( model, input_size=(1, int(n_chans), n_times), depth=3, verbose=0, col_names=("output_size", "num_params"), ) summary_str = str(info) except Exception as exc: # noqa: BLE001 summary_str = f"(torchinfo summary unavailable: {exc})" out_shape: Any = "?" try: x = torch.randn(2, int(n_chans), n_times) with torch.no_grad(): y = model(x) out_shape = tuple(y.shape) if hasattr(y, "shape") else type(y).__name__ except Exception as exc: # noqa: BLE001 out_shape = f"forward failed: {exc}" stat = _stat_tile( params=n_params, n_chans=int(n_chans), n_times=n_times, out_shape=out_shape ) return stat, f"```\n{summary_str}\n```" # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- def build_app() -> gr.Blocks: theme = gr.themes.Soft( primary_hue=gr.themes.colors.blue, font=[gr.themes.GoogleFont("IBM Plex Sans"), "system-ui", "sans-serif"], font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "monospace"], ) with gr.Blocks( title="Braindecode Model Explorer", theme=theme, css=GLOBAL_CSS, ) as app: gr.HTML(_header_band()) with gr.Row(equal_height=False): # ---------- LEFT: controls + stat tile ---------- with gr.Column(scale=1, min_width=280): model_dd = gr.Dropdown( choices=MODEL_NAMES, value=DEFAULT_MODEL, label="Architecture", interactive=True, filterable=True, ) gr.HTML(_section_rule("Signal configuration")) with gr.Group(): n_chans = gr.Number(value=22, label="n_chans", precision=0) sfreq = gr.Number(value=250, label="sfreq · Hz") window_s = gr.Number(value=4.0, label="window · seconds") n_outputs = gr.Number(value=4, label="n_outputs", precision=0) run_btn = gr.Button( "Build network", variant="primary", elem_classes="bd-cta" ) stat_html = gr.HTML(_stat_tile()) # ---------- RIGHT: model info + docstring ---------- with gr.Column(scale=3): info_html = gr.HTML(_info_card(DEFAULT_MODEL)) gr.HTML(_section_rule("Architecture documentation")) doc_html = gr.HTML( render_docstring_html(MODELS[DEFAULT_MODEL].__doc__) ) with gr.Accordion("Layer summary (after build)", open=False): summary_md = gr.Markdown( "_Press **Build network** to populate the summary._" ) gr.HTML(_footer()) # ---------- wiring ---------- model_dd.change( show_model, inputs=model_dd, outputs=[info_html, doc_html, stat_html, n_chans, sfreq, window_s, n_outputs], ) run_btn.click( instantiate, inputs=[model_dd, n_chans, sfreq, window_s, n_outputs], outputs=[stat_html, summary_md], ) return app if __name__ == "__main__": # On HF Spaces the sandbox blocks localhost-only binds; expose on 0.0.0.0 # so the front-door proxy can reach us. Locally this still works fine. build_app().launch(server_name="0.0.0.0", server_port=7860)