import random import gradio as gr from PIL import Image from model import predict from datasets import load_dataset dataset = load_dataset( "AIOmarRehan/AnimalsDataset", split="train", streaming=True ) def classify_image(img: Image.Image): if img is None: return "No image uploaded", 0, {} label, confidence, probs = predict(img) return ( label, round(confidence, 3), {k: round(v, 3) for k, v in probs.items()} ) def random_example(): item = next(iter(dataset.shuffle(buffer_size=100))) img = item["image"].convert("RGB") label = item["label"] label_str = dataset.features["label"].int2str(label) return img, img, label_str demo = gr.Blocks() with demo: gr.Markdown("## Animal Image Classifier with Random Dataset Samples") with gr.Row(): input_img = gr.Image(type="pil", label="Upload an image") rand_img = gr.Button("Random Dataset Image") pred_btn = gr.Button("Predict") output_label = gr.Label(label="Predicted Class") output_conf = gr.Number(label="Confidence") output_probs = gr.JSON(label="All Probabilities") rand_display = gr.Image(type="pil", label="Random Dataset Sample") rand_label = gr.Textbox(label="Sample Label") pred_btn.click( classify_image, inputs=input_img, outputs=[output_label, output_conf, output_probs] ) rand_img.click( random_example, outputs=[input_img, rand_display, rand_label] ) if __name__ == "__main__": demo.launch()