import os import warnings from typing import Optional, Tuple import gradio as gr import pandas as pd import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots warnings.filterwarnings("ignore") # Import Mostly AI SDK try: from mostlyai.sdk import MostlyAI MOSTLY_AI_AVAILABLE = True except ImportError: MOSTLY_AI_AVAILABLE = False print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]") class SyntheticDataGenerator: def __init__(self): self.mostly = None self.generator = None self.original_data = None def initialize_mostly_ai(self) -> Tuple[bool, str]: if not MOSTLY_AI_AVAILABLE: return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]" try: self.mostly = MostlyAI(local=True, local_port=8080) return True, "Mostly AI SDK initialized successfully." except Exception as e: return False, f"Failed to initialize Mostly AI SDK: {str(e)}" def train_generator( self, data: pd.DataFrame, name: str, epochs: int = 10, max_training_time: int = 30, batch_size: int = 32, value_protection: bool = True, rare_category_protection: bool = False, flexible_generation: bool = False, model_size: str = "MEDIUM", target_accuracy: float = 0.95, validation_split: float = 0.2, learning_rate: float = 0.001, early_stopping_patience: int = 10, dropout_rate: float = 0.1, weight_decay: float = 0.0001, ) -> Tuple[bool, str]: if not self.mostly: return False, "Mostly AI SDK not initialized. Please initialize the SDK first." try: self.original_data = data train_config = { "tables": [ { "name": name, "data": data, "tabular_model_configuration": { "max_epochs": epochs, "max_training_time": max_training_time, "value_protection": value_protection, "batch_size": batch_size, "rare_category_protection": rare_category_protection, "flexible_generation": flexible_generation, "model_size": model_size, # "SMALL" | "MEDIUM" | "LARGE" "target_accuracy": target_accuracy, # early stop once target met "validation_split": validation_split, "learning_rate": learning_rate, "early_stopping_patience": early_stopping_patience, "dropout_rate": dropout_rate, "weight_decay": weight_decay, }, } ] } self.generator = self.mostly.train(config=train_config) return True, f"Training completed successfully. Model name: {name}" except Exception as e: return False, f"Training failed with error: {str(e)}" def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]: if not self.generator: return None, "No trained generator available. Please train a model first." try: synthetic_data = self.mostly.generate(self.generator, size=size) df = synthetic_data.data() return df, f"Synthetic data generated successfully. {len(df)} records created." except Exception as e: return None, f"Synthetic data generation failed with error: {str(e)}" def estimate_memory_usage(self, df: pd.DataFrame) -> str: if df is None or df.empty: return "No data available to analyze." memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) rows, cols = len(df), len(df.columns) estimated_training_mb = memory_mb * 4 status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large") return f""" Memory Usage Estimate: - Data size: {memory_mb:.1f} MB - Estimated training memory: {estimated_training_mb:.1f} MB - Status: {status} - Rows: {rows:,} | Columns: {cols} """.strip() # --- App state --- generator = SyntheticDataGenerator() _last_synth_df: Optional[pd.DataFrame] = None # ---- Gradio wrappers ---- def initialize_sdk() -> str: ok, msg = generator.initialize_mostly_ai() return ("Success: " if ok else "Error: ") + msg def train_model( data: pd.DataFrame, model_name: str, epochs: int, max_training_time: int, batch_size: int, value_protection: bool, rare_category_protection: bool, flexible_generation: bool, model_size: str, target_accuracy: float, validation_split: float, learning_rate: float, early_stopping_patience: int, dropout_rate: float, weight_decay: float, ) -> str: if data is None or data.empty: return "Error: No data provided. Please upload or create sample data first." ok, msg = generator.train_generator( data=data, name=model_name, epochs=epochs, max_training_time=max_training_time, batch_size=batch_size, value_protection=value_protection, rare_category_protection=rare_category_protection, flexible_generation=flexible_generation, model_size=model_size, target_accuracy=target_accuracy, validation_split=validation_split, learning_rate=learning_rate, early_stopping_patience=early_stopping_patience, dropout_rate=dropout_rate, weight_decay=weight_decay, ) return ("Success: " if ok else "Error: ") + msg def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]: global _last_synth_df synth_df, message = generator.generate_synthetic_data(size) if synth_df is not None: _last_synth_df = synth_df.copy() return synth_df, f"Success: {message}" else: return None, f"Error: {message}" def download_csv_prepare() -> Optional[str]: """Return a path to the latest synthetic CSV; used as output to gr.File.""" global _last_synth_df if _last_synth_df is None or _last_synth_df.empty: return None os.makedirs("/tmp", exist_ok=True) path = "/tmp/synthetic_data.csv" _last_synth_df.to_csv(path, index=False) return path def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame): if original_df is None or synthetic_df is None: return None numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() if not numeric_cols: return None n_cols = min(3, len(numeric_cols)) n_rows = (len(numeric_cols) + n_cols - 1) // n_cols fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols]) for i, col in enumerate(numeric_cols[: n_rows * n_cols]): row = i // n_cols + 1 col_idx = i % n_cols + 1 fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True) return fig # ---- UI ---- def create_interface(): with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: gr.Image( value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png", show_label=False, elem_id="header-image", ) gr.Markdown( """ # Synthetic Data SDK by MOSTLY AI Demo Space [Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/) A Python toolkit for generating high-fidelity, privacy-safe synthetic data. """ ) with gr.Tab("Quick Start"): gr.Markdown("### Initialize the SDK and upload your data") with gr.Row(): with gr.Column(): init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") init_status = gr.Textbox(label="Initialization Status", interactive=False) with gr.Column(): gr.Markdown( """ **Next Steps:** 1. Initialize the SDK 2. Go to the "Upload Data and Train Model" tab to upload your CSV file 3. Train a model on your data 4. Generate synthetic data """ ) with gr.Tab("Upload Data and Train Model"): gr.Markdown("### Upload your CSV file to generate synthetic data") gr.Markdown( """ **File Requirements:** - Format: CSV with header row - Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM) """ ) file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single") uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False) memory_info = gr.Markdown(label="Memory Usage Info", visible=False) with gr.Row(): with gr.Column(scale=1): model_name = gr.Textbox( value="My Synthetic Model", label="Generator Name", placeholder="Enter a name for your generator", info="Appears in training runs and saved generators." ) epochs = gr.Slider( 1, 200, value=100, step=1, label="Training Epochs", info="Maximum number of passes over the training data." ) max_training_time = gr.Slider( 1, 1000, value=60, step=1, label="Maximum Training Time (minutes)", info="Upper bound in minutes; training stops if exceeded." ) batch_size = gr.Slider( 8, 1024, value=32, step=8, label="Batch Size", info="Number of rows per optimization step. Larger can speed up but needs more memory." ) value_protection = gr.Checkbox( label="Value Protection", info="Adds protections to reduce memorization of unique or sensitive values.", value=False ) rare_category_protection = gr.Checkbox( label="Rare Category Protection", info="Prevents overfitting to infrequent categories to improve privacy and robustness.", value=False ) with gr.Column(scale=1): flexible_generation = gr.Checkbox( label="Flexible Generation", info="Allows generation when inputs slightly differ from training schema.", value=True ) model_size = gr.Dropdown( choices=["SMALL", "MEDIUM", "LARGE"], value="MEDIUM", label="Model Size", info="Sets model capacity. Larger can improve fidelity but uses more compute." ) target_accuracy = gr.Slider( 0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy", info="Stop early when validation accuracy reaches this threshold." ) validation_split = gr.Slider( 0.05, 0.5, value=0.2, step=0.01, label="Validation Split", info="Fraction of the dataset held out for validation during training." ) early_stopping_patience = gr.Slider( 0, 50, value=10, step=1, label="Early Stopping Patience (epochs)", info="Stop if no validation improvement after this many epochs." ) with gr.Column(scale=1): learning_rate = gr.Number( value=0.001, precision=6, label="Learning Rate", info="Step size for the optimizer. Typical range: 1e-4 to 1e-2." ) dropout_rate = gr.Slider( 0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate", info="Regularization to reduce overfitting by randomly dropping units." ) weight_decay = gr.Number( value=0.0001, precision=6, label="Weight Decay", info="L2 regularization strength applied to model weights." ) train_btn = gr.Button("Train Model", variant="primary") train_status = gr.Textbox(label="Training Status", interactive=False) with gr.Tab("Generate Data"): gr.Markdown("### Generate synthetic data from your trained model") with gr.Row(): with gr.Column(): gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate", info="How many synthetic rows to create in the table.") generate_btn = gr.Button("Generate Synthetic Data", variant="primary") with gr.Column(): gen_status = gr.Textbox(label="Generation Status", interactive=False) synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) with gr.Row(): csv_download_btn = gr.Button("Download CSV", variant="secondary") with gr.Group(visible=False) as csv_group: csv_file = gr.File(label="Synthetic CSV", interactive=False) comparison_plot = gr.Plot(label="Data Comparison") # ---- Events ---- init_btn.click(initialize_sdk, outputs=[init_status]) train_btn.click( train_model, inputs=[ uploaded_data, model_name, epochs, max_training_time, batch_size, value_protection, rare_category_protection, flexible_generation, model_size, target_accuracy, validation_split, learning_rate, early_stopping_patience, dropout_rate, weight_decay ], outputs=[train_status], ) generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status]) synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]) def _prepare_csv_for_download(): path = download_csv_prepare() if path: return path, gr.update(visible=True) else: return None, gr.update(visible=False) csv_download_btn.click( _prepare_csv_for_download, outputs=[csv_file, csv_group], ) def process_uploaded_file(file): if file is None: return None, "No file uploaded.", gr.update(visible=False) try: df = pd.read_csv(file.name) success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns" mem_info = generator.estimate_memory_usage(df) return df, success_msg, gr.update(value=mem_info, visible=True) except Exception as e: return None, f"Error reading file: {str(e)}", gr.update(visible=False) file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info]) return demo if __name__ == "__main__": demo = create_interface() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)