AIOmarRehan commited on
Commit
30b7c97
·
verified ·
1 Parent(s): 92d6cc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -17
app.py CHANGED
@@ -4,37 +4,29 @@ from PIL import Image
4
  from model import predict
5
  from datasets import load_dataset
6
 
7
- dataset = load_dataset(
8
- "AIOmarRehan/AnimalsDataset",
9
- split="train",
10
- streaming=True
11
- )
12
 
13
  def classify_image(img: Image.Image):
14
-
15
  if img is None:
16
  return "No image uploaded", 0, {}
17
-
18
  label, confidence, probs = predict(img)
19
-
20
  return (
21
  label,
22
  round(confidence, 3),
23
  {k: round(v, 3) for k, v in probs.items()}
24
  )
25
 
 
26
  def random_example():
27
-
28
- item = next(iter(dataset.shuffle(buffer_size=1500)))
29
-
30
  img = item["image"].convert("RGB")
31
- label = item["label"]
32
-
33
- label_str = dataset.features["label"].int2str(label)
34
-
35
- return img, img, label_str
36
-
37
 
 
38
  demo = gr.Blocks()
39
 
40
  with demo:
@@ -53,12 +45,14 @@ with demo:
53
  rand_display = gr.Image(type="pil", label="Random Dataset Sample")
54
  rand_label = gr.Textbox(label="Sample Label")
55
 
 
56
  pred_btn.click(
57
  classify_image,
58
  inputs=input_img,
59
  outputs=[output_label, output_conf, output_probs]
60
  )
61
 
 
62
  rand_img.click(
63
  random_example,
64
  outputs=[input_img, rand_display, rand_label]
 
4
  from model import predict
5
  from datasets import load_dataset
6
 
7
+ # Load the full dataset (non-streaming)
8
+ dataset = load_dataset("AIOmarRehan/AnimalsDataset", split="train")
 
 
 
9
 
10
  def classify_image(img: Image.Image):
 
11
  if img is None:
12
  return "No image uploaded", 0, {}
13
+
14
  label, confidence, probs = predict(img)
 
15
  return (
16
  label,
17
  round(confidence, 3),
18
  {k: round(v, 3) for k, v in probs.items()}
19
  )
20
 
21
+ # Random example from the dataset
22
  def random_example():
23
+ item = random.choice(dataset)
 
 
24
  img = item["image"].convert("RGB")
25
+ label = dataset.features["label"].int2str(item["label"])
26
+ # Return image twice: once for input_img (for prediction), once for display
27
+ return img, img, label
 
 
 
28
 
29
+ # Gradio UI
30
  demo = gr.Blocks()
31
 
32
  with demo:
 
45
  rand_display = gr.Image(type="pil", label="Random Dataset Sample")
46
  rand_label = gr.Textbox(label="Sample Label")
47
 
48
+ # Predict button uses whatever image is currently in input_img
49
  pred_btn.click(
50
  classify_image,
51
  inputs=input_img,
52
  outputs=[output_label, output_conf, output_probs]
53
  )
54
 
55
+ # Random button picks a dataset image
56
  rand_img.click(
57
  random_example,
58
  outputs=[input_img, rand_display, rand_label]