Spaces:
Build error
Build error
| import gradio as gr | |
| from torch.nn.functional import softmax | |
| import torch | |
| from transformers import ViTFeatureExtractor | |
| from transformers import MobileViTFeatureExtractor | |
| from transformers import MobileViTForImageClassification | |
| from transformers import ViTForImageClassification | |
| def predict(model_type, inp): | |
| if model_type == "ViT": | |
| model_name_or_path = './models/vit-base-garbage/' | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path) | |
| model = ViTForImageClassification.from_pretrained(model_name_or_path) | |
| elif model_type == "MobileViT": | |
| model_name_or_path = './models/apple/mobilevit-small-garbage/' | |
| feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_name_or_path) | |
| model = MobileViTForImageClassification.from_pretrained(model_name_or_path) | |
| inputs = feature_extractor(inp, return_tensors="pt") | |
| LABELS = list(model.config.label2id.keys()) | |
| with torch.no_grad(): | |
| logits = model(**inputs) | |
| print(logits[0]) | |
| probability = torch.nn.functional.softmax(logits[0], dim=-1) | |
| confidences = {LABELS[i]:(float(probability[0][i])) for i in range(6)} | |
| # print(confidences) | |
| return confidences | |
| demo = gr.Interface(fn=predict, | |
| inputs=[gr.Dropdown(["ViT", "MobileViT"], label="Model Name", value='ViT'),gr.inputs.Image(type="pil")], | |
| outputs=gr.outputs.Label(num_top_classes=3), | |
| examples=[["ViT","paper567.jpg"],["ViT","trash105.jpg"],["ViT","plastic202.jpg"],["MobileViT","metal382.jpg"]], | |
| ) | |
| demo.launch() |