import random import gradio as gr from PIL import Image from model import predict from datasets import load_dataset dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train") def classify_image(img: Image.Image): if img is None: return "No image uploaded", 0, {} label, confidence, probs = predict(img) return ( label, round(confidence, 3), {k: round(v, 3) for k, v in probs.items()} ) # Random example from the dataset def random_example(): item = random.choice(dataset) img = item["image"].convert("RGB") label = dataset.features["label"].int2str(item["label"]) # Return image twice: once for input_img (for prediction), once for display return img, img, label # Gradio UI demo = gr.Blocks() with demo: gr.Markdown("## Animal Image Classifier with Random Dataset Samples") with gr.Row(): input_img = gr.Image(type="pil", label="Upload an image") rand_img = gr.Button("Random Dataset Image") pred_btn = gr.Button("Predict") output_label = gr.Label(label="Predicted Class") output_conf = gr.Number(label="Confidence") output_probs = gr.JSON(label="All Probabilities") rand_display = gr.Image(type="pil", label="Random Dataset Sample") rand_label = gr.Textbox(label="Sample Label") # Predict button uses whatever image is currently in input_img pred_btn.click( classify_image, inputs=input_img, outputs=[output_label, output_conf, output_probs] ) # Random button picks a dataset image rand_img.click( random_example, outputs=[input_img, rand_display, rand_label] ) if __name__ == "__main__": demo.launch()