AIOmarRehan's picture
Update app.py
eb4c49e verified
raw
history blame
1.63 kB
import random
import gradio as gr
from PIL import Image
from model import predict
from datasets import load_dataset
# Load the HF dataset once
dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
def classify_image(img: Image.Image):
label, confidence, probs = predict(img)
return (
label,
round(confidence, 3),
{k: round(v, 3) for k, v in probs.items()}
)
# Pick a random example
def random_example():
# Choose a random row
idx = random.randint(0, len(dataset) - 1)
item = dataset[idx]
img = item["image"].convert("RGB") # PIL Image
label = item["label"] # numeric label from dataset
label_str = dataset.features["label"].int2str(label) # class name
return img, label_str
demo = gr.Blocks()
with demo:
gr.Markdown("## Animal Image Classifier with Random Dataset Samples")
with gr.Row():
input_img = gr.Image(type="pil", label="Upload an image")
rand_img = gr.Button("Random Dataset Image")
with gr.Row():
pred_btn = gr.Button("Predict")
output_label = gr.Label(label="Predicted Class")
output_conf = gr.Number(label="Confidence")
output_probs = gr.JSON(label="All Probabilities")
# Display random dataset sample
rand_display = gr.Image(type="pil", label="Random Dataset Sample")
rand_label = gr.Textbox(label="Sample Label")
# Actions
pred_btn.click(classify_image, inputs=input_img, outputs=[output_label, output_conf, output_probs])
rand_img.click(random_example, outputs=[rand_display, rand_label])
if __name__ == "__main__":
demo.launch()