import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt # Conditional VAE definition (same as training) class CVAE(torch.nn.Module): def __init__(self, latent_dim=20): super().__init__() self.latent_dim = latent_dim self.label_embed = torch.nn.Embedding(10, 10) self.encoder = torch.nn.Sequential( torch.nn.Linear(28*28 + 10, 400), torch.nn.ReLU(), ) self.fc_mu = torch.nn.Linear(400, latent_dim) self.fc_logvar = torch.nn.Linear(400, latent_dim) self.decoder = torch.nn.Sequential( torch.nn.Linear(latent_dim + 10, 400), torch.nn.ReLU(), torch.nn.Linear(400, 28*28), torch.nn.Sigmoid() ) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z, y): y_embed = self.label_embed(y) inputs = torch.cat([z, y_embed], dim=1) return self.decoder(inputs) model = CVAE() model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu')) model.eval() # Image generation function def generate_digit_images(digit): images = [] for _ in range(5): z = torch.randn(1, 20) y = torch.tensor([int(digit)]) with torch.no_grad(): out = model.decode(z, y) img = out.view(28, 28).numpy() images.append((img * 255).astype(np.uint8)) return images # Launch Gradio app iface = gr.Interface( fn=generate_digit_images, inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0–9)"), outputs=[gr.Image(image_mode='L') for _ in range(5)], title="Conditional VAE Handwritten Digit Generator", description="Generates 5 images of the digit you select (0–9) using a Conditional Variational Autoencoder trained on MNIST." ) iface.launch()