stroke-classification / src /streamlit_app.py
melisklc0's picture
refactor: Update branding and layout in Streamlit app
9ab05dd
import os
from pathlib import Path
import streamlit as st
from PIL import Image
from model_utils import load_stroke_model, predict
_STATIC_DIR = Path(__file__).resolve().parent / "static"
_APP_CSS_PATH = _STATIC_DIR / "app.css"
# KD EfficientNet-B0 · 3-fold CV mean (thesis/results/kd/KD_Efficientnet_b0)
_MODEL_METRICS = {
"accuracy": 98.0,
"precision": 99.5,
"recall": 96.4,
"f1": 98.0,
}
st.set_page_config(
page_title="Stroke Classification | Clinical Support",
page_icon="🩺",
layout="wide",
initial_sidebar_state="expanded",
)
@st.cache_data
def _load_app_css() -> str:
if not _APP_CSS_PATH.is_file():
raise FileNotFoundError(f"Stylesheet not found: {_APP_CSS_PATH}")
return _APP_CSS_PATH.read_text(encoding="utf-8")
def inject_app_styles() -> None:
st.markdown(f"<style>{_load_app_css()}</style>", unsafe_allow_html=True)
@st.cache_resource(show_spinner=False)
def get_model():
return load_stroke_model()
def _resolve_input_image(file_source):
if file_source is not None:
st.session_state.pop("sample_path", None)
return Image.open(file_source).convert("RGB"), "upload"
sample_path = st.session_state.get("sample_path")
if sample_path and os.path.isfile(sample_path):
return Image.open(sample_path).convert("RGB"), "sample"
if sample_path:
st.session_state.pop("sample_path", None)
return None, None
def _render_probability_bars(results: dict[str, float]) -> None:
colors = {"No-Stroke": "#15803d", "Stroke": "#b91c1c"}
rows = []
for cls in ("No-Stroke", "Stroke"):
prob = results.get(cls, 0.0)
pct = prob * 100
color = colors[cls]
rows.append(
f'<div class="prob-row">'
f'<p class="prob-row__label">{cls}{prob:.1%}</p>'
f'<div class="prob-track">'
f'<div class="prob-fill" style="width:{pct:.1f}%;background:{color};"></div>'
f"</div></div>"
)
st.html(f'<div class="prob-chart">{"".join(rows)}</div>')
def _sidebar_specs_html() -> str:
m = _MODEL_METRICS
return f"""
<dl class="info-card">
<dt>Model</dt>
<dd>EfficientNet-B0 (Distilled)</dd>
<dt>Metrics</dt>
<dd class="metrics-dd">
<div class="metrics-grid" aria-label="Model metrics">
<div class="metrics-grid__cell">
<span class="metrics-grid__label">Accuracy</span>
<span class="metrics-grid__value">{m["accuracy"]:.1f}%</span>
</div>
<div class="metrics-grid__cell">
<span class="metrics-grid__label">Precision</span>
<span class="metrics-grid__value">{m["precision"]:.1f}%</span>
</div>
<div class="metrics-grid__cell">
<span class="metrics-grid__label">Recall</span>
<span class="metrics-grid__value">{m["recall"]:.1f}%</span>
</div>
<div class="metrics-grid__cell">
<span class="metrics-grid__label">F1</span>
<span class="metrics-grid__value">{m["f1"]:.1f}%</span>
</div>
</div>
</dd>
<dt>Training data</dt>
<dd>MOH Turkey (15k Augmented Scans)</dd>
<dt>External validation</dt>
<dd>Kaggle hold-out set</dd>
</dl>
"""
def _model_loading_html() -> str:
return """
<div class="hint-box hint-box--loading">
<span class="hint-box__icon">⏳</span>
<span>
<strong>Loading model</strong>
<span class="hint-box__sub">Downloading weights from Hugging Face…</span>
</span>
</div>
"""
def _model_ready_html() -> str:
return """
<div class="hint-box hint-box--ready">
<span class="hint-box__icon">✅</span>
<span><strong>Model ready</strong> — waiting for a scan</span>
</div>
"""
def _ensure_model_loaded(status_slot):
if st.session_state.get("model_bundle") is not None:
status_slot.markdown(_model_ready_html(), unsafe_allow_html=True)
return st.session_state.model_bundle
status_slot.markdown(_model_loading_html(), unsafe_allow_html=True)
try:
bundle = get_model()
except Exception as e:
st.error(f"Model could not be loaded: {e}")
st.stop()
st.session_state.model_bundle = bundle
status_slot.markdown(_model_ready_html(), unsafe_allow_html=True)
return bundle
inject_app_styles()
# --- Sidebar ---
with st.sidebar:
st.markdown(
"""
<div class="brand-block">
<div class="brand-mark">SC</div>
<div>
<div class="brand-title">Stroke Classification</div>
<div class="brand-sub">AI-assisted CT review powered by knowledge-distilled EfficientNet-B0</div>
</div>
</div>
""",
unsafe_allow_html=True,
)
st.markdown(_sidebar_specs_html(), unsafe_allow_html=True)
st.markdown("**Validation Samples**")
st.caption("Test cases from the external dataset.")
if st.button("Stroke", use_container_width=True):
st.session_state.sample_path = "assets/sample_stroke.png"
if st.button("No Stroke", use_container_width=True):
st.session_state.sample_path = "assets/sample_no_stroke.png"
st.markdown(
"""
<div class="sidebar-disclaimer">
<p class="sidebar-disclaimer__label">Disclaimer</p>
<p>For research and decision support only — not a standalone diagnostic device.
A qualified clinician must interpret all findings.</p>
</div>
""",
unsafe_allow_html=True,
)
# Visible gap above the CT scan / Analysis row (Streamlit main chrome eats plain padding)
st.markdown(
'<div class="main-cards-top-gap" aria-hidden="true"></div>',
unsafe_allow_html=True,
)
scan_col, result_col = st.columns([1, 1], gap="large")
with scan_col:
with st.container(border=True):
st.subheader("CT scan")
file_source = st.file_uploader(
"Upload a non-contrast or contrast-enhanced axial slice (PNG, JPG).",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
)
input_image, source_kind = _resolve_input_image(file_source)
if input_image is not None:
caption = "Uploaded scan" if source_kind == "upload" else "Kaggle hold-out sample"
st.image(input_image, caption=caption, width="stretch")
else:
st.markdown(
"""
<div class="welcome-panel">
<span class="welcome-panel__title">Get started</span>
<ul class="welcome-panel__list">
<li>Upload a CT slice (PNG or JPG)</li>
<li>Pick a sample from the sidebar</li>
</ul>
</div>
""",
unsafe_allow_html=True,
)
with result_col:
with st.container(border=True):
st.subheader("Analysis")
model_status = st.empty()
model, transform = _ensure_model_loaded(model_status)
if input_image is None:
st.markdown(
'<p class="empty-state">'
"Results will appear here after you upload an image or select a sample."
"</p>",
unsafe_allow_html=True,
)
else:
model_status.empty()
with st.spinner("Running inference…"):
prediction, confidence, results = predict(model, transform, input_image)
is_stroke = prediction == "Stroke"
verdict_class = "stroke" if is_stroke else "normal"
verdict_text = "Stroke detected" if is_stroke else "No stroke detected"
st.markdown(
f"""
<div class="verdict-box {verdict_class}">
<div class="verdict-label">Classification</div>
<div class="verdict-value">{verdict_text.upper()}</div>
</div>
""",
unsafe_allow_html=True,
)
st.metric("Model confidence", f"{confidence:.1%}")
st.markdown("**Class probabilities**")
_render_probability_bars(results)
note = (
"Pattern indicates hemorrhage or ischemia. Clinical review required."
if is_stroke
else "No stroke pattern detected. Clinical review required."
)
st.markdown(
f'<div class="alert-box alert-box--muted">{note}</div>',
unsafe_allow_html=True,
)
st.markdown(
"""
<p class="footer-note">
Stroke Classification System · Melis Kılıç &amp; Esra Koç<br>
ONNX inference · Streamlit
</p>
""",
unsafe_allow_html=True,
)