| | |
| | import os, glob, traceback |
| | import numpy as np |
| | from PIL import Image |
| | import gradio as gr |
| |
|
| | |
| | os.environ.setdefault("KERAS_BACKEND", "tensorflow") |
| | import keras |
| | from huggingface_hub import hf_hub_download |
| |
|
| | HF_MODEL_ID = "Vedag812/xray_cnn" |
| | CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
| |
|
| | @gr.cache_resource |
| | def load_model(): |
| | model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras") |
| | |
| | model = keras.saving.load_model(model_path, compile=False, safe_mode=False) |
| | return model |
| |
|
| | def _infer_input_shape(model): |
| | |
| | try: |
| | shp = tuple(model.inputs[0].shape) |
| | except Exception: |
| | shp = getattr(model, "input_shape", None) |
| | if shp is None: |
| | return 150, 150, 1 |
| | if len(shp) < 4: |
| | return 150, 150, 1 |
| | H = int(shp[1]) if shp[1] is not None else 150 |
| | W = int(shp[2]) if shp[2] is not None else 150 |
| | C = int(shp[3]) if shp[3] is not None else 1 |
| | return H, W, C |
| |
|
| | def preprocess(pil_img: Image.Image, target): |
| | H, W, C = target |
| | g = pil_img.convert("L").resize((W, H)) |
| | arr = np.array(g).astype("float32") / 255.0 |
| | if C == 1: |
| | x = np.expand_dims(arr, axis=(0, -1)) |
| | elif C == 3: |
| | x = np.expand_dims(np.stack([arr]*3, axis=-1), 0) |
| | else: |
| | x = np.expand_dims(np.repeat(arr[..., None], C, axis=-1), 0) |
| | return x |
| |
|
| | def predict_fn(pil_img: Image.Image): |
| | try: |
| | if pil_img is None: |
| | return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, "Please upload an image or pick a sample." |
| | model = load_model() |
| | H, W, C = _infer_input_shape(model) |
| | x = preprocess(pil_img, (H, W, C)) |
| | y = model.predict(x, verbose=0) |
| | prob = float(np.ravel(y)[0]) |
| | idx = int(prob > 0.5) |
| | conf = prob if idx == 1 else 1 - prob |
| | probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
| | msg = f"Prediction: {CLASS_NAMES[idx]} | Confidence: {conf*100:.2f}%" |
| | return probs, msg |
| | except Exception as e: |
| | tip = ( |
| | "If this persists, make sure the Space has keras>=3 and tensorflow>=2.16." |
| | ) |
| | err = f"⚠️ Error during prediction:\n\n{e}\n\n{tip}" |
| | |
| | |
| | return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err |
| |
|
| | def list_examples(): |
| | files = [] |
| | for pat in ("images/*.jpeg", "images/*.jpg", "images/*.png"): |
| | files.extend(glob.glob(pat)) |
| | return [[p] for p in sorted(files)] |
| |
|
| | with gr.Blocks(css=""" |
| | .gradio-container {max-width: 980px !important; margin: auto;} |
| | #title {text-align:center;} |
| | """) as demo: |
| | gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
| | gr.Markdown("Upload an image or click a sample. The model predicts NORMAL or PNEUMONIA.") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
| | with gr.Row(): |
| | btn = gr.Button("Predict", variant="primary") |
| | gr.ClearButton(components=[inp]) |
| | gr.Markdown("### Samples") |
| | gr.Examples(examples=list_examples(), inputs=inp, examples_per_page=12) |
| | with gr.Column(scale=1): |
| | probs = gr.Label(num_top_classes=2, label="Class probabilities") |
| | out_text = gr.Markdown() |
| |
|
| | btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
| | inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|