deepshah23's picture
Update README.md
adf39c3 verified
---
license: gpl-3.0
language:
- en
metrics:
- accuracy
pipeline_tag: image-classification
tags:
- digits
- cnn
- mnist
- emnist
- pytorch
- handwriting-recognition
- onnx
---
# Digit & Blank Image Classifier (PyTorch CNN)
A high-accuracy convolutional neural network trained to classify handwritten digits from the **MNIST** and **EMNIST Digits** datasets, and additionally detect **blank images** (unfilled boxes) as a distinct class. This model is trained using PyTorch and exported in TorchScript format (`.pt`) for reliable and portable inference.
---
## License & Attribution
This model is licensed under the **AGPL-3.0** license to comply with the [Plom Project](https://gitlab.com/plom/plom) licensing requirements.
### Developed as part of the Plom Project
**Authors & Credits**:
- Model: **Deep Shah**, Undergraduate Research Assistant, UBC
- Supervision: **Prof. Andrew Rechnitzer** and **Prof. Colin B. MacDonald**
- Project: [The Plom Project GitLab](https://gitlab.com/plom/plom)
---
## Overview
- **Input**: 1Γ—28Γ—28 grayscale image
- **Output**: Integer class prediction:
- 0–9: Digits
- 10: Blank image
- **Architecture**: 3-layer CNN with BatchNorm, ReLU, MaxPooling, Dropout, Fully Connected Layers
- **Model Format**: TorchScript (`.pt`), ONNX (`.onnx`)
- **Training Dataset**: Combined MNIST, EMNIST Digits, and 5000 synthetic blank images
---
## Dataset Details
### Datasets Used:
- **MNIST** – 28Γ—28 handwritten digits (0–9), 60,000 training images
- **EMNIST Digits** – 28Γ—28 digits extracted from handwritten characters, 240,000+ training samples
- **Blank Images** – 5,000 synthetic all-black 28Γ—28 images, labeled as class `10` to simulate unfilled regions
### Preprocessing:
- Normalized pixel values to [0, 1]
- Converted images to channel-first format (N, C, H, W)
- Combined and shuffled datasets
---
## Data Augmentation
To improve generalization and robustness to handwriting variation:
- `RandomRotation(Β±10Β°)`
- `RandomAffine`: scale (0.9–1.1), translate (Β±10%)
These transformations simulate handwritten noise and variation in real student submissions.
---
## Model Architecture
```
Input: (1, 28, 28)
↓ Conv2D(1 β†’ 32) + BatchNorm + ReLU
↓ Conv2D(32 β†’ 64) + BatchNorm + ReLU
↓ MaxPool2d(2x2) + Dropout(0.1)
↓ Conv2D(64 β†’ 128) + BatchNorm + ReLU
↓ MaxPool2d(2x2) + Dropout(0.1)
↓ Flatten
↓ Linear(128*7*7 β†’ 128) + BatchNorm + ReLU + Dropout(0.2)
↓ Linear(128 β†’ 11)
β†’ Output: class logits (digits 0–9, blank = 10)
```
---
## Training Configuration
| Hyperparameter | Value |
| -------------- | ------------------- |
| Optimizer | Adam (lr=0.001) |
| Loss Function | CrossEntropyLoss |
| Scheduler | ReduceLROnPlateau |
| Early Stopping | Patience = 5 |
| Epochs | Max 50 |
| Batch Size | 64 |
| Device | CPU or CUDA |
| Random Seed | 42 |
---
## Evaluation Results
| Metric | Value |
| -------------------- | --------- |
| Test Accuracy | 99.73% |
| Blank Image Accuracy | 100.00% |
All 5,000 blank images were correctly classified.
---
## Inference Examples
### 1. TorchScript (PyTorch)
```python
import torch
# Load TorchScript model
model = torch.jit.load("mnist_emnist_blank_cnn_v1.pt")
model.eval()
# Dummy input (1 image, 1 channel, 28x28)
img = torch.randn(1, 1, 28, 28)
# Predict
with torch.no_grad():
out = model(img)
predicted = out.argmax(dim=1).item()
print("Predicted class:", predicted)
```
### 2. ONNX (ONNX Runtime)
```python
import onnxruntime as ort
import numpy as np
# Load ONNX model
session = ort.InferenceSession("mnist_emnist_blank_cnn_v1.onnx", providers=["CPUExecutionProvider"])
# Dummy input
img = np.random.randn(1, 1, 28, 28).astype(np.float32)
# Predict
outputs = session.run(None, {"input": img})
predicted = int(outputs[0].argmax(axis=1)[0])
print("Predicted class:", predicted)
```
> If the prediction is `10`, the model considers the image to be blank (no digits present).
---
## Included Files
- `train_digit_classifier.py`: Training script with full documentation
- `mnist_emnist_blank_cnn_v1.pth`: Final trained model weights
- `mnist_emnist_blank_cnn_v1.pt`: TorchScript export for deployment
- `mnist_emnist_blank_cnn_v1.onnx`: ONNX export for deployment
- `requirements.txt`: Required dependencies for training or inference
---
## Intended Use
This model was designed to support the Plom Project’s student ID digit detection system, helping automatically identify handwritten digits (and detect blank/unfilled boxes) from scanned exam sheets.
It may also be adapted for other handwritten digit classification tasks or real-time blank field detection applications.
<!-- ---
## Maintainer & Contact
- **Deep Shah** β€” [Hugging Face Profile](https://huggingface.co/deepshah23)
- For Plom inquiries: [The Plom Project GitLab](https://gitlab.com/plom/plom) -->