|
--- |
|
license: gpl-3.0 |
|
language: |
|
- en |
|
metrics: |
|
- accuracy |
|
pipeline_tag: image-classification |
|
tags: |
|
- digits |
|
- cnn |
|
- mnist |
|
- emnist |
|
- pytorch |
|
- handwriting-recognition |
|
- onnx |
|
--- |
|
# Digit & Blank Image Classifier (PyTorch CNN) |
|
|
|
A high-accuracy convolutional neural network trained to classify handwritten digits from the **MNIST** and **EMNIST Digits** datasets, and additionally detect **blank images** (unfilled boxes) as a distinct class. This model is trained using PyTorch and exported in TorchScript format (`.pt`) for reliable and portable inference. |
|
|
|
--- |
|
|
|
## License & Attribution |
|
|
|
This model is licensed under the **AGPL-3.0** license to comply with the [Plom Project](https://gitlab.com/plom/plom) licensing requirements. |
|
|
|
### Developed as part of the Plom Project |
|
|
|
**Authors & Credits**: |
|
- Model: **Deep Shah**, Undergraduate Research Assistant, UBC |
|
- Supervision: **Prof. Andrew Rechnitzer** and **Prof. Colin B. MacDonald** |
|
- Project: [The Plom Project GitLab](https://gitlab.com/plom/plom) |
|
|
|
--- |
|
|
|
## Overview |
|
|
|
- **Input**: 1Γ28Γ28 grayscale image |
|
- **Output**: Integer class prediction: |
|
- 0β9: Digits |
|
- 10: Blank image |
|
- **Architecture**: 3-layer CNN with BatchNorm, ReLU, MaxPooling, Dropout, Fully Connected Layers |
|
- **Model Format**: TorchScript (`.pt`), ONNX (`.onnx`) |
|
- **Training Dataset**: Combined MNIST, EMNIST Digits, and 5000 synthetic blank images |
|
|
|
--- |
|
|
|
## Dataset Details |
|
|
|
### Datasets Used: |
|
|
|
- **MNIST** β 28Γ28 handwritten digits (0β9), 60,000 training images |
|
- **EMNIST Digits** β 28Γ28 digits extracted from handwritten characters, 240,000+ training samples |
|
- **Blank Images** β 5,000 synthetic all-black 28Γ28 images, labeled as class `10` to simulate unfilled regions |
|
|
|
### Preprocessing: |
|
|
|
- Normalized pixel values to [0, 1] |
|
- Converted images to channel-first format (N, C, H, W) |
|
- Combined and shuffled datasets |
|
|
|
--- |
|
|
|
## Data Augmentation |
|
|
|
To improve generalization and robustness to handwriting variation: |
|
|
|
- `RandomRotation(Β±10Β°)` |
|
- `RandomAffine`: scale (0.9β1.1), translate (Β±10%) |
|
|
|
These transformations simulate handwritten noise and variation in real student submissions. |
|
|
|
--- |
|
|
|
## Model Architecture |
|
|
|
``` |
|
Input: (1, 28, 28) |
|
β Conv2D(1 β 32) + BatchNorm + ReLU |
|
β Conv2D(32 β 64) + BatchNorm + ReLU |
|
β MaxPool2d(2x2) + Dropout(0.1) |
|
β Conv2D(64 β 128) + BatchNorm + ReLU |
|
β MaxPool2d(2x2) + Dropout(0.1) |
|
β Flatten |
|
β Linear(128*7*7 β 128) + BatchNorm + ReLU + Dropout(0.2) |
|
β Linear(128 β 11) |
|
β Output: class logits (digits 0β9, blank = 10) |
|
``` |
|
|
|
--- |
|
|
|
## Training Configuration |
|
|
|
| Hyperparameter | Value | |
|
| -------------- | ------------------- | |
|
| Optimizer | Adam (lr=0.001) | |
|
| Loss Function | CrossEntropyLoss | |
|
| Scheduler | ReduceLROnPlateau | |
|
| Early Stopping | Patience = 5 | |
|
| Epochs | Max 50 | |
|
| Batch Size | 64 | |
|
| Device | CPU or CUDA | |
|
| Random Seed | 42 | |
|
|
|
--- |
|
|
|
## Evaluation Results |
|
|
|
| Metric | Value | |
|
| -------------------- | --------- | |
|
| Test Accuracy | 99.73% | |
|
| Blank Image Accuracy | 100.00% | |
|
|
|
All 5,000 blank images were correctly classified. |
|
|
|
--- |
|
|
|
## Inference Examples |
|
|
|
### 1. TorchScript (PyTorch) |
|
|
|
```python |
|
import torch |
|
|
|
# Load TorchScript model |
|
model = torch.jit.load("mnist_emnist_blank_cnn_v1.pt") |
|
model.eval() |
|
|
|
# Dummy input (1 image, 1 channel, 28x28) |
|
img = torch.randn(1, 1, 28, 28) |
|
|
|
# Predict |
|
with torch.no_grad(): |
|
out = model(img) |
|
predicted = out.argmax(dim=1).item() |
|
|
|
print("Predicted class:", predicted) |
|
``` |
|
|
|
### 2. ONNX (ONNX Runtime) |
|
```python |
|
import onnxruntime as ort |
|
import numpy as np |
|
|
|
# Load ONNX model |
|
session = ort.InferenceSession("mnist_emnist_blank_cnn_v1.onnx", providers=["CPUExecutionProvider"]) |
|
|
|
# Dummy input |
|
img = np.random.randn(1, 1, 28, 28).astype(np.float32) |
|
|
|
# Predict |
|
outputs = session.run(None, {"input": img}) |
|
predicted = int(outputs[0].argmax(axis=1)[0]) |
|
|
|
print("Predicted class:", predicted) |
|
``` |
|
|
|
|
|
|
|
> If the prediction is `10`, the model considers the image to be blank (no digits present). |
|
|
|
--- |
|
|
|
## Included Files |
|
|
|
- `train_digit_classifier.py`: Training script with full documentation |
|
- `mnist_emnist_blank_cnn_v1.pth`: Final trained model weights |
|
- `mnist_emnist_blank_cnn_v1.pt`: TorchScript export for deployment |
|
- `mnist_emnist_blank_cnn_v1.onnx`: ONNX export for deployment |
|
- `requirements.txt`: Required dependencies for training or inference |
|
|
|
--- |
|
|
|
## Intended Use |
|
|
|
This model was designed to support the Plom Projectβs student ID digit detection system, helping automatically identify handwritten digits (and detect blank/unfilled boxes) from scanned exam sheets. |
|
|
|
It may also be adapted for other handwritten digit classification tasks or real-time blank field detection applications. |
|
|
|
<!-- --- |
|
|
|
## Maintainer & Contact |
|
|
|
- **Deep Shah** β [Hugging Face Profile](https://huggingface.co/deepshah23) |
|
- For Plom inquiries: [The Plom Project GitLab](https://gitlab.com/plom/plom) --> |