PedroSampaio/fruits-360
Viewer • Updated • 90.4k • 667 • 3
This repository contains a fruit image classification model based on a fine-tuned EfficientNet-B0 architecture using PyTorch and torchvision. The model was trained on the Fruits-360 dataset, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in [76] distinct classes. <-- Make sure this matches your actual class count
Training progress and metrics were tracked using Neptune.ai.
You can load the model and its configuration directly from the Hugging Face Hub using torch, torchvision, and huggingface_hub.
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
import requests
from huggingface_hub import hf_hub_download
import os
# --- 1. Define Model Loading Function ---
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"):
"""Loads model state_dict and config from Hugging Face Hub."""
# Download config file
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(config_path, 'r') as f:
config = json.load(f)
num_labels = config['num_labels']
id2label = config['id2label'] # Load label mapping
# Instantiate the correct architecture (EfficientNet-B0)
# Load architecture without pre-trained weights, as we'll load our fine-tuned ones
model = models.efficientnet_b0(weights=None)
# Modify the classifier head to match the number of classes used during training
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)
# Download model weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
# Load the state dict
# Ensure map_location handles CPU/GPU as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval() # Set to evaluation mode
print(f"Model loaded successfully from {repo_id} and set to evaluation mode.")
return model, config, id2label
# --- 2. Define Preprocessing ---
# Use the same transformations as validation during training
IMG_SIZE = (224, 224) # Standard EfficientNet input size
# ImageNet stats often used with EfficientNet pre-training
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
preprocess = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --- 3. Load Model ---
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID
model, config, id2label = load_model_from_hf(repo_id_to_load)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# --- 4. Prepare Input Image ---
# Example: Load an image file (replace with your image path)
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH
if not os.path.exists(image_path):
print(f"Warning: Image path not found: {image_path}")
print("Skipping prediction. Please provide a valid image path.")
input_batch = None
else:
try:
img = Image.open(image_path).convert("RGB")
input_tensor = preprocess(img)
# Add batch dimension (model expects batches)
input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.to(device)
except Exception as e:
print(f"Error processing image {image_path}: {e}")
input_batch = None
# --- 5. Make Prediction ---
if input_batch is not None:
with torch.no_grad(): # Disable gradient calculations for inference
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_catid = torch.max(probabilities, dim=0)
predicted_label_index = top_catid.item()
# Use the id2label mapping loaded from config
predicted_label = id2label.get(str(predicted_label_index), "Unknown Label")
confidence = top_prob.item()
print(f"\nPrediction for: {os.path.basename(image_path)}")
print(f"Predicted Label Index: {predicted_label_index}")
print(f"Predicted Label: {predicted_label}")
print(f"Confidence: {confidence:.4f}")
Base model
google/efficientnet-b0