|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class CVAE(torch.nn.Module): |
|
def __init__(self, latent_dim=20): |
|
super().__init__() |
|
self.latent_dim = latent_dim |
|
self.label_embed = torch.nn.Embedding(10, 10) |
|
|
|
self.encoder = torch.nn.Sequential( |
|
torch.nn.Linear(28*28 + 10, 400), |
|
torch.nn.ReLU(), |
|
) |
|
self.fc_mu = torch.nn.Linear(400, latent_dim) |
|
self.fc_logvar = torch.nn.Linear(400, latent_dim) |
|
|
|
self.decoder = torch.nn.Sequential( |
|
torch.nn.Linear(latent_dim + 10, 400), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(400, 28*28), |
|
torch.nn.Sigmoid() |
|
) |
|
|
|
def reparameterize(self, mu, logvar): |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return mu + eps * std |
|
|
|
def decode(self, z, y): |
|
y_embed = self.label_embed(y) |
|
inputs = torch.cat([z, y_embed], dim=1) |
|
return self.decoder(inputs) |
|
|
|
model = CVAE() |
|
model.load_state_dict(torch.load("cvae_mnist.pth", map_location='cpu')) |
|
model.eval() |
|
|
|
|
|
def generate_digit_images(digit): |
|
images = [] |
|
for _ in range(5): |
|
z = torch.randn(1, 20) |
|
y = torch.tensor([int(digit)]) |
|
with torch.no_grad(): |
|
out = model.decode(z, y) |
|
img = out.view(28, 28).numpy() |
|
images.append((img * 255).astype(np.uint8)) |
|
return images |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_digit_images, |
|
inputs=gr.Dropdown(choices=[str(i) for i in range(10)], label="Choose a digit (0β9)"), |
|
outputs=[gr.Image(image_mode='L') for _ in range(5)], |
|
title="Conditional VAE Handwritten Digit Generator", |
|
description="Generates 5 images of the digit you select (0β9) using a Conditional Variational Autoencoder trained on MNIST." |
|
) |
|
|
|
iface.launch() |