|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import io |
|
import zipfile |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
st.set_page_config(layout="wide", page_title="Synthetic Image Generator - Conditional GAN") |
|
|
|
TEXTS = { |
|
"en": { |
|
"title": "RAYgan Datasets Generator", |
|
"description": "Uses a **Conditional Generative Adversarial Network (GAN)** to generate synthetic clothing images.", |
|
"model_source": "Model source: [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)", |
|
"sidebar_header": "Generation Options", |
|
"language_select": "Language", |
|
"generation_mode_radio": "Generation Mode:", |
|
"mode_class": "Generate by Class", |
|
"mode_dataset": "Generate a Full Dataset", |
|
"select_class": "Choose a Class:", |
|
"num_images_input": "Number of images to generate:", |
|
"num_images_per_class": "Number of images per class:", |
|
"generate_button": "Launch Generation", |
|
"generation_in_progress": "Generation in progress...", |
|
"generating_class_info": "Generating {num_images} images for class '{class_name}'...", |
|
"generating_dataset_info": "Generating a complete dataset of {num_images} images ({num_images_per_class} per class)...", |
|
"preview_header": "Preview of Generated Images", |
|
"preview_caption": "Preview {idx}", |
|
"download_button": "Download ZIP file", |
|
"generation_success": "Generation complete and images ready for download!", |
|
"model_not_found_error": "Error Management : The model file could not be found locally.", |
|
"instructions_header": "Instructions:", |
|
"instructions_1": "1. Choose your generation mode.", |
|
"instructions_2": "2. Enter the number of images you want to create.", |
|
"instructions_3": "3. Click on 'Launch Generation'.", |
|
}, |
|
"fr": { |
|
"title": "RAYgan Datasets Generator", |
|
"description": "Utilise un modèle **Conditional Generative Adversarial Network (GAN)** pour générer des images synthétiques de vêtements.", |
|
"model_source": "Source du modèle : [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)", |
|
"sidebar_header": "Options de Génération", |
|
"language_select": "Langue", |
|
"generation_mode_radio": "Mode de génération :", |
|
"mode_class": "Générer par classe", |
|
"mode_dataset": "Générer un dataset complet", |
|
"select_class": "Choisir la classe :", |
|
"num_images_input": "Nombre d'images à générer :", |
|
"num_images_per_class": "Nombre d'images par classe :", |
|
"generate_button": "Lancer la génération", |
|
"generation_in_progress": "Génération en cours...", |
|
"generating_class_info": "Génération de {num_images} images pour la classe '{class_name}'...", |
|
"generating_dataset_info": "Génération d'un dataset complet de {num_images} images ({num_images_per_class} par classe)...", |
|
"preview_header": "Aperçu des images générées", |
|
"preview_caption": "Aperçu {idx}", |
|
"download_button": "Télécharger le fichier ZIP", |
|
"generation_success": "Génération terminée et images prêtes pour le téléchargement !", |
|
"model_not_found_error": "Gestion des erreurs : Le fichier du modèle n'a pas pu être trouvé localement.", |
|
"instructions_header": "Instructions :", |
|
"instructions_1": "1. Choisissez votre mode de génération.", |
|
"instructions_2": "2. Entrez le nombre d'images que vous souhaitez créer.", |
|
"instructions_3": "3. Cliquez sur 'Lancer la génération'.", |
|
} |
|
} |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class Generateur(nn.Module): |
|
def __init__(self, z_dim, ngf, num_classes): |
|
super().__init__() |
|
self.main = nn.Sequential( |
|
nn.ConvTranspose2d(z_dim + num_classes, ngf * 8, 4, 1, 0, bias=False), |
|
nn.BatchNorm2d(ngf * 8), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 4), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), |
|
nn.BatchNorm2d(ngf * 2), |
|
nn.ReLU(True), |
|
nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False), |
|
nn.Tanh() |
|
) |
|
self.z_dim = z_dim |
|
self.num_classes = num_classes |
|
|
|
def forward(self, x, labels): |
|
x = x.view(-1, self.z_dim, 1, 1) |
|
labels_reshaped = F.one_hot(labels, self.num_classes).float().view(-1, self.num_classes, 1, 1) |
|
x = torch.cat([x, labels_reshaped], 1) |
|
return self.main(x) |
|
|
|
Z_DIM = 100 |
|
NGF = 64 |
|
NUM_CLASSES = 10 |
|
IMAGE_SIZE = 32 |
|
MODEL_FILE = os.path.join(os.path.dirname(__file__), "Raygan-zalando_datasetsgen.pth") |
|
|
|
class_names = [ |
|
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", |
|
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot" |
|
] |
|
class_to_idx = {name: i for i, name in enumerate(class_names)} |
|
|
|
@st.cache_resource |
|
def load_model_from_local(): |
|
try: |
|
model = Generateur(Z_DIM, NGF, NUM_CLASSES).to(device) |
|
full_state_dict = torch.load(MODEL_FILE, map_location=device) |
|
filtered_state_dict = { |
|
key: value for key, value in full_state_dict.items() if key.startswith('main') |
|
} |
|
model.load_state_dict(filtered_state_dict) |
|
model.eval() |
|
return model |
|
except Exception as e: |
|
st.error(st.session_state.lang_texts["model_not_found_error"]) |
|
st.error(f"Détails de l'erreur : {e}") |
|
return None |
|
|
|
if "lang" not in st.session_state: |
|
st.session_state.lang = "en" |
|
st.session_state.lang_texts = TEXTS["en"] |
|
|
|
lang_selection = st.sidebar.selectbox(st.session_state.lang_texts["language_select"], ["English", "Français"]) |
|
if lang_selection == "Français": |
|
st.session_state.lang = "fr" |
|
st.session_state.lang_texts = TEXTS["fr"] |
|
else: |
|
st.session_state.lang = "en" |
|
st.session_state.lang_texts = TEXTS["en"] |
|
|
|
st.title(st.session_state.lang_texts["title"]) |
|
st.markdown(st.session_state.lang_texts["description"]) |
|
st.markdown(st.session_state.lang_texts["model_source"]) |
|
st.sidebar.header(st.session_state.lang_texts["sidebar_header"]) |
|
|
|
generation_mode = st.sidebar.radio( |
|
st.session_state.lang_texts["generation_mode_radio"], |
|
(st.session_state.lang_texts["mode_class"], st.session_state.lang_texts["mode_dataset"]) |
|
) |
|
|
|
if generation_mode == st.session_state.lang_texts["mode_class"]: |
|
selected_class_name = st.sidebar.selectbox(st.session_state.lang_texts["select_class"], options=class_names) |
|
num_images_to_generate = st.sidebar.number_input(st.session_state.lang_texts["num_images_input"], min_value=1, max_value=1000, value=1, step=1) |
|
else: |
|
num_images_per_class = st.sidebar.number_input(st.session_state.lang_texts["num_images_per_class"], min_value=1, max_value=100, value=3, step=1) |
|
num_images_to_generate = num_images_per_class * NUM_CLASSES |
|
|
|
generate_button = st.sidebar.button(st.session_state.lang_texts["generate_button"]) |
|
|
|
if generate_button: |
|
model = load_model_from_local() |
|
if model is not None: |
|
st.subheader(st.session_state.lang_texts["generation_in_progress"]) |
|
all_generated_images = [] |
|
progress_bar = st.progress(0) |
|
zip_buffer = io.BytesIO() |
|
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
if generation_mode == st.session_state.lang_texts["mode_class"]: |
|
selected_class_idx = class_to_idx[selected_class_name] |
|
st.info(st.session_state.lang_texts["generating_class_info"].format(num_images=num_images_to_generate, class_name=selected_class_name)) |
|
for i in range(num_images_to_generate): |
|
noise = torch.randn(1, Z_DIM, device=device) |
|
labels = torch.tensor([selected_class_idx]).to(device) |
|
with torch.no_grad(): |
|
generated_image = model(noise, labels) |
|
image_tensor = (generated_image.cpu().squeeze() + 1) / 2 |
|
image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8)) |
|
all_generated_images.append(image_pil) |
|
img_byte_arr = io.BytesIO() |
|
image_pil.save(img_byte_arr, format='PNG') |
|
zipf.writestr(f"generated_images/{selected_class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue()) |
|
progress_bar.progress((i + 1) / num_images_to_generate) |
|
else: |
|
st.info(st.session_state.lang_texts["generating_dataset_info"].format(num_images=num_images_to_generate, num_images_per_class=num_images_per_class)) |
|
total_generated = 0 |
|
for class_idx, class_name in enumerate(class_names): |
|
for i in range(num_images_per_class): |
|
noise = torch.randn(1, Z_DIM, device=device) |
|
labels = torch.tensor([class_idx]).to(device) |
|
with torch.no_grad(): |
|
generated_image = model(noise, labels) |
|
image_tensor = (generated_image.cpu().squeeze() + 1) / 2 |
|
image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8)) |
|
if i < 3: |
|
all_generated_images.append(image_pil) |
|
img_byte_arr = io.BytesIO() |
|
image_pil.save(img_byte_arr, format='PNG') |
|
zipf.writestr(f"generated_dataset/{class_name}/{class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue()) |
|
total_generated += 1 |
|
progress_bar.progress(total_generated / num_images_to_generate) |
|
st.subheader(st.session_state.lang_texts["preview_header"]) |
|
cols = st.columns(3) |
|
for idx, img in enumerate(all_generated_images): |
|
with cols[idx % 3]: |
|
st.image(img, caption=st.session_state.lang_texts["preview_caption"].format(idx=idx+1), use_container_width=True) |
|
download_file_name = f"fashion_mnist_synthetique_{generation_mode.replace(' ', '_')}.zip" |
|
st.download_button( |
|
label=st.session_state.lang_texts["download_button"], |
|
data=zip_buffer.getvalue(), |
|
file_name=download_file_name, |
|
mime="application/zip" |
|
) |
|
st.success(st.session_state.lang_texts["generation_success"]) |
|
else: |
|
st.info(st.session_state.lang_texts["model_not_found_error"]) |
|
|
|
st.sidebar.markdown("---") |
|
st.sidebar.markdown(f"**{st.session_state.lang_texts['instructions_header']}**") |
|
st.sidebar.markdown(st.session_state.lang_texts["instructions_1"]) |
|
st.sidebar.markdown(st.session_state.lang_texts["instructions_2"]) |
|
st.sidebar.markdown(st.session_state.lang_texts["instructions_3"]) |
|
|