Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| import streamlit as st | |
| from PIL import Image | |
| from model_utils import load_stroke_model, predict | |
| _STATIC_DIR = Path(__file__).resolve().parent / "static" | |
| _APP_CSS_PATH = _STATIC_DIR / "app.css" | |
| # KD EfficientNet-B0 · 3-fold CV mean (thesis/results/kd/KD_Efficientnet_b0) | |
| _MODEL_METRICS = { | |
| "accuracy": 98.0, | |
| "precision": 99.5, | |
| "recall": 96.4, | |
| "f1": 98.0, | |
| } | |
| st.set_page_config( | |
| page_title="Stroke Classification | Clinical Support", | |
| page_icon="🩺", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| def _load_app_css() -> str: | |
| if not _APP_CSS_PATH.is_file(): | |
| raise FileNotFoundError(f"Stylesheet not found: {_APP_CSS_PATH}") | |
| return _APP_CSS_PATH.read_text(encoding="utf-8") | |
| def inject_app_styles() -> None: | |
| st.markdown(f"<style>{_load_app_css()}</style>", unsafe_allow_html=True) | |
| def get_model(): | |
| return load_stroke_model() | |
| def _resolve_input_image(file_source): | |
| if file_source is not None: | |
| st.session_state.pop("sample_path", None) | |
| return Image.open(file_source).convert("RGB"), "upload" | |
| sample_path = st.session_state.get("sample_path") | |
| if sample_path and os.path.isfile(sample_path): | |
| return Image.open(sample_path).convert("RGB"), "sample" | |
| if sample_path: | |
| st.session_state.pop("sample_path", None) | |
| return None, None | |
| def _render_probability_bars(results: dict[str, float]) -> None: | |
| colors = {"No-Stroke": "#15803d", "Stroke": "#b91c1c"} | |
| rows = [] | |
| for cls in ("No-Stroke", "Stroke"): | |
| prob = results.get(cls, 0.0) | |
| pct = prob * 100 | |
| color = colors[cls] | |
| rows.append( | |
| f'<div class="prob-row">' | |
| f'<p class="prob-row__label">{cls} — {prob:.1%}</p>' | |
| f'<div class="prob-track">' | |
| f'<div class="prob-fill" style="width:{pct:.1f}%;background:{color};"></div>' | |
| f"</div></div>" | |
| ) | |
| st.html(f'<div class="prob-chart">{"".join(rows)}</div>') | |
| def _sidebar_specs_html() -> str: | |
| m = _MODEL_METRICS | |
| return f""" | |
| <dl class="info-card"> | |
| <dt>Model</dt> | |
| <dd>EfficientNet-B0 (Distilled)</dd> | |
| <dt>Metrics</dt> | |
| <dd class="metrics-dd"> | |
| <div class="metrics-grid" aria-label="Model metrics"> | |
| <div class="metrics-grid__cell"> | |
| <span class="metrics-grid__label">Accuracy</span> | |
| <span class="metrics-grid__value">{m["accuracy"]:.1f}%</span> | |
| </div> | |
| <div class="metrics-grid__cell"> | |
| <span class="metrics-grid__label">Precision</span> | |
| <span class="metrics-grid__value">{m["precision"]:.1f}%</span> | |
| </div> | |
| <div class="metrics-grid__cell"> | |
| <span class="metrics-grid__label">Recall</span> | |
| <span class="metrics-grid__value">{m["recall"]:.1f}%</span> | |
| </div> | |
| <div class="metrics-grid__cell"> | |
| <span class="metrics-grid__label">F1</span> | |
| <span class="metrics-grid__value">{m["f1"]:.1f}%</span> | |
| </div> | |
| </div> | |
| </dd> | |
| <dt>Training data</dt> | |
| <dd>MOH Turkey (15k Augmented Scans)</dd> | |
| <dt>External validation</dt> | |
| <dd>Kaggle hold-out set</dd> | |
| </dl> | |
| """ | |
| def _model_loading_html() -> str: | |
| return """ | |
| <div class="hint-box hint-box--loading"> | |
| <span class="hint-box__icon">⏳</span> | |
| <span> | |
| <strong>Loading model</strong> | |
| <span class="hint-box__sub">Downloading weights from Hugging Face…</span> | |
| </span> | |
| </div> | |
| """ | |
| def _model_ready_html() -> str: | |
| return """ | |
| <div class="hint-box hint-box--ready"> | |
| <span class="hint-box__icon">✅</span> | |
| <span><strong>Model ready</strong> — waiting for a scan</span> | |
| </div> | |
| """ | |
| def _ensure_model_loaded(status_slot): | |
| if st.session_state.get("model_bundle") is not None: | |
| status_slot.markdown(_model_ready_html(), unsafe_allow_html=True) | |
| return st.session_state.model_bundle | |
| status_slot.markdown(_model_loading_html(), unsafe_allow_html=True) | |
| try: | |
| bundle = get_model() | |
| except Exception as e: | |
| st.error(f"Model could not be loaded: {e}") | |
| st.stop() | |
| st.session_state.model_bundle = bundle | |
| status_slot.markdown(_model_ready_html(), unsafe_allow_html=True) | |
| return bundle | |
| inject_app_styles() | |
| # --- Sidebar --- | |
| with st.sidebar: | |
| st.markdown( | |
| """ | |
| <div class="brand-block"> | |
| <div class="brand-mark">SC</div> | |
| <div> | |
| <div class="brand-title">Stroke Classification</div> | |
| <div class="brand-sub">AI-assisted CT review powered by knowledge-distilled EfficientNet-B0</div> | |
| </div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown(_sidebar_specs_html(), unsafe_allow_html=True) | |
| st.markdown("**Validation Samples**") | |
| st.caption("Test cases from the external dataset.") | |
| if st.button("Stroke", use_container_width=True): | |
| st.session_state.sample_path = "assets/sample_stroke.png" | |
| if st.button("No Stroke", use_container_width=True): | |
| st.session_state.sample_path = "assets/sample_no_stroke.png" | |
| st.markdown( | |
| """ | |
| <div class="sidebar-disclaimer"> | |
| <p class="sidebar-disclaimer__label">Disclaimer</p> | |
| <p>For research and decision support only — not a standalone diagnostic device. | |
| A qualified clinician must interpret all findings.</p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Visible gap above the CT scan / Analysis row (Streamlit main chrome eats plain padding) | |
| st.markdown( | |
| '<div class="main-cards-top-gap" aria-hidden="true"></div>', | |
| unsafe_allow_html=True, | |
| ) | |
| scan_col, result_col = st.columns([1, 1], gap="large") | |
| with scan_col: | |
| with st.container(border=True): | |
| st.subheader("CT scan") | |
| file_source = st.file_uploader( | |
| "Upload a non-contrast or contrast-enhanced axial slice (PNG, JPG).", | |
| type=["png", "jpg", "jpeg"], | |
| label_visibility="collapsed", | |
| ) | |
| input_image, source_kind = _resolve_input_image(file_source) | |
| if input_image is not None: | |
| caption = "Uploaded scan" if source_kind == "upload" else "Kaggle hold-out sample" | |
| st.image(input_image, caption=caption, width="stretch") | |
| else: | |
| st.markdown( | |
| """ | |
| <div class="welcome-panel"> | |
| <span class="welcome-panel__title">Get started</span> | |
| <ul class="welcome-panel__list"> | |
| <li>Upload a CT slice (PNG or JPG)</li> | |
| <li>Pick a sample from the sidebar</li> | |
| </ul> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| with result_col: | |
| with st.container(border=True): | |
| st.subheader("Analysis") | |
| model_status = st.empty() | |
| model, transform = _ensure_model_loaded(model_status) | |
| if input_image is None: | |
| st.markdown( | |
| '<p class="empty-state">' | |
| "Results will appear here after you upload an image or select a sample." | |
| "</p>", | |
| unsafe_allow_html=True, | |
| ) | |
| else: | |
| model_status.empty() | |
| with st.spinner("Running inference…"): | |
| prediction, confidence, results = predict(model, transform, input_image) | |
| is_stroke = prediction == "Stroke" | |
| verdict_class = "stroke" if is_stroke else "normal" | |
| verdict_text = "Stroke detected" if is_stroke else "No stroke detected" | |
| st.markdown( | |
| f""" | |
| <div class="verdict-box {verdict_class}"> | |
| <div class="verdict-label">Classification</div> | |
| <div class="verdict-value">{verdict_text.upper()}</div> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.metric("Model confidence", f"{confidence:.1%}") | |
| st.markdown("**Class probabilities**") | |
| _render_probability_bars(results) | |
| note = ( | |
| "Pattern indicates hemorrhage or ischemia. Clinical review required." | |
| if is_stroke | |
| else "No stroke pattern detected. Clinical review required." | |
| ) | |
| st.markdown( | |
| f'<div class="alert-box alert-box--muted">{note}</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| """ | |
| <p class="footer-note"> | |
| Stroke Classification System · Melis Kılıç & Esra Koç<br> | |
| ONNX inference · Streamlit | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |