RyanTietjen commited on
Commit
6ffc1e4
·
verified ·
1 Parent(s): 5544c2e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ryan Tietjen
3
+ Aug 2024
4
+ Demo application for a food classificiation demonstration
5
+ """
6
+ from model import vit_b_16
7
+ from timeit import default_timer as timer
8
+ import torch
9
+ import os
10
+ import gradio as gr
11
+
12
+ with open("demo/class_names.txt", 'r') as file:
13
+ class_names = [line.strip() for line in file]
14
+
15
+
16
+ model, transforms = vit_b_16(num_classes=101,
17
+ seed=31,
18
+ freeze_gradients=True,
19
+ unfreeze_blocks=0)
20
+
21
+ model.load_state_dict(torch.load('demo/vit_b_16_unfreeze_one_encoder_block_10_total_epochs.pth',
22
+ weights_only=True))
23
+
24
+
25
+ def predict_single_image(img):
26
+ start_time = timer()
27
+
28
+ model.eval()
29
+
30
+ #Add batch dim
31
+ img = transforms(img).unsqueeze(dim=0)
32
+
33
+ with torch.inference_mode():
34
+ # Obtain prediction logits -> prediction probabilities from image
35
+ logits = model(img)
36
+ probabilities = torch.softmax(logits, dim=1)
37
+
38
+ class_probabilities = {}
39
+ for i in range(len(class_names)):
40
+ class_probabilities[class_names[i]] = float(probabilities[0][i])
41
+
42
+ end_time = timer()
43
+
44
+ pred_time = round(end_time - start_time, 3)
45
+
46
+ return class_probabilities, pred_time
47
+
48
+
49
+ title = "Food Image Classification With PyTorch by Ryan Tietjen"
50
+ description = f"""
51
+ Determines what type of food is presented in a given image.
52
+ This model is capable of classifying [101 different types of food](https://github.com/RyanTietjen/Food-Classifier-pytorch-ver.-/blob/main/demo/class_names.txt) by
53
+ utilizing a [pre-trained Vision Transformer](https://pytorch.org/vision/stable/models/generated/torchvision.models.vit_b_16.html#torchvision.models.ViT_B_16_Weights),
54
+ and fine-tuning the results for specific food categories.
55
+ This model achieved a Top-1 accuracy of 91.55% and a Top-5 accuracy of 98.56%
56
+ """
57
+
58
+ sample_list = [["demo/samples/" + sample] for sample in os.listdir("demo/samples")]
59
+
60
+ #Gradio interface
61
+ demo = gr.Interface(
62
+ fn=predict_single_image,
63
+ inputs=gr.Image(type="pil"),
64
+ outputs=[
65
+ gr.Label(num_top_classes=5, label="Predictions"),
66
+ gr.Number(label="Prediction time (s)"),
67
+ ],
68
+ examples=sample_list,
69
+ title=title,
70
+ description=description,
71
+ )
72
+
73
+ demo.launch(share=True)