mnist / app.py
debojit01's picture
Update app.py
388fa58 verified
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
# Conditional VAE definition (same as training)
class CVAE(torch.nn.Module):
def __init__(self, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
self.label_embed = torch.nn.Embedding(10, 10)
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28*28 + 10, 400),
torch.nn.ReLU(),
)
self.fc_mu = torch.nn.Linear(400, latent_dim)
self.fc_logvar = torch.nn.Linear(400, latent_dim)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_dim + 10, 400),
torch.nn.ReLU(),
torch.nn.Linear(400, 28*28),
torch.nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, y):
y_embed = self.label_embed(y)
inputs = torch.cat([z, y_embed], dim=1)
return self.decoder(inputs)
model = CVAE()
model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu'))
model.eval()
# Image generation function
def generate_digit_images(digit):
images = []
for _ in range(5):
z = torch.randn(1, 20)
y = torch.tensor([int(digit)])
with torch.no_grad():
out = model.decode(z, y)
img = out.view(28, 28).numpy()
images.append((img * 255).astype(np.uint8))
return images
# Launch Gradio app
iface = gr.Interface(
fn=generate_digit_images,
inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0–9)"),
outputs=[gr.Image(image_mode='L') for _ in range(5)],
title="Conditional VAE Handwritten Digit Generator",
description="Generates 5 images of the digit you select (0–9) using a Conditional Variational Autoencoder trained on MNIST."
)
iface.launch()