AIOmarRehan commited on
Commit
eb4c49e
·
verified ·
1 Parent(s): 87d27f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -12
app.py CHANGED
@@ -1,27 +1,53 @@
 
1
  import gradio as gr
2
  from PIL import Image
3
  from model import predict
 
 
 
 
4
 
5
  def classify_image(img: Image.Image):
6
  label, confidence, probs = predict(img)
7
-
8
  return (
9
  label,
10
  round(confidence, 3),
11
  {k: round(v, 3) for k, v in probs.items()}
12
  )
13
 
14
- demo = gr.Interface(
15
- fn=classify_image,
16
- inputs=gr.Image(type="pil", label="Upload an image"),
17
- outputs=[
18
- gr.Label(label="Predicted Class"),
19
- gr.Number(label="Confidence"),
20
- gr.JSON(label="All Probabilities")
21
- ],
22
- title="Animal Image Classifier",
23
- description="Upload an image and the model will predict the animal."
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  if __name__ == "__main__":
27
  demo.launch()
 
1
+ import random
2
  import gradio as gr
3
  from PIL import Image
4
  from model import predict
5
+ from datasets import load_dataset
6
+
7
+ # Load the HF dataset once
8
+ dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
9
 
10
  def classify_image(img: Image.Image):
11
  label, confidence, probs = predict(img)
 
12
  return (
13
  label,
14
  round(confidence, 3),
15
  {k: round(v, 3) for k, v in probs.items()}
16
  )
17
 
18
+ # Pick a random example
19
+ def random_example():
20
+ # Choose a random row
21
+ idx = random.randint(0, len(dataset) - 1)
22
+ item = dataset[idx]
23
+ img = item["image"].convert("RGB") # PIL Image
24
+ label = item["label"] # numeric label from dataset
25
+ label_str = dataset.features["label"].int2str(label) # class name
26
+ return img, label_str
27
+
28
+ demo = gr.Blocks()
29
+
30
+ with demo:
31
+ gr.Markdown("## Animal Image Classifier with Random Dataset Samples")
32
+
33
+ with gr.Row():
34
+ input_img = gr.Image(type="pil", label="Upload an image")
35
+ rand_img = gr.Button("Random Dataset Image")
36
+
37
+ with gr.Row():
38
+ pred_btn = gr.Button("Predict")
39
+
40
+ output_label = gr.Label(label="Predicted Class")
41
+ output_conf = gr.Number(label="Confidence")
42
+ output_probs = gr.JSON(label="All Probabilities")
43
+
44
+ # Display random dataset sample
45
+ rand_display = gr.Image(type="pil", label="Random Dataset Sample")
46
+ rand_label = gr.Textbox(label="Sample Label")
47
+
48
+ # Actions
49
+ pred_btn.click(classify_image, inputs=input_img, outputs=[output_label, output_conf, output_probs])
50
+ rand_img.click(random_example, outputs=[rand_display, rand_label])
51
 
52
  if __name__ == "__main__":
53
  demo.launch()