| import random | |
| import gradio as gr | |
| from PIL import Image | |
| from model import predict | |
| from datasets import load_dataset | |
| dataset = load_dataset( | |
| "AIOmarRehan/AnimalsDataset", | |
| split="train", | |
| streaming=True | |
| ) | |
| def classify_image(img: Image.Image): | |
| if img is None: | |
| return "No image uploaded", 0, {} | |
| label, confidence, probs = predict(img) | |
| return ( | |
| label, | |
| round(confidence, 3), | |
| {k: round(v, 3) for k, v in probs.items()} | |
| ) | |
| def random_example(): | |
| item = next(iter(dataset.shuffle(buffer_size=100))) | |
| img = item["image"].convert("RGB") | |
| label = item["label"] | |
| label_str = dataset.features["label"].int2str(label) | |
| return img, img, label_str | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("## Animal Image Classifier with Random Dataset Samples") | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="Upload an image") | |
| rand_img = gr.Button("Random Dataset Image") | |
| pred_btn = gr.Button("Predict") | |
| output_label = gr.Label(label="Predicted Class") | |
| output_conf = gr.Number(label="Confidence") | |
| output_probs = gr.JSON(label="All Probabilities") | |
| rand_display = gr.Image(type="pil", label="Random Dataset Sample") | |
| rand_label = gr.Textbox(label="Sample Label") | |
| pred_btn.click( | |
| classify_image, | |
| inputs=input_img, | |
| outputs=[output_label, output_conf, output_probs] | |
| ) | |
| rand_img.click( | |
| random_example, | |
| outputs=[input_img, rand_display, rand_label] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |