import os from pathlib import Path import streamlit as st from PIL import Image from model_utils import load_stroke_model, predict _STATIC_DIR = Path(__file__).resolve().parent / "static" _APP_CSS_PATH = _STATIC_DIR / "app.css" # KD EfficientNet-B0 · 3-fold CV mean (thesis/results/kd/KD_Efficientnet_b0) _MODEL_METRICS = { "accuracy": 98.0, "precision": 99.5, "recall": 96.4, "f1": 98.0, } st.set_page_config( page_title="Stroke Classification | Clinical Support", page_icon="🩺", layout="wide", initial_sidebar_state="expanded", ) @st.cache_data def _load_app_css() -> str: if not _APP_CSS_PATH.is_file(): raise FileNotFoundError(f"Stylesheet not found: {_APP_CSS_PATH}") return _APP_CSS_PATH.read_text(encoding="utf-8") def inject_app_styles() -> None: st.markdown(f"", unsafe_allow_html=True) @st.cache_resource(show_spinner=False) def get_model(): return load_stroke_model() def _resolve_input_image(file_source): if file_source is not None: st.session_state.pop("sample_path", None) return Image.open(file_source).convert("RGB"), "upload" sample_path = st.session_state.get("sample_path") if sample_path and os.path.isfile(sample_path): return Image.open(sample_path).convert("RGB"), "sample" if sample_path: st.session_state.pop("sample_path", None) return None, None def _render_probability_bars(results: dict[str, float]) -> None: colors = {"No-Stroke": "#15803d", "Stroke": "#b91c1c"} rows = [] for cls in ("No-Stroke", "Stroke"): prob = results.get(cls, 0.0) pct = prob * 100 color = colors[cls] rows.append( f'
' f'

{cls} — {prob:.1%}

' f'
' f'
' f"
" ) st.html(f'
{"".join(rows)}
') def _sidebar_specs_html() -> str: m = _MODEL_METRICS return f"""
Model
EfficientNet-B0 (Distilled)
Metrics
Accuracy {m["accuracy"]:.1f}%
Precision {m["precision"]:.1f}%
Recall {m["recall"]:.1f}%
F1 {m["f1"]:.1f}%
Training data
MOH Turkey (15k Augmented Scans)
External validation
Kaggle hold-out set
""" def _model_loading_html() -> str: return """
Loading model Downloading weights from Hugging Face…
""" def _model_ready_html() -> str: return """
Model ready — waiting for a scan
""" def _ensure_model_loaded(status_slot): if st.session_state.get("model_bundle") is not None: status_slot.markdown(_model_ready_html(), unsafe_allow_html=True) return st.session_state.model_bundle status_slot.markdown(_model_loading_html(), unsafe_allow_html=True) try: bundle = get_model() except Exception as e: st.error(f"Model could not be loaded: {e}") st.stop() st.session_state.model_bundle = bundle status_slot.markdown(_model_ready_html(), unsafe_allow_html=True) return bundle inject_app_styles() # --- Sidebar --- with st.sidebar: st.markdown( """
SC
Stroke Classification
AI-assisted CT review powered by knowledge-distilled EfficientNet-B0
""", unsafe_allow_html=True, ) st.markdown(_sidebar_specs_html(), unsafe_allow_html=True) st.markdown("**Validation Samples**") st.caption("Test cases from the external dataset.") if st.button("Stroke", use_container_width=True): st.session_state.sample_path = "assets/sample_stroke.png" if st.button("No Stroke", use_container_width=True): st.session_state.sample_path = "assets/sample_no_stroke.png" st.markdown( """ """, unsafe_allow_html=True, ) # Visible gap above the CT scan / Analysis row (Streamlit main chrome eats plain padding) st.markdown( '', unsafe_allow_html=True, ) scan_col, result_col = st.columns([1, 1], gap="large") with scan_col: with st.container(border=True): st.subheader("CT scan") file_source = st.file_uploader( "Upload a non-contrast or contrast-enhanced axial slice (PNG, JPG).", type=["png", "jpg", "jpeg"], label_visibility="collapsed", ) input_image, source_kind = _resolve_input_image(file_source) if input_image is not None: caption = "Uploaded scan" if source_kind == "upload" else "Kaggle hold-out sample" st.image(input_image, caption=caption, width="stretch") else: st.markdown( """
Get started
""", unsafe_allow_html=True, ) with result_col: with st.container(border=True): st.subheader("Analysis") model_status = st.empty() model, transform = _ensure_model_loaded(model_status) if input_image is None: st.markdown( '

' "Results will appear here after you upload an image or select a sample." "

", unsafe_allow_html=True, ) else: model_status.empty() with st.spinner("Running inference…"): prediction, confidence, results = predict(model, transform, input_image) is_stroke = prediction == "Stroke" verdict_class = "stroke" if is_stroke else "normal" verdict_text = "Stroke detected" if is_stroke else "No stroke detected" st.markdown( f"""
Classification
{verdict_text.upper()}
""", unsafe_allow_html=True, ) st.metric("Model confidence", f"{confidence:.1%}") st.markdown("**Class probabilities**") _render_probability_bars(results) note = ( "Pattern indicates hemorrhage or ischemia. Clinical review required." if is_stroke else "No stroke pattern detected. Clinical review required." ) st.markdown( f'
{note}
', unsafe_allow_html=True, ) st.markdown( """ """, unsafe_allow_html=True, )