ZennyKenny's picture
Update app.py
2ab0ee5 verified
import os
import warnings
from typing import Optional, Tuple
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
warnings.filterwarnings("ignore")
# Import Mostly AI SDK
try:
from mostlyai.sdk import MostlyAI
MOSTLY_AI_AVAILABLE = True
except ImportError:
MOSTLY_AI_AVAILABLE = False
print("Warning: Mostly AI SDK not available. Please install with: pip install mostlyai[local]")
class SyntheticDataGenerator:
def __init__(self):
self.mostly = None
self.generator = None
self.original_data = None
def initialize_mostly_ai(self) -> Tuple[bool, str]:
if not MOSTLY_AI_AVAILABLE:
return False, "Mostly AI SDK not installed. Please install with: pip install mostlyai[local]"
try:
self.mostly = MostlyAI(local=True, local_port=8080)
return True, "Mostly AI SDK initialized successfully."
except Exception as e:
return False, f"Failed to initialize Mostly AI SDK: {str(e)}"
def train_generator(
self,
data: pd.DataFrame,
name: str,
epochs: int = 10,
max_training_time: int = 30,
batch_size: int = 32,
value_protection: bool = True,
rare_category_protection: bool = False,
flexible_generation: bool = False,
model_size: str = "MEDIUM",
target_accuracy: float = 0.95,
validation_split: float = 0.2,
learning_rate: float = 0.001,
early_stopping_patience: int = 10,
dropout_rate: float = 0.1,
weight_decay: float = 0.0001,
) -> Tuple[bool, str]:
if not self.mostly:
return False, "Mostly AI SDK not initialized. Please initialize the SDK first."
try:
self.original_data = data
train_config = {
"tables": [
{
"name": name,
"data": data,
"tabular_model_configuration": {
"max_epochs": epochs,
"max_training_time": max_training_time,
"value_protection": value_protection,
"batch_size": batch_size,
"rare_category_protection": rare_category_protection,
"flexible_generation": flexible_generation,
"model_size": model_size, # "SMALL" | "MEDIUM" | "LARGE"
"target_accuracy": target_accuracy, # early stop once target met
"validation_split": validation_split,
"learning_rate": learning_rate,
"early_stopping_patience": early_stopping_patience,
"dropout_rate": dropout_rate,
"weight_decay": weight_decay,
},
}
]
}
self.generator = self.mostly.train(config=train_config)
return True, f"Training completed successfully. Model name: {name}"
except Exception as e:
return False, f"Training failed with error: {str(e)}"
def generate_synthetic_data(self, size: int) -> Tuple[Optional[pd.DataFrame], str]:
if not self.generator:
return None, "No trained generator available. Please train a model first."
try:
synthetic_data = self.mostly.generate(self.generator, size=size)
df = synthetic_data.data()
return df, f"Synthetic data generated successfully. {len(df)} records created."
except Exception as e:
return None, f"Synthetic data generation failed with error: {str(e)}"
def estimate_memory_usage(self, df: pd.DataFrame) -> str:
if df is None or df.empty:
return "No data available to analyze."
memory_mb = df.memory_usage(deep=True).sum() / (1024 * 1024)
rows, cols = len(df), len(df.columns)
estimated_training_mb = memory_mb * 4
status = "Good" if memory_mb < 100 else ("Large" if memory_mb < 500 else "Very Large")
return f"""
Memory Usage Estimate:
- Data size: {memory_mb:.1f} MB
- Estimated training memory: {estimated_training_mb:.1f} MB
- Status: {status}
- Rows: {rows:,} | Columns: {cols}
""".strip()
# --- App state ---
generator = SyntheticDataGenerator()
_last_synth_df: Optional[pd.DataFrame] = None
# ---- Gradio wrappers ----
def initialize_sdk() -> str:
ok, msg = generator.initialize_mostly_ai()
return ("Success: " if ok else "Error: ") + msg
def train_model(
data: pd.DataFrame,
model_name: str,
epochs: int,
max_training_time: int,
batch_size: int,
value_protection: bool,
rare_category_protection: bool,
flexible_generation: bool,
model_size: str,
target_accuracy: float,
validation_split: float,
learning_rate: float,
early_stopping_patience: int,
dropout_rate: float,
weight_decay: float,
) -> str:
if data is None or data.empty:
return "Error: No data provided. Please upload or create sample data first."
ok, msg = generator.train_generator(
data=data,
name=model_name,
epochs=epochs,
max_training_time=max_training_time,
batch_size=batch_size,
value_protection=value_protection,
rare_category_protection=rare_category_protection,
flexible_generation=flexible_generation,
model_size=model_size,
target_accuracy=target_accuracy,
validation_split=validation_split,
learning_rate=learning_rate,
early_stopping_patience=early_stopping_patience,
dropout_rate=dropout_rate,
weight_decay=weight_decay,
)
return ("Success: " if ok else "Error: ") + msg
def generate_data(size: int) -> Tuple[Optional[pd.DataFrame], str]:
global _last_synth_df
synth_df, message = generator.generate_synthetic_data(size)
if synth_df is not None:
_last_synth_df = synth_df.copy()
return synth_df, f"Success: {message}"
else:
return None, f"Error: {message}"
def download_csv_prepare() -> Optional[str]:
"""Return a path to the latest synthetic CSV; used as output to gr.File."""
global _last_synth_df
if _last_synth_df is None or _last_synth_df.empty:
return None
os.makedirs("/tmp", exist_ok=True)
path = "/tmp/synthetic_data.csv"
_last_synth_df.to_csv(path, index=False)
return path
def create_comparison_plot(original_df: pd.DataFrame, synthetic_df: pd.DataFrame):
if original_df is None or synthetic_df is None:
return None
numeric_cols = original_df.select_dtypes(include=[np.number]).columns.tolist()
if not numeric_cols:
return None
n_cols = min(3, len(numeric_cols))
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=numeric_cols[: n_rows * n_cols])
for i, col in enumerate(numeric_cols[: n_rows * n_cols]):
row = i // n_cols + 1
col_idx = i % n_cols + 1
fig.add_trace(go.Histogram(x=original_df[col], name=f"Original {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
fig.add_trace(go.Histogram(x=synthetic_df[col], name=f"Synthetic {col}", opacity=0.7, nbinsx=20), row=row, col=col_idx)
fig.update_layout(title="Original vs Synthetic Data Comparison", height=300 * n_rows, showlegend=True)
return fig
# ---- UI ----
def create_interface():
with gr.Blocks(title="MOSTLY AI Synthetic Data Generator", theme=gr.themes.Soft()) as demo:
gr.Image(
value="https://img.mailinblue.com/8225865/images/content_library/original/6880d164e4e4ea1a183ad4c0.png",
show_label=False,
elem_id="header-image",
)
gr.Markdown(
"""
# Synthetic Data SDK by MOSTLY AI Demo Space
[Documentation](https://mostly-ai.github.io/mostlyai/) | [Technical White Paper](https://arxiv.org/abs/2508.00718) | [Usage Examples](https://mostly-ai.github.io/mostlyai/usage/) | [Free Cloud Service](https://app.mostly.ai/)
A Python toolkit for generating high-fidelity, privacy-safe synthetic data.
"""
)
with gr.Tab("Quick Start"):
gr.Markdown("### Initialize the SDK and upload your data")
with gr.Row():
with gr.Column():
init_btn = gr.Button("Initialize Mostly AI SDK", variant="primary")
init_status = gr.Textbox(label="Initialization Status", interactive=False)
with gr.Column():
gr.Markdown(
"""
**Next Steps:**
1. Initialize the SDK
2. Go to the "Upload Data and Train Model" tab to upload your CSV file
3. Train a model on your data
4. Generate synthetic data
"""
)
with gr.Tab("Upload Data and Train Model"):
gr.Markdown("### Upload your CSV file to generate synthetic data")
gr.Markdown(
"""
**File Requirements:**
- Format: CSV with header row
- Size: Optimized for Hugging Face Spaces (2 vCPU, 16GB RAM)
"""
)
file_upload = gr.File(label="Upload CSV File", file_types=[".csv"], file_count="single")
uploaded_data = gr.Dataframe(label="Uploaded Data", interactive=False)
memory_info = gr.Markdown(label="Memory Usage Info", visible=False)
with gr.Row():
with gr.Column(scale=1):
model_name = gr.Textbox(
value="My Synthetic Model",
label="Generator Name",
placeholder="Enter a name for your generator",
info="Appears in training runs and saved generators."
)
epochs = gr.Slider(
1, 200, value=100, step=1, label="Training Epochs",
info="Maximum number of passes over the training data."
)
max_training_time = gr.Slider(
1, 1000, value=60, step=1, label="Maximum Training Time (minutes)",
info="Upper bound in minutes; training stops if exceeded."
)
batch_size = gr.Slider(
8, 1024, value=32, step=8, label="Batch Size",
info="Number of rows per optimization step. Larger can speed up but needs more memory."
)
value_protection = gr.Checkbox(
label="Value Protection",
info="Adds protections to reduce memorization of unique or sensitive values.",
value=False
)
rare_category_protection = gr.Checkbox(
label="Rare Category Protection",
info="Prevents overfitting to infrequent categories to improve privacy and robustness.",
value=False
)
with gr.Column(scale=1):
flexible_generation = gr.Checkbox(
label="Flexible Generation",
info="Allows generation when inputs slightly differ from training schema.",
value=True
)
model_size = gr.Dropdown(
choices=["SMALL", "MEDIUM", "LARGE"],
value="MEDIUM",
label="Model Size",
info="Sets model capacity. Larger can improve fidelity but uses more compute."
)
target_accuracy = gr.Slider(
0.50, 0.999, value=0.95, step=0.001, label="Target Accuracy",
info="Stop early when validation accuracy reaches this threshold."
)
validation_split = gr.Slider(
0.05, 0.5, value=0.2, step=0.01, label="Validation Split",
info="Fraction of the dataset held out for validation during training."
)
early_stopping_patience = gr.Slider(
0, 50, value=10, step=1, label="Early Stopping Patience (epochs)",
info="Stop if no validation improvement after this many epochs."
)
with gr.Column(scale=1):
learning_rate = gr.Number(
value=0.001, precision=6, label="Learning Rate",
info="Step size for the optimizer. Typical range: 1e-4 to 1e-2."
)
dropout_rate = gr.Slider(
0.0, 0.6, value=0.1, step=0.01, label="Dropout Rate",
info="Regularization to reduce overfitting by randomly dropping units."
)
weight_decay = gr.Number(
value=0.0001, precision=6, label="Weight Decay",
info="L2 regularization strength applied to model weights."
)
train_btn = gr.Button("Train Model", variant="primary")
train_status = gr.Textbox(label="Training Status", interactive=False)
with gr.Tab("Generate Data"):
gr.Markdown("### Generate synthetic data from your trained model")
with gr.Row():
with gr.Column():
gen_size = gr.Slider(10, 1000, value=100, step=10, label="Number of Records to Generate",
info="How many synthetic rows to create in the table.")
generate_btn = gr.Button("Generate Synthetic Data", variant="primary")
with gr.Column():
gen_status = gr.Textbox(label="Generation Status", interactive=False)
synthetic_data = gr.Dataframe(label="Synthetic Data", interactive=False)
with gr.Row():
csv_download_btn = gr.Button("Download CSV", variant="secondary")
with gr.Group(visible=False) as csv_group:
csv_file = gr.File(label="Synthetic CSV", interactive=False)
comparison_plot = gr.Plot(label="Data Comparison")
# ---- Events ----
init_btn.click(initialize_sdk, outputs=[init_status])
train_btn.click(
train_model,
inputs=[
uploaded_data, model_name,
epochs, max_training_time, batch_size,
value_protection, rare_category_protection, flexible_generation,
model_size, target_accuracy, validation_split,
learning_rate, early_stopping_patience, dropout_rate, weight_decay
],
outputs=[train_status],
)
generate_btn.click(generate_data, inputs=[gen_size], outputs=[synthetic_data, gen_status])
synthetic_data.change(create_comparison_plot, inputs=[uploaded_data, synthetic_data], outputs=[comparison_plot])
def _prepare_csv_for_download():
path = download_csv_prepare()
if path:
return path, gr.update(visible=True)
else:
return None, gr.update(visible=False)
csv_download_btn.click(
_prepare_csv_for_download,
outputs=[csv_file, csv_group],
)
def process_uploaded_file(file):
if file is None:
return None, "No file uploaded.", gr.update(visible=False)
try:
df = pd.read_csv(file.name)
success_msg = f"File uploaded successfully. {len(df)} rows × {len(df.columns)} columns"
mem_info = generator.estimate_memory_usage(df)
return df, success_msg, gr.update(value=mem_info, visible=True)
except Exception as e:
return None, f"Error reading file: {str(e)}", gr.update(visible=False)
file_upload.change(process_uploaded_file, inputs=[file_upload], outputs=[uploaded_data, train_status, memory_info])
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)