import os
from pathlib import Path
import streamlit as st
from PIL import Image
from model_utils import load_stroke_model, predict
_STATIC_DIR = Path(__file__).resolve().parent / "static"
_APP_CSS_PATH = _STATIC_DIR / "app.css"
# KD EfficientNet-B0 · 3-fold CV mean (thesis/results/kd/KD_Efficientnet_b0)
_MODEL_METRICS = {
"accuracy": 98.0,
"precision": 99.5,
"recall": 96.4,
"f1": 98.0,
}
st.set_page_config(
page_title="Stroke Classification | Clinical Support",
page_icon="🩺",
layout="wide",
initial_sidebar_state="expanded",
)
@st.cache_data
def _load_app_css() -> str:
if not _APP_CSS_PATH.is_file():
raise FileNotFoundError(f"Stylesheet not found: {_APP_CSS_PATH}")
return _APP_CSS_PATH.read_text(encoding="utf-8")
def inject_app_styles() -> None:
st.markdown(f"", unsafe_allow_html=True)
@st.cache_resource(show_spinner=False)
def get_model():
return load_stroke_model()
def _resolve_input_image(file_source):
if file_source is not None:
st.session_state.pop("sample_path", None)
return Image.open(file_source).convert("RGB"), "upload"
sample_path = st.session_state.get("sample_path")
if sample_path and os.path.isfile(sample_path):
return Image.open(sample_path).convert("RGB"), "sample"
if sample_path:
st.session_state.pop("sample_path", None)
return None, None
def _render_probability_bars(results: dict[str, float]) -> None:
colors = {"No-Stroke": "#15803d", "Stroke": "#b91c1c"}
rows = []
for cls in ("No-Stroke", "Stroke"):
prob = results.get(cls, 0.0)
pct = prob * 100
color = colors[cls]
rows.append(
f'
'
f'
{cls} — {prob:.1%}
'
f'
"
)
st.html(f'{"".join(rows)}
')
def _sidebar_specs_html() -> str:
m = _MODEL_METRICS
return f"""
- Model
- EfficientNet-B0 (Distilled)
- Metrics
-
Accuracy
{m["accuracy"]:.1f}%
Precision
{m["precision"]:.1f}%
Recall
{m["recall"]:.1f}%
F1
{m["f1"]:.1f}%
- Training data
- MOH Turkey (15k Augmented Scans)
- External validation
- Kaggle hold-out set
"""
def _model_loading_html() -> str:
return """
⏳
Loading model
Downloading weights from Hugging Face…
"""
def _model_ready_html() -> str:
return """
✅
Model ready — waiting for a scan
"""
def _ensure_model_loaded(status_slot):
if st.session_state.get("model_bundle") is not None:
status_slot.markdown(_model_ready_html(), unsafe_allow_html=True)
return st.session_state.model_bundle
status_slot.markdown(_model_loading_html(), unsafe_allow_html=True)
try:
bundle = get_model()
except Exception as e:
st.error(f"Model could not be loaded: {e}")
st.stop()
st.session_state.model_bundle = bundle
status_slot.markdown(_model_ready_html(), unsafe_allow_html=True)
return bundle
inject_app_styles()
# --- Sidebar ---
with st.sidebar:
st.markdown(
"""
SC
Stroke Classification
AI-assisted CT review powered by knowledge-distilled EfficientNet-B0
""",
unsafe_allow_html=True,
)
st.markdown(_sidebar_specs_html(), unsafe_allow_html=True)
st.markdown("**Validation Samples**")
st.caption("Test cases from the external dataset.")
if st.button("Stroke", use_container_width=True):
st.session_state.sample_path = "assets/sample_stroke.png"
if st.button("No Stroke", use_container_width=True):
st.session_state.sample_path = "assets/sample_no_stroke.png"
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Visible gap above the CT scan / Analysis row (Streamlit main chrome eats plain padding)
st.markdown(
'',
unsafe_allow_html=True,
)
scan_col, result_col = st.columns([1, 1], gap="large")
with scan_col:
with st.container(border=True):
st.subheader("CT scan")
file_source = st.file_uploader(
"Upload a non-contrast or contrast-enhanced axial slice (PNG, JPG).",
type=["png", "jpg", "jpeg"],
label_visibility="collapsed",
)
input_image, source_kind = _resolve_input_image(file_source)
if input_image is not None:
caption = "Uploaded scan" if source_kind == "upload" else "Kaggle hold-out sample"
st.image(input_image, caption=caption, width="stretch")
else:
st.markdown(
"""
Get started
- Upload a CT slice (PNG or JPG)
- Pick a sample from the sidebar
""",
unsafe_allow_html=True,
)
with result_col:
with st.container(border=True):
st.subheader("Analysis")
model_status = st.empty()
model, transform = _ensure_model_loaded(model_status)
if input_image is None:
st.markdown(
''
"Results will appear here after you upload an image or select a sample."
"
",
unsafe_allow_html=True,
)
else:
model_status.empty()
with st.spinner("Running inference…"):
prediction, confidence, results = predict(model, transform, input_image)
is_stroke = prediction == "Stroke"
verdict_class = "stroke" if is_stroke else "normal"
verdict_text = "Stroke detected" if is_stroke else "No stroke detected"
st.markdown(
f"""
Classification
{verdict_text.upper()}
""",
unsafe_allow_html=True,
)
st.metric("Model confidence", f"{confidence:.1%}")
st.markdown("**Class probabilities**")
_render_probability_bars(results)
note = (
"Pattern indicates hemorrhage or ischemia. Clinical review required."
if is_stroke
else "No stroke pattern detected. Clinical review required."
)
st.markdown(
f'{note}
',
unsafe_allow_html=True,
)
st.markdown(
"""
""",
unsafe_allow_html=True,
)