import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F import io import zipfile from PIL import Image import numpy as np import os st.set_page_config(layout="wide", page_title="Synthetic Image Generator - Conditional GAN") TEXTS = { "en": { "title": "RAYgan Datasets Generator", "description": "Uses a **Conditional Generative Adversarial Network (GAN)** to generate synthetic clothing images.", "model_source": "Model source: [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)", "sidebar_header": "Generation Options", "language_select": "Language", "generation_mode_radio": "Generation Mode:", "mode_class": "Generate by Class", "mode_dataset": "Generate a Full Dataset", "select_class": "Choose a Class:", "num_images_input": "Number of images to generate:", "num_images_per_class": "Number of images per class:", "generate_button": "Launch Generation", "generation_in_progress": "Generation in progress...", "generating_class_info": "Generating {num_images} images for class '{class_name}'...", "generating_dataset_info": "Generating a complete dataset of {num_images} images ({num_images_per_class} per class)...", "preview_header": "Preview of Generated Images", "preview_caption": "Preview {idx}", "download_button": "Download ZIP file", "generation_success": "Generation complete and images ready for download!", "model_not_found_error": "Error Management : The model file could not be found locally.", "instructions_header": "Instructions:", "instructions_1": "1. Choose your generation mode.", "instructions_2": "2. Enter the number of images you want to create.", "instructions_3": "3. Click on 'Launch Generation'.", }, "fr": { "title": "RAYgan Datasets Generator", "description": "Utilise un modèle **Conditional Generative Adversarial Network (GAN)** pour générer des images synthétiques de vêtements.", "model_source": "Source du modèle : [RAYAuser/raygan-zalando-datasetsgen](https://huggingface.co/RAYAuser/raygan-zalando-datasetsgen)", "sidebar_header": "Options de Génération", "language_select": "Langue", "generation_mode_radio": "Mode de génération :", "mode_class": "Générer par classe", "mode_dataset": "Générer un dataset complet", "select_class": "Choisir la classe :", "num_images_input": "Nombre d'images à générer :", "num_images_per_class": "Nombre d'images par classe :", "generate_button": "Lancer la génération", "generation_in_progress": "Génération en cours...", "generating_class_info": "Génération de {num_images} images pour la classe '{class_name}'...", "generating_dataset_info": "Génération d'un dataset complet de {num_images} images ({num_images_per_class} par classe)...", "preview_header": "Aperçu des images générées", "preview_caption": "Aperçu {idx}", "download_button": "Télécharger le fichier ZIP", "generation_success": "Génération terminée et images prêtes pour le téléchargement !", "model_not_found_error": "Gestion des erreurs : Le fichier du modèle n'a pas pu être trouvé localement.", "instructions_header": "Instructions :", "instructions_1": "1. Choisissez votre mode de génération.", "instructions_2": "2. Entrez le nombre d'images que vous souhaitez créer.", "instructions_3": "3. Cliquez sur 'Lancer la génération'.", } } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Generateur(nn.Module): def __init__(self, z_dim, ngf, num_classes): super().__init__() self.main = nn.Sequential( nn.ConvTranspose2d(z_dim + num_classes, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, 1, 4, 2, 1, bias=False), nn.Tanh() ) self.z_dim = z_dim self.num_classes = num_classes def forward(self, x, labels): x = x.view(-1, self.z_dim, 1, 1) labels_reshaped = F.one_hot(labels, self.num_classes).float().view(-1, self.num_classes, 1, 1) x = torch.cat([x, labels_reshaped], 1) return self.main(x) Z_DIM = 100 NGF = 64 NUM_CLASSES = 10 IMAGE_SIZE = 32 MODEL_FILE = os.path.join(os.path.dirname(__file__), "Raygan-zalando_datasetsgen.pth") class_names = [ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot" ] class_to_idx = {name: i for i, name in enumerate(class_names)} @st.cache_resource def load_model_from_local(): try: model = Generateur(Z_DIM, NGF, NUM_CLASSES).to(device) full_state_dict = torch.load(MODEL_FILE, map_location=device) filtered_state_dict = { key: value for key, value in full_state_dict.items() if key.startswith('main') } model.load_state_dict(filtered_state_dict) model.eval() return model except Exception as e: st.error(st.session_state.lang_texts["model_not_found_error"]) st.error(f"Détails de l'erreur : {e}") return None if "lang" not in st.session_state: st.session_state.lang = "en" st.session_state.lang_texts = TEXTS["en"] lang_selection = st.sidebar.selectbox(st.session_state.lang_texts["language_select"], ["English", "Français"]) if lang_selection == "Français": st.session_state.lang = "fr" st.session_state.lang_texts = TEXTS["fr"] else: st.session_state.lang = "en" st.session_state.lang_texts = TEXTS["en"] st.title(st.session_state.lang_texts["title"]) st.markdown(st.session_state.lang_texts["description"]) st.markdown(st.session_state.lang_texts["model_source"]) st.sidebar.header(st.session_state.lang_texts["sidebar_header"]) generation_mode = st.sidebar.radio( st.session_state.lang_texts["generation_mode_radio"], (st.session_state.lang_texts["mode_class"], st.session_state.lang_texts["mode_dataset"]) ) if generation_mode == st.session_state.lang_texts["mode_class"]: selected_class_name = st.sidebar.selectbox(st.session_state.lang_texts["select_class"], options=class_names) num_images_to_generate = st.sidebar.number_input(st.session_state.lang_texts["num_images_input"], min_value=1, max_value=1000, value=1, step=1) else: num_images_per_class = st.sidebar.number_input(st.session_state.lang_texts["num_images_per_class"], min_value=1, max_value=100, value=3, step=1) num_images_to_generate = num_images_per_class * NUM_CLASSES generate_button = st.sidebar.button(st.session_state.lang_texts["generate_button"]) if generate_button: model = load_model_from_local() if model is not None: st.subheader(st.session_state.lang_texts["generation_in_progress"]) all_generated_images = [] progress_bar = st.progress(0) zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf: if generation_mode == st.session_state.lang_texts["mode_class"]: selected_class_idx = class_to_idx[selected_class_name] st.info(st.session_state.lang_texts["generating_class_info"].format(num_images=num_images_to_generate, class_name=selected_class_name)) for i in range(num_images_to_generate): noise = torch.randn(1, Z_DIM, device=device) labels = torch.tensor([selected_class_idx]).to(device) with torch.no_grad(): generated_image = model(noise, labels) image_tensor = (generated_image.cpu().squeeze() + 1) / 2 image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8)) all_generated_images.append(image_pil) img_byte_arr = io.BytesIO() image_pil.save(img_byte_arr, format='PNG') zipf.writestr(f"generated_images/{selected_class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue()) progress_bar.progress((i + 1) / num_images_to_generate) else: st.info(st.session_state.lang_texts["generating_dataset_info"].format(num_images=num_images_to_generate, num_images_per_class=num_images_per_class)) total_generated = 0 for class_idx, class_name in enumerate(class_names): for i in range(num_images_per_class): noise = torch.randn(1, Z_DIM, device=device) labels = torch.tensor([class_idx]).to(device) with torch.no_grad(): generated_image = model(noise, labels) image_tensor = (generated_image.cpu().squeeze() + 1) / 2 image_pil = Image.fromarray((image_tensor.numpy() * 255).astype(np.uint8)) if i < 3: all_generated_images.append(image_pil) img_byte_arr = io.BytesIO() image_pil.save(img_byte_arr, format='PNG') zipf.writestr(f"generated_dataset/{class_name}/{class_name.replace('/', '_')}_{i+1}.png", img_byte_arr.getvalue()) total_generated += 1 progress_bar.progress(total_generated / num_images_to_generate) st.subheader(st.session_state.lang_texts["preview_header"]) cols = st.columns(3) for idx, img in enumerate(all_generated_images): with cols[idx % 3]: st.image(img, caption=st.session_state.lang_texts["preview_caption"].format(idx=idx+1), use_container_width=True) download_file_name = f"fashion_mnist_synthetique_{generation_mode.replace(' ', '_')}.zip" st.download_button( label=st.session_state.lang_texts["download_button"], data=zip_buffer.getvalue(), file_name=download_file_name, mime="application/zip" ) st.success(st.session_state.lang_texts["generation_success"]) else: st.info(st.session_state.lang_texts["model_not_found_error"]) st.sidebar.markdown("---") st.sidebar.markdown(f"**{st.session_state.lang_texts['instructions_header']}**") st.sidebar.markdown(st.session_state.lang_texts["instructions_1"]) st.sidebar.markdown(st.session_state.lang_texts["instructions_2"]) st.sidebar.markdown(st.session_state.lang_texts["instructions_3"])