| | |
| | import os, glob, traceback |
| | import numpy as np |
| | from PIL import Image |
| | import gradio as gr |
| | import tensorflow as tf |
| |
|
| | |
| | KERAS3_AVAILABLE = False |
| | try: |
| | import keras |
| | KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3 |
| | except Exception: |
| | keras = None |
| |
|
| | HF_MODEL_ID = "Vedag812/xray_cnn" |
| | CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
| |
|
| | @gr.cache_resource |
| | def load_model(): |
| | from huggingface_hub import hf_hub_download |
| | model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras") |
| | |
| | if KERAS3_AVAILABLE: |
| | os.environ.setdefault("KERAS_BACKEND", "tensorflow") |
| | try: |
| | return keras.saving.load_model(model_path, compile=False, safe_mode=False) |
| | except Exception: |
| | |
| | pass |
| | |
| | return tf.keras.models.load_model(model_path, compile=False) |
| |
|
| | def _infer_input_shape(model): |
| | """returns (H, W, C) with integers if available, else defaults to (150,150,1)""" |
| | shape = None |
| | try: |
| | shape = tuple(model.inputs[0].shape.as_list()) |
| | except Exception: |
| | try: |
| | shape = tuple(model.input_shape) |
| | except Exception: |
| | pass |
| | if not shape or len(shape) < 4: |
| | return 150, 150, 1 |
| | H = int(shape[1]) if shape[1] else 150 |
| | W = int(shape[2]) if shape[2] else 150 |
| | C = int(shape[3]) if shape[3] else 1 |
| | return H, W, C |
| |
|
| | def preprocess(pil_img: Image.Image, target_hw_c): |
| | H, W, C = target_hw_c |
| | |
| | g = pil_img.convert("L").resize((W, H)) |
| | g_arr = np.array(g).astype("float32") / 255.0 |
| | if C == 1: |
| | x = np.expand_dims(g_arr, axis=(0, -1)) |
| | elif C == 3: |
| | x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) |
| | x = np.expand_dims(x3, axis=0) |
| | else: |
| | |
| | xC = np.repeat(g_arr[..., None], C, axis=-1) |
| | x = np.expand_dims(xC, axis=0) |
| | return x |
| |
|
| | def predict_fn(pil_img: Image.Image): |
| | try: |
| | model = load_model() |
| | H, W, C = _infer_input_shape(model) |
| | x = preprocess(pil_img, (H, W, C)) |
| | preds = model.predict(x, verbose=0) |
| | |
| | prob = float(preds.ravel()[0]) |
| | pred_idx = int(prob > 0.5) |
| | confidence = prob if pred_idx == 1 else 1 - prob |
| | probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
| | msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%" |
| | return probs, msg |
| | except Exception as e: |
| | |
| | tip = ( |
| | "Tip: if this keeps happening, the Space may need keras>=3 to load a model " |
| | "saved with newer Keras. I handled both paths here, but if your model was saved " |
| | "with a very new version, updating the Space deps can help." |
| | ) |
| | err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip |
| | return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text |
| |
|
| | def list_examples(): |
| | files = [] |
| | for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]: |
| | files.extend(glob.glob(pattern)) |
| | files = sorted(files) |
| | return [[p] for p in files] |
| |
|
| | with gr.Blocks(css=""" |
| | .gradio-container {max-width: 980px !important; margin: auto;} |
| | #title {text-align:center;} |
| | .card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;} |
| | """) as demo: |
| | gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
| | gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
| | with gr.Row(): |
| | btn = gr.Button("Predict", variant="primary") |
| | gr.ClearButton(components=[inp], value="Clear") |
| | gr.Markdown("### Samples") |
| | gr.Examples( |
| | examples=list_examples(), |
| | inputs=inp, |
| | examples_per_page=12, |
| | ) |
| | with gr.Column(scale=1): |
| | probs = gr.Label(num_top_classes=2, label="Class probabilities") |
| | out_text = gr.Markdown() |
| |
|
| | btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
| | inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|