| | import gradio as gr |
| | from transformers import AutoModelForImageClassification, AutoFeatureExtractor |
| | import torch |
| | from PIL import Image |
| | import numpy as np |
| |
|
| | |
| | model_name = "microsoft/beit-base-patch16-224" |
| | model = AutoModelForImageClassification.from_pretrained(model_name) |
| | feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
| |
|
| | |
| | labels = model.config.id2label |
| |
|
| | |
| | def classify_image(image): |
| | |
| | if isinstance(image, Image.Image): |
| | image = np.array(image) |
| |
|
| | |
| | inputs = feature_extractor(images=image, return_tensors="pt") |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | logits = outputs.logits |
| | predicted_class_idx = logits.argmax(-1).item() |
| |
|
| | |
| | class_name = labels.get(predicted_class_idx, f"Unknown Class (ID: {predicted_class_idx})") |
| |
|
| | return f"Predicted class: {class_name} (ID: {predicted_class_idx})" |
| |
|
| | |
| | demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo") |
| | demo.launch() |
| |
|
| |
|