AIOmarRehan's picture
Update model.py
1c8dc00 verified
raw
history blame
1.74 kB
import random
import gradio as gr
from PIL import Image
from model import predict
from datasets import load_dataset
# Load dataset (NO streaming → allows len() and indexing)
dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
def classify_image(img: Image.Image):
# Handle empty input safely
if img is None:
return "No image uploaded", 0, {}
label, confidence, probs = predict(img)
return (
label,
round(confidence, 3),
{k: round(v, 3) for k, v in probs.items()}
)
# Pick a random example
def random_example():
idx = random.randint(0, len(dataset) - 1)
item = dataset[idx]
img = item["image"].convert("RGB")
label = item["label"]
label_str = dataset.features["label"].int2str(label)
return img, label_str
demo = gr.Blocks()
with demo:
gr.Markdown("## Animal Image Classifier with Random Dataset Samples")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload an image")
rand_img = gr.Button("Random Dataset Image")
with gr.Row():
pred_btn = gr.Button("Predict")
output_label = gr.Label(label="Predicted Class")
output_conf = gr.Number(label="Confidence")
output_probs = gr.JSON(label="All Probabilities")
# Display random dataset sample
rand_display = gr.Image(type="pil", label="Random Dataset Sample")
rand_label = gr.Textbox(label="Sample Label")
# Actions
pred_btn.click(
classify_image,
inputs=input_img,
outputs=[output_label, output_conf, output_probs]
)
rand_img.click(
random_example,
outputs=[rand_display, rand_label]
)
if __name__ == "__main__":
demo.launch()