import random import gradio as gr from PIL import Image from model import predict from datasets import load_dataset # Load dataset (NO streaming → allows len() and indexing) dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train") def classify_image(img: Image.Image): # Handle empty input safely if img is None: return "No image uploaded", 0, {} label, confidence, probs = predict(img) return ( label, round(confidence, 3), {k: round(v, 3) for k, v in probs.items()} ) # Pick a random example def random_example(): idx = random.randint(0, len(dataset) - 1) item = dataset[idx] img = item["image"].convert("RGB") label = item["label"] label_str = dataset.features["label"].int2str(label) return img, label_str demo = gr.Blocks() with demo: gr.Markdown("## Animal Image Classifier with Random Dataset Samples") with gr.Row(): input_img = gr.Image(type="pil", label="Upload an image") rand_img = gr.Button("Random Dataset Image") with gr.Row(): pred_btn = gr.Button("Predict") output_label = gr.Label(label="Predicted Class") output_conf = gr.Number(label="Confidence") output_probs = gr.JSON(label="All Probabilities") # Display random dataset sample rand_display = gr.Image(type="pil", label="Random Dataset Sample") rand_label = gr.Textbox(label="Sample Label") # Actions pred_btn.click( classify_image, inputs=input_img, outputs=[output_label, output_conf, output_probs] ) rand_img.click( random_example, outputs=[rand_display, rand_label] ) if __name__ == "__main__": demo.launch()