RAYAuser's picture
Update src/streamlit_app.py
8184244 verified
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import io
import zipfile
from PIL import Image
import numpy as np
import os
st.set_page_config(layout="wide", page_title="Synthetic Image Generator - Conditional GAN")
TEXTS = {
"en": {
"title": "RAYgan Datasets Generator",
"description": "Uses a **Conditional Generative Adversarial Network (GAN)** to generate synthetic clothing images.",
"model_source": "Model source: [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)",
"sidebar_header": "Generation Options",
"language_select": "Language",
"generation_mode_radio": "Generation Mode:",
"mode_class": "Generate by Class",
"mode_dataset": "Generate a Full Dataset",
"select_class": "Choose a Class:",
"num_images_input": "Number of images to generate:",
"num_images_per_class": "Number of images per class:",
"generate_button": "Launch Generation",
"generation_in_progress": "Generation in progress...",
"generating_class_info": "Generating {num_images} images for class '{class_name}'...",
"generating_dataset_info": "Generating a complete dataset of {num_images} images ({num_images_per_class} per class)...",
"preview_header": "Preview of Generated Images",
"preview_caption": "Preview {idx}",
"download_button": "Download ZIP file",
"generation_success": "Generation complete and images ready for download!",
"model_not_found_error": "Error Management : The model file could not be found locally.",
"instructions_header": "Instructions:",
"instructions_1": "1. Choose your generation mode.",
"instructions_2": "2. Enter the number of images you want to create.",
"instructions_3": "3. Click on 'Launch Generation'.",
},
"fr": {
"title": "RAYgan Datasets Generator",
"description": "Utilise un modèle **Conditional Generative Adversarial Network (GAN)** pour générer des images synthétiques de vêtements.",
"model_source": "Source du modèle : [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)",
"sidebar_header": "Options de Génération",
"language_select": "Langue",
"generation_mode_radio": "Mode de génération :",
"mode_class": "Générer par classe",
"mode_dataset": "Générer un dataset complet",
"select_class": "Choisir la classe :",
"num_images_input": "Nombre d'images à générer :",
"num_images_per_class": "Nombre d'images par classe :",
"generate_button": "Lancer la génération",
"generation_in_progress": "Génération en cours...",
"generating_class_info": "Génération de {num_images} images pour la classe '{class_name}'...",
"generating_dataset_info": "Génération d'un dataset complet de {num_images} images ({num_images_per_class} par classe)...",
"preview_header": "Aperçu des images générées",
"preview_caption": "Aperçu {idx}",
"download_button": "Télécharger le fichier ZIP",
"generation_success": "Génération terminée et images prêtes pour le téléchargement !",
"model_not_found_error": "Gestion des erreurs : Le fichier du modèle n'a pas pu être trouvé localement.",
"instructions_header": "Instructions :",
"instructions_1": "1. Choisissez votre mode de génération.",
"instructions_2": "2. Entrez le nombre d'images que vous souhaitez créer.",
"instructions_3": "3. Cliquez sur 'Lancer la génération'.",
}
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generateur(nn.Module):
def __init__(self, z_dim, ngf, num_classes):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(z_dim + num_classes, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False),
nn.Tanh()
)
self.z_dim = z_dim
self.num_classes = num_classes
def forward(self, x, labels):
x = x.view(-1, self.z_dim, 1, 1)
labels_reshaped = F.one_hot(labels, self.num_classes).float().view(-1, self.num_classes, 1, 1)
x = torch.cat([x, labels_reshaped], 1)
return self.main(x)
Z_DIM = 100
NGF = 64
NUM_CLASSES = 10
IMAGE_SIZE = 32
MODEL_FILE = os.path.join(os.path.dirname(__file__), "Raygan-zalando_datasetsgen.pth")
class_names = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]
class_to_idx = {name: i for i, name in enumerate(class_names)}
@st.cache_resource
def load_model_from_local():
try:
model = Generateur(Z_DIM, NGF, NUM_CLASSES).to(device)
full_state_dict = torch.load(MODEL_FILE, map_location=device)
filtered_state_dict = {
key: value for key, value in full_state_dict.items() if key.startswith('main')
}
model.load_state_dict(filtered_state_dict)
model.eval()
return model
except Exception as e:
st.error(st.session_state.lang_texts["model_not_found_error"])
st.error(f"Détails de l'erreur : {e}")
return None
if "lang" not in st.session_state:
st.session_state.lang = "en"
st.session_state.lang_texts = TEXTS["en"]
lang_selection = st.sidebar.selectbox(st.session_state.lang_texts["language_select"], ["English", "Français"])
if lang_selection == "Français":
st.session_state.lang = "fr"
st.session_state.lang_texts = TEXTS["fr"]
else:
st.session_state.lang = "en"
st.session_state.lang_texts = TEXTS["en"]
st.title(st.session_state.lang_texts["title"])
st.markdown(st.session_state.lang_texts["description"])
st.markdown(st.session_state.lang_texts["model_source"])
st.sidebar.header(st.session_state.lang_texts["sidebar_header"])
generation_mode = st.sidebar.radio(
st.session_state.lang_texts["generation_mode_radio"],
(st.session_state.lang_texts["mode_class"], st.session_state.lang_texts["mode_dataset"])
)
if generation_mode == st.session_state.lang_texts["mode_class"]:
selected_class_name = st.sidebar.selectbox(st.session_state.lang_texts["select_class"], options=class_names)
num_images_to_generate = st.sidebar.number_input(st.session_state.lang_texts["num_images_input"], min_value=1, max_value=1000, value=1, step=1)
else:
num_images_per_class = st.sidebar.number_input(st.session_state.lang_texts["num_images_per_class"], min_value=1, max_value=100, value=3, step=1)
num_images_to_generate = num_images_per_class * NUM_CLASSES
generate_button = st.sidebar.button(st.session_state.lang_texts["generate_button"])
if generate_button:
model = load_model_from_local()
if model is not None:
st.subheader(st.session_state.lang_texts["generation_in_progress"])
all_generated_images = []
progress_bar = st.progress(0)
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
if generation_mode == st.session_state.lang_texts["mode_class"]:
selected_class_idx = class_to_idx[selected_class_name]
st.info(st.session_state.lang_texts["generating_class_info"].format(num_images=num_images_to_generate, class_name=selected_class_name))
for i in range(num_images_to_generate):
noise = torch.randn(1, Z_DIM, device=device)
labels = torch.tensor([selected_class_idx]).to(device)
with torch.no_grad():
generated_image = model(noise, labels)
image_tensor = (generated_image.cpu().squeeze() + 1) / 2
image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8))
all_generated_images.append(image_pil)
img_byte_arr = io.BytesIO()
image_pil.save(img_byte_arr, format='PNG')
zipf.writestr(f"generated_images/{selected_class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue())
progress_bar.progress((i + 1) / num_images_to_generate)
else:
st.info(st.session_state.lang_texts["generating_dataset_info"].format(num_images=num_images_to_generate, num_images_per_class=num_images_per_class))
total_generated = 0
for class_idx, class_name in enumerate(class_names):
for i in range(num_images_per_class):
noise = torch.randn(1, Z_DIM, device=device)
labels = torch.tensor([class_idx]).to(device)
with torch.no_grad():
generated_image = model(noise, labels)
image_tensor = (generated_image.cpu().squeeze() + 1) / 2
image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8))
if i < 3:
all_generated_images.append(image_pil)
img_byte_arr = io.BytesIO()
image_pil.save(img_byte_arr, format='PNG')
zipf.writestr(f"generated_dataset/{class_name}/{class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue())
total_generated += 1
progress_bar.progress(total_generated / num_images_to_generate)
st.subheader(st.session_state.lang_texts["preview_header"])
cols = st.columns(3)
for idx, img in enumerate(all_generated_images):
with cols[idx % 3]:
st.image(img, caption=st.session_state.lang_texts["preview_caption"].format(idx=idx+1), use_container_width=True)
download_file_name = f"fashion_mnist_synthetique_{generation_mode.replace(' ', '_')}.zip"
st.download_button(
label=st.session_state.lang_texts["download_button"],
data=zip_buffer.getvalue(),
file_name=download_file_name,
mime="application/zip"
)
st.success(st.session_state.lang_texts["generation_success"])
else:
st.info(st.session_state.lang_texts["model_not_found_error"])
st.sidebar.markdown("---")
st.sidebar.markdown(f"**{st.session_state.lang_texts['instructions_header']}**")
st.sidebar.markdown(st.session_state.lang_texts["instructions_1"])
st.sidebar.markdown(st.session_state.lang_texts["instructions_2"])
st.sidebar.markdown(st.session_state.lang_texts["instructions_3"])