Spaces:
Running
Running
import os | |
import warnings | |
from typing import Optional, Tuple | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
warnings.filterwarnings("ignore") | |
# Import Mostly AI SDK | |
try: | |
from mostlyai.sdk import MostlyAI | |
MOSTLY_AI_AVAILABLE = True | |
except ImportError: | |
MOSTLY_AI_AVAILABLE = False | |
print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]") | |
class SyntheticDataGenerator: | |
def __init__(self): | |
self.mostly = None | |
self.generator = None | |
self.original_data = None | |
def initialize_mostly_ai(self) -> Tuple[bool, str]: | |
if not MOSTLY_AI_AVAILABLE: | |
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]" | |
try: | |
self.mostly = MostlyAI(local=True, local_port=8080) | |
return True, "Mostly AI SDK initialized successfully." | |
except Exception as e: | |
return False, f"Failed to initialize Mostly AI SDK: {str(e)}" | |
def train_generator( | |
self, | |
data: pd.DataFrame, | |
name: str, | |
epochs: int = 10, | |
max_training_time: int = 30, | |
batch_size: int = 32, | |
value_protection: bool = True, | |
rare_category_protection: bool = False, | |
flexible_generation: bool = False, | |
model_size: str = "MEDIUM", | |
target_accuracy: float = 0.95, | |
validation_split: float = 0.2, | |
learning_rate: float = 0.001, | |
early_stopping_patience: int = 10, | |
dropout_rate: float = 0.1, | |
weight_decay: float = 0.0001, | |
) -> Tuple[bool, str]: | |
if not self.mostly: | |
return False, "Mostly AI SDK not initialized. Please initialize the SDK first." | |
try: | |
self.original_data = data | |
train_config = { | |
"tables": [ | |
{ | |
"name": name, | |
"data": data, | |
"tabular_model_configuration": { | |
"max_epochs": epochs, | |
"max_training_time": max_training_time, | |
"value_protection": value_protection, | |
"batch_size": batch_size, | |
"rare_category_protection": rare_category_protection, | |
"flexible_generation": flexible_generation, | |
"model_size": model_size, # "SMALL" | "MEDIUM" | "LARGE" | |
"target_accuracy": target_accuracy, # early stop once target met | |
"validation_split": validation_split, | |
"learning_rate": learning_rate, | |
"early_stopping_patience": early_stopping_patience, | |
"dropout_rate": dropout_rate, | |
"weight_decay": weight_decay, | |
}, | |
} | |
] | |
} | |
self.generator = self.mostly.train(config=train_config) | |
return True, f"Training completed successfully. Model name: {name}" | |
except Exception as e: | |
return False, f"Training failed with error: {str(e)}" | |
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
if not self.generator: | |
return None, "No trained generator available. Please train a model first." | |
try: | |
synthetic_data = self.mostly.generate(self.generator, size=size) | |
df = synthetic_data.data() | |
return df, f"Synthetic data generated successfully. {len(df)} records created." | |
except Exception as e: | |
return None, f"Synthetic data generation failed with error: {str(e)}" | |
def estimate_memory_usage(self, df: pd.DataFrame) -> str: | |
if df is None or df.empty: | |
return "No data available to analyze." | |
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024) | |
rows, cols = len(df), len(df.columns) | |
estimated_training_mb = memory_mb * 4 | |
status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large") | |
return f""" | |
Memory Usage Estimate: | |
- Data size: {memory_mb:.1f} MB | |
- Estimated training memory: {estimated_training_mb:.1f} MB | |
- Status: {status} | |
- Rows: {rows:,} | Columns: {cols} | |
""".strip() | |
# --- App state --- | |
generator = SyntheticDataGenerator() | |
_last_synth_df: Optional[pd.DataFrame] = None | |
# ---- Gradio wrappers ---- | |
def initialize_sdk() -> str: | |
ok, msg = generator.initialize_mostly_ai() | |
return ("Success: " if ok else "Error: ") + msg | |
def train_model( | |
data: pd.DataFrame, | |
model_name: str, | |
epochs: int, | |
max_training_time: int, | |
batch_size: int, | |
value_protection: bool, | |
rare_category_protection: bool, | |
flexible_generation: bool, | |
model_size: str, | |
target_accuracy: float, | |
validation_split: float, | |
learning_rate: float, | |
early_stopping_patience: int, | |
dropout_rate: float, | |
weight_decay: float, | |
) -> str: | |
if data is None or data.empty: | |
return "Error: No data provided. Please upload or create sample data first." | |
ok, msg = generator.train_generator( | |
data=data, | |
name=model_name, | |
epochs=epochs, | |
max_training_time=max_training_time, | |
batch_size=batch_size, | |
value_protection=value_protection, | |
rare_category_protection=rare_category_protection, | |
flexible_generation=flexible_generation, | |
model_size=model_size, | |
target_accuracy=target_accuracy, | |
validation_split=validation_split, | |
learning_rate=learning_rate, | |
early_stopping_patience=early_stopping_patience, | |
dropout_rate=dropout_rate, | |
weight_decay=weight_decay, | |
) | |
return ("Success: " if ok else "Error: ") + msg | |
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]: | |
global _last_synth_df | |
synth_df, message = generator.generate_synthetic_data(size) | |
if synth_df is not None: | |
_last_synth_df = synth_df.copy() | |
return synth_df, f"Success: {message}" | |
else: | |
return None, f"Error: {message}" | |
def download_csv_prepare() -> Optional[str]: | |
"""Return a path to the latest synthetic CSV; used as output to gr.File.""" | |
global _last_synth_df | |
if _last_synth_df is None or _last_synth_df.empty: | |
return None | |
os.makedirs("/tmp", exist_ok=True) | |
path = "/tmp/synthetic_data.csv" | |
_last_synth_df.to_csv(path, index=False) | |
return path | |
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame): | |
if original_df is None or synthetic_df is None: | |
return None | |
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist() | |
if not numeric_cols: | |
return None | |
n_cols = min(3, len(numeric_cols)) | |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols]) | |
for i, col in enumerate(numeric_cols[: n_rows * n_cols]): | |
row = i // n_cols + 1 | |
col_idx = i % n_cols + 1 | |
fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) | |
fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx) | |
fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True) | |
return fig | |
# ---- UI ---- | |
def create_interface(): | |
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo: | |
gr.Image( | |
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png", | |
show_label=False, | |
elem_id="header-image", | |
) | |
gr.Markdown( | |
""" | |
# Synthetic Data SDK by MOSTLY AI Demo Space | |
[Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/) | |
A Python toolkit for generating high-fidelity, privacy-safe synthetic data. | |
""" | |
) | |
with gr.Tab("Quick Start"): | |
gr.Markdown("### Initialize the SDK and upload your data") | |
with gr.Row(): | |
with gr.Column(): | |
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary") | |
init_status = gr.Textbox(label="Initialization Status", interactive=False) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
**Next Steps:** | |
1. Initialize the SDK | |
2. Go to the "Upload Data and Train Model" tab to upload your CSV file | |
3. Train a model on your data | |
4. Generate synthetic data | |
""" | |
) | |
with gr.Tab("Upload Data and Train Model"): | |
gr.Markdown("### Upload your CSV file to generate synthetic data") | |
gr.Markdown( | |
""" | |
**File Requirements:** | |
- Format: CSV with header row | |
- Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM) | |
""" | |
) | |
file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single") | |
uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False) | |
memory_info = gr.Markdown(label="Memory Usage Info", visible=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
model_name = gr.Textbox( | |
value="My Synthetic Model", | |
label="Generator Name", | |
placeholder="Enter a name for your generator", | |
info="Appears in training runs and saved generators." | |
) | |
epochs = gr.Slider( | |
1, 200, value=100, step=1, label="Training Epochs", | |
info="Maximum number of passes over the training data." | |
) | |
max_training_time = gr.Slider( | |
1, 1000, value=60, step=1, label="Maximum Training Time (minutes)", | |
info="Upper bound in minutes; training stops if exceeded." | |
) | |
batch_size = gr.Slider( | |
8, 1024, value=32, step=8, label="Batch Size", | |
info="Number of rows per optimization step. Larger can speed up but needs more memory." | |
) | |
value_protection = gr.Checkbox( | |
label="Value Protection", | |
info="Adds protections to reduce memorization of unique or sensitive values.", | |
value=False | |
) | |
rare_category_protection = gr.Checkbox( | |
label="Rare Category Protection", | |
info="Prevents overfitting to infrequent categories to improve privacy and robustness.", | |
value=False | |
) | |
with gr.Column(scale=1): | |
flexible_generation = gr.Checkbox( | |
label="Flexible Generation", | |
info="Allows generation when inputs slightly differ from training schema.", | |
value=True | |
) | |
model_size = gr.Dropdown( | |
choices=["SMALL", "MEDIUM", "LARGE"], | |
value="MEDIUM", | |
label="Model Size", | |
info="Sets model capacity. Larger can improve fidelity but uses more compute." | |
) | |
target_accuracy = gr.Slider( | |
0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy", | |
info="Stop early when validation accuracy reaches this threshold." | |
) | |
validation_split = gr.Slider( | |
0.05, 0.5, value=0.2, step=0.01, label="Validation Split", | |
info="Fraction of the dataset held out for validation during training." | |
) | |
early_stopping_patience = gr.Slider( | |
0, 50, value=10, step=1, label="Early Stopping Patience (epochs)", | |
info="Stop if no validation improvement after this many epochs." | |
) | |
with gr.Column(scale=1): | |
learning_rate = gr.Number( | |
value=0.001, precision=6, label="Learning Rate", | |
info="Step size for the optimizer. Typical range: 1e-4 to 1e-2." | |
) | |
dropout_rate = gr.Slider( | |
0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate", | |
info="Regularization to reduce overfitting by randomly dropping units." | |
) | |
weight_decay = gr.Number( | |
value=0.0001, precision=6, label="Weight Decay", | |
info="L2 regularization strength applied to model weights." | |
) | |
train_btn = gr.Button("Train Model", variant="primary") | |
train_status = gr.Textbox(label="Training Status", interactive=False) | |
with gr.Tab("Generate Data"): | |
gr.Markdown("### Generate synthetic data from your trained model") | |
with gr.Row(): | |
with gr.Column(): | |
gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate", | |
info="How many synthetic rows to create in the table.") | |
generate_btn = gr.Button("Generate Synthetic Data", variant="primary") | |
with gr.Column(): | |
gen_status = gr.Textbox(label="Generation Status", interactive=False) | |
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False) | |
with gr.Row(): | |
csv_download_btn = gr.Button("Download CSV", variant="secondary") | |
with gr.Group(visible=False) as csv_group: | |
csv_file = gr.File(label="Synthetic CSV", interactive=False) | |
comparison_plot = gr.Plot(label="Data Comparison") | |
# ---- Events ---- | |
init_btn.click(initialize_sdk, outputs=[init_status]) | |
train_btn.click( | |
train_model, | |
inputs=[ | |
uploaded_data, model_name, | |
epochs, max_training_time, batch_size, | |
value_protection, rare_category_protection, flexible_generation, | |
model_size, target_accuracy, validation_split, | |
learning_rate, early_stopping_patience, dropout_rate, weight_decay | |
], | |
outputs=[train_status], | |
) | |
generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status]) | |
synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot]) | |
def _prepare_csv_for_download(): | |
path = download_csv_prepare() | |
if path: | |
return path, gr.update(visible=True) | |
else: | |
return None, gr.update(visible=False) | |
csv_download_btn.click( | |
_prepare_csv_for_download, | |
outputs=[csv_file, csv_group], | |
) | |
def process_uploaded_file(file): | |
if file is None: | |
return None, "No file uploaded.", gr.update(visible=False) | |
try: | |
df = pd.read_csv(file.name) | |
success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns" | |
mem_info = generator.estimate_memory_usage(df) | |
return df, success_msg, gr.update(value=mem_info, visible=True) | |
except Exception as e: | |
return None, f"Error reading file: {str(e)}", gr.update(visible=False) | |
file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info]) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |