|
|
|
--- |
|
license: mit |
|
datasets: |
|
- PedroSampaio/fruits-360 |
|
language: |
|
- en |
|
base_model: |
|
- google/efficientnet-b0 |
|
pipeline_tag: image-classification |
|
tags: |
|
- pytorch |
|
- torchvision |
|
- efficientnet |
|
- image-classification |
|
- fruits |
|
- fruits-360 |
|
- transfer-learning |
|
- neptune-ai |
|
widget: |
|
|
|
- src: https://images.unsplash.com/photo-1573246123790-a64e870b8b1a?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 |
|
example_title: Apple Example |
|
- src: https://images.unsplash.com/photo-1528825871115-3581a5377919?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 |
|
example_title: Banana Example |
|
--- |
|
|
|
[DEMO APP](https://huggingface.co/spaces/bhumong/fruit-classifier-app) |
|
|
|
# Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged) |
|
|
|
This repository contains a fruit image classification model based on a fine-tuned **EfficientNet-B0** architecture using PyTorch and torchvision. The model was trained on the **Fruits-360 dataset**, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in **[76]** distinct classes. <-- Make sure this matches your actual class count |
|
|
|
Training progress and metrics were tracked using **Neptune.ai**. |
|
|
|
## Model Description |
|
|
|
* **Architecture:** EfficientNet-B0 (pre-trained on ImageNet) |
|
* **Fine-tuning Strategy:** Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset. |
|
* **Framework:** PyTorch / torchvision |
|
* **Task:** Image Classification |
|
* **Dataset:** Fruits-360 (Merged Classes) |
|
* **Number of Classes:** [76] <-- Make sure this matches your actual class count |
|
|
|
## Intended Uses & Limitations |
|
|
|
* **Intended Use:** Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development. |
|
* **Limitations:** |
|
* Trained *only* on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed. |
|
* Only recognizes the specific [76] merged classes it was trained on. |
|
* Performance may vary depending on input image quality. |
|
* Not intended for safety-critical applications without rigorous testing and validation. |
|
|
|
## How to Use |
|
|
|
You can load the model and its configuration directly from the Hugging Face Hub using `torch`, `torchvision`, and `huggingface_hub`. |
|
|
|
```python |
|
import torch |
|
import torchvision.models as models |
|
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used |
|
from PIL import Image |
|
from torchvision import transforms |
|
import json |
|
import requests |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
# --- 1. Define Model Loading Function --- |
|
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"): |
|
"""Loads model state_dict and config from Hugging Face Hub.""" |
|
|
|
# Download config file |
|
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
num_labels = config['num_labels'] |
|
id2label = config['id2label'] # Load label mapping |
|
|
|
# Instantiate the correct architecture (EfficientNet-B0) |
|
# Load architecture without pre-trained weights, as we'll load our fine-tuned ones |
|
model = models.efficientnet_b0(weights=None) |
|
|
|
# Modify the classifier head to match the number of classes used during training |
|
num_ftrs = model.classifier[1].in_features |
|
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels) |
|
|
|
# Download model weights |
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) |
|
|
|
# Load the state dict |
|
# Ensure map_location handles CPU/GPU as needed |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
state_dict = torch.load(model_path, map_location=device) |
|
model.load_state_dict(state_dict) |
|
|
|
model.eval() # Set to evaluation mode |
|
print(f"Model loaded successfully from {repo_id} and set to evaluation mode.") |
|
return model, config, id2label |
|
|
|
# --- 2. Define Preprocessing --- |
|
# Use the same transformations as validation during training |
|
IMG_SIZE = (224, 224) # Standard EfficientNet input size |
|
# ImageNet stats often used with EfficientNet pre-training |
|
mean=[0.485, 0.456, 0.406] |
|
std=[0.229, 0.224, 0.225] |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(IMG_SIZE), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=mean, std=std), |
|
]) |
|
|
|
# --- 3. Load Model --- |
|
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID |
|
model, config, id2label = load_model_from_hf(repo_id_to_load) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
|
|
# --- 4. Prepare Input Image --- |
|
# Example: Load an image file (replace with your image path) |
|
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH |
|
|
|
if not os.path.exists(image_path): |
|
print(f"Warning: Image path not found: {image_path}") |
|
print("Skipping prediction. Please provide a valid image path.") |
|
input_batch = None |
|
else: |
|
try: |
|
img = Image.open(image_path).convert("RGB") |
|
input_tensor = preprocess(img) |
|
# Add batch dimension (model expects batches) |
|
input_batch = input_tensor.unsqueeze(0) |
|
input_batch = input_batch.to(device) |
|
except Exception as e: |
|
print(f"Error processing image {image_path}: {e}") |
|
input_batch = None |
|
|
|
# --- 5. Make Prediction --- |
|
if input_batch is not None: |
|
with torch.no_grad(): # Disable gradient calculations for inference |
|
output = model(input_batch) |
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
top_prob, top_catid = torch.max(probabilities, dim=0) |
|
|
|
predicted_label_index = top_catid.item() |
|
# Use the id2label mapping loaded from config |
|
predicted_label = id2label.get(str(predicted_label_index), "Unknown Label") |
|
confidence = top_prob.item() |
|
|
|
print(f"\nPrediction for: {os.path.basename(image_path)}") |
|
print(f"Predicted Label Index: {predicted_label_index}") |
|
print(f"Predicted Label: {predicted_label}") |
|
print(f"Confidence: {confidence:.4f}") |
|
|
|
|
|
|