Harshb11's picture
Create app.py
be51ac0 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image
# ANN model
class ANN(nn.Module):
def __init__(self):
super(ANN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 128)
self.fc4 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.fc4(x)
return x
# Load dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
# Train model
model = ANN()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_model(epochs=1):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")
model.eval()
# Call it once at the start
train_model(epochs=1)
# Inference
def predict_digit(img):
img = img.convert('L').resize((28, 28)) # grayscale and resize
img = np.array(img).astype(np.float32)
img = (img - 127.5) / 127.5 # normalize to [-1, 1]
img_tensor = torch.tensor(img).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = model(img_tensor)
_, predicted = torch.max(output, 1)
return f"Prediction: {predicted.item()}"
gr.Interface(
fn=predict_digit,
inputs=gr.Image(image_mode="L", shape=(280, 280), invert_colors=True, source="canvas"),
outputs="text",
title="MNIST Digit Recognizer (MLP)",
description="Draw a digit and the model will try to predict it after training for 1 epoch on MNIST."
).launch()