Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import gradio as gr | |
from gradio.utils import get_space | |
from modal_sandbox import create_modal_sandbox | |
from pathlib import Path | |
import json | |
from datetime import datetime | |
import threading | |
import re | |
from openai import OpenAI, AzureOpenAI | |
from jupyter_handler import JupyterNotebook | |
if not get_space(): | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except (ImportError, ModuleNotFoundError): | |
pass | |
from jupyter_agent import ( | |
run_interactive_notebook_with_session_state, | |
SessionStateManager, | |
) | |
TMP_DIR = './temp/' | |
# Environment and API key management utilities | |
def get_environment(): | |
"""Get the current environment (dev/prod)""" | |
return os.environ.get("ENVIRONMENT", "prod").lower() | |
def is_dev_environment(): | |
"""Check if running in development environment""" | |
return get_environment() == "dev" | |
def get_required_api_keys(): | |
"""Get dictionary of required API keys and their current status""" | |
required_keys = { | |
"MODAL_TOKEN_ID": { | |
"value": os.environ.get("MODAL_TOKEN_ID"), | |
"required": True, | |
"description": "Modal Token ID for sandbox access" | |
}, | |
"MODAL_TOKEN_SECRET": { | |
"value": os.environ.get("MODAL_TOKEN_SECRET"), | |
"required": True, | |
"description": "Modal Token Secret for sandbox access" | |
}, | |
"HF_TOKEN": { | |
"value": os.environ.get("HF_TOKEN"), | |
"required": False, | |
"description": "Hugging Face Token for model access" | |
}, | |
"PROVIDER_API_KEY": { | |
"value": os.environ.get("PROVIDER_API_KEY"), | |
"required": True, | |
"description": "AI Provider API Key (Anthropic, OpenAI, etc.)" | |
}, | |
"PROVIDER_API_ENDPOINT": { | |
"value": os.environ.get("PROVIDER_API_ENDPOINT"), | |
"required": True, | |
"description": "AI Provider API Endpoint" | |
}, | |
"MODEL_NAME": { | |
"value": os.environ.get("MODEL_NAME"), | |
"required": True, | |
"description": "Model name to use" | |
}, | |
"TAVILY_API_KEY": { | |
"value": os.environ.get("TAVILY_API_KEY"), | |
"required": False, | |
"description": "Tavily API Key for web search functionality" | |
} | |
} | |
return required_keys | |
def get_missing_api_keys(): | |
"""Get list of missing required API keys""" | |
required_keys = get_required_api_keys() | |
missing_keys = {} | |
for key, config in required_keys.items(): | |
if config["required"] and not config["value"]: | |
missing_keys[key] = config | |
return missing_keys | |
def validate_api_key_format(key_name, key_value): | |
"""Basic validation for API key formats""" | |
if not key_value or not key_value.strip(): | |
return False, "API key cannot be empty" | |
key_value = key_value.strip() | |
# Basic format validation | |
if key_name == "MODAL_TOKEN_ID" and not key_value.startswith("ak-"): | |
return False, "Modal Token ID should start with 'ak-'" | |
elif key_name == "MODAL_TOKEN_SECRET" and not key_value.startswith("as-"): | |
return False, "Modal Token Secret should start with 'as-'" | |
elif key_name == "HF_TOKEN" and not key_value.startswith("hf_"): | |
return False, "Hugging Face token should start with 'hf_'" | |
elif key_name == "PROVIDER_API_KEY": | |
# Check for common API key prefixes | |
valid_prefixes = ["sk-", "gsk_", "csk-"] | |
if not any(key_value.startswith(prefix) for prefix in valid_prefixes): | |
return False, "API key format may be invalid (expected prefixes: sk-, gsk_, csk-)" | |
elif key_name == "PROVIDER_API_ENDPOINT" and not (key_value.startswith("http://") or key_value.startswith("https://")): | |
return False, "API endpoint should start with http:// or https://" | |
elif key_name == "TAVILY_API_KEY" and not key_value.startswith("tvly-"): | |
return False, "Tavily API key should start with 'tvly-'" | |
return True, "Valid format" | |
def apply_user_api_keys(api_keys_dict): | |
"""Apply user-provided API keys to environment""" | |
for key, value in api_keys_dict.items(): | |
if value and value.strip(): | |
os.environ[key] = value.strip() | |
logger.info(f"Applied user-provided API key: {key}") | |
def get_previous_notebooks(): | |
"""Get list of previous notebook sessions (dev only)""" | |
if not is_dev_environment(): | |
return [] | |
notebooks = [] | |
tmp_dir = Path(TMP_DIR) | |
if not tmp_dir.exists(): | |
return notebooks | |
for session_dir in tmp_dir.iterdir(): | |
if session_dir.is_dir() and session_dir.name != ".": | |
notebook_file = session_dir / "jupyter-agent.ipynb" | |
if notebook_file.exists(): | |
try: | |
# Get creation time and basic info | |
stat = notebook_file.stat() | |
size = stat.st_size | |
modified = stat.st_mtime | |
# Try to read basic notebook info | |
with open(notebook_file, 'r') as f: | |
notebook_data = json.load(f) | |
cell_count = len(notebook_data.get('cells', [])) | |
# Format timestamp | |
formatted_time = datetime.fromtimestamp(modified).strftime("%Y-%m-%d %H:%M") | |
# Try to load session state for additional info | |
config_info = "" | |
try: | |
session_manager = SessionStateManager(session_dir.name, TMP_DIR) | |
session_state = session_manager.load_state() | |
if session_state: | |
hardware = session_state.get("hardware_config", {}) | |
gpu = hardware.get("gpu_type", "unknown") | |
config_info = f", {gpu}" | |
except Exception: | |
pass | |
notebooks.append({ | |
'session_id': session_dir.name, | |
'path': str(notebook_file), | |
'modified': modified, | |
'size': size, | |
'cell_count': cell_count, | |
'display_name': f"{session_dir.name} ({cell_count} cells{config_info}, {formatted_time})" | |
}) | |
except Exception as e: | |
logger.warning(f"Failed to read notebook info for {session_dir.name}: {e}") | |
# Sort by modification time (newest first) | |
notebooks.sort(key=lambda x: x['modified'], reverse=True) | |
return notebooks | |
def parse_environment_variables(env_vars_text): | |
""" | |
Parse environment variables from text input | |
Args: | |
env_vars_text: String containing environment variables in KEY=value format, one per line | |
Returns: | |
dict: Dictionary of parsed environment variables | |
""" | |
env_dict = {} | |
if not env_vars_text or not env_vars_text.strip(): | |
return env_dict | |
for line in env_vars_text.strip().split('\n'): | |
line = line.strip() | |
if not line or line.startswith('#'): # Skip empty lines and comments | |
continue | |
if '=' in line: | |
key, value = line.split('=', 1) # Split only on first = | |
key = key.strip() | |
value = value.strip() | |
if key: # Only add if key is not empty | |
env_dict[key] = value | |
else: | |
logger.warning(f"Skipping invalid environment variable format: {line}") | |
return env_dict | |
def create_notification_html(message, notification_type="info", show_spinner=False): | |
""" | |
Create HTML for notification messages | |
Args: | |
message: The notification message | |
notification_type: Type of notification ('info', 'success', 'warning', 'error') | |
show_spinner: Whether to show a loading spinner | |
""" | |
colors = { | |
'info': '#3498db', | |
'success': '#27ae60', | |
'warning': '#f39c12', | |
'error': '#e74c3c', | |
'loading': '#6c5ce7' | |
} | |
icons = { | |
'info': '🔄', | |
'success': '✅', | |
'warning': '⚠️', | |
'error': '❌', | |
'loading': '⏳' | |
} | |
color = colors.get(notification_type, colors['info']) | |
icon = icons.get(notification_type, icons['info']) | |
spinner_html = "" | |
if show_spinner or notification_type == 'loading': | |
spinner_html = """ | |
<div style=" | |
display: inline-block; | |
width: 20px; | |
height: 20px; | |
border: 2px solid #f3f3f3; | |
border-top: 2px solid {color}; | |
border-radius: 50%; | |
animation: spin 1s linear infinite; | |
margin-right: 8px; | |
"></div> | |
<style> | |
@keyframes spin {{ | |
0% {{ transform: rotate(0deg); }} | |
100% {{ transform: rotate(360deg); }} | |
}} | |
</style> | |
""".format(color=color) | |
return f""" | |
<div style=" | |
background-color: {color}20; | |
border-left: 4px solid {color}; | |
padding: 12px 16px; | |
margin: 10px 0; | |
border-radius: 4px; | |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
font-size: 14px; | |
color: #2c3e50; | |
display: flex; | |
align-items: center; | |
"> | |
{spinner_html} | |
<strong>{icon} {message}</strong> | |
</div> | |
""" | |
def create_progress_notification(message, progress_percent=None): | |
"""Create a progress notification with optional progress bar""" | |
progress_html = "" | |
if progress_percent is not None: | |
progress_html = f""" | |
<div style=" | |
width: 100%; | |
background-color: #e0e0e0; | |
border-radius: 5px; | |
margin-top: 8px; | |
height: 8px; | |
"> | |
<div style=" | |
width: {progress_percent}%; | |
background-color: #3498db; | |
height: 8px; | |
border-radius: 5px; | |
transition: width 0.3s ease; | |
"></div> | |
</div> | |
<small style="color: #666; margin-top: 4px; display: block;">{progress_percent}% complete</small> | |
""" | |
return create_notification_html(message, "loading", show_spinner=True) + progress_html | |
def initialize_phoenix_tracing(): | |
"""Initialize Phoenix tracing with proper error handling and session support""" | |
try: | |
from phoenix.otel import register | |
phoenix_api_key = os.getenv("PHOENIX_API_KEY") | |
collector_endpoint = os.getenv("PHOENIX_COLLECTOR_ENDPOINT") | |
if not phoenix_api_key: | |
logger.info("Phoenix API key not found, skipping Phoenix tracing initialization") | |
return None | |
if not collector_endpoint: | |
logger.info("Phoenix collector endpoint not found, skipping Phoenix tracing initialization") | |
return None | |
logger.info("Initializing Phoenix tracing with session support...") | |
# Set required environment variables | |
os.environ["PHOENIX_API_KEY"] = phoenix_api_key | |
os.environ["PHOENIX_COLLECTOR_ENDPOINT"] = collector_endpoint | |
os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={phoenix_api_key}" | |
os.environ["PHOENIX_CLIENT_HEADERS"] = f"api_key={phoenix_api_key}" | |
# Configure the Phoenix tracer with OpenAI instrumentation enabled | |
tracer_provider = register( | |
project_name="eureka-agent", | |
auto_instrument=True, # Keep auto-instrument enabled for OpenAI tracing | |
set_global_tracer_provider=True | |
) | |
# Additional instrumentation setup for session tracking | |
try: | |
from openinference.instrumentation.openai import OpenAIInstrumentor | |
# Ensure OpenAI instrumentation is properly configured | |
if not OpenAIInstrumentor().is_instrumented_by_opentelemetry: | |
OpenAIInstrumentor().instrument() | |
logger.info("OpenAI instrumentation configured for Phoenix session tracking") | |
else: | |
logger.info("OpenAI instrumentation already active") | |
except ImportError: | |
logger.warning("OpenAI instrumentation not available - session grouping may not work optimally") | |
except Exception as e: | |
logger.warning(f"Failed to configure OpenAI instrumentation: {str(e)}") | |
logger.info("Phoenix tracing initialized successfully with session support") | |
return tracer_provider | |
except ImportError: | |
logger.info("Phoenix not installed, skipping tracing initialization") | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to initialize Phoenix tracer (non-critical): {str(e)}") | |
return None | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('jupyter_agent.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize Phoenix tracing | |
tracer_provider = initialize_phoenix_tracing() | |
MODAL_TOKEN_ID = os.environ.get("MODAL_TOKEN_ID") | |
MODAL_TOKEN_SECRET = os.environ.get("MODAL_TOKEN_SECRET") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
SANDBOXES = {} | |
SANDBOX_TIMEOUT = 300 | |
STOP_EVENTS = {} # Store stop events for each session | |
EXECUTION_STATES = {} # Store execution states for each session | |
# GPU configuration options for the UI | |
GPU_OPTIONS = [ | |
("CPU Only", "cpu"), | |
("NVIDIA T4 (16GB)", "T4"), | |
("NVIDIA L4 (24GB)", "L4"), | |
("NVIDIA A100 40GB", "A100-40GB"), | |
("NVIDIA A100 80GB", "A100-80GB"), | |
("NVIDIA H100 (80GB)", "H100") | |
] | |
def initialize_openai_client(): | |
"""Initialize OpenAI client with proper error handling and fallbacks""" | |
client = None | |
model_name = None | |
# Check if we have any API keys configured | |
has_azure = os.environ.get("AZURE_OPENAI_ENDPOINT") and os.environ.get("AZURE_OPENAI_API_KEY") | |
has_provider = os.environ.get("PROVIDER_API_ENDPOINT") and os.environ.get("PROVIDER_API_KEY") | |
has_openai = os.environ.get("OPENAI_API_KEY") | |
if not (has_azure or has_provider or has_openai): | |
logger.warning("No API keys found in environment - client will be initialized later when user provides keys") | |
return None, None | |
try: | |
# Option 1: Azure OpenAI | |
if has_azure: | |
logger.info("Initializing Azure OpenAI client") | |
client = AzureOpenAI( | |
api_version="2024-12-01-preview", | |
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"), | |
api_key=os.environ.get("AZURE_OPENAI_API_KEY") | |
) | |
model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default fallback | |
logger.info(f"Azure OpenAI client initialized with model: {model_name}") | |
# Option 2: Custom Provider (Cerebras, etc.) | |
elif has_provider: | |
logger.info("Initializing custom provider OpenAI client") | |
client = OpenAI( | |
base_url=os.environ.get("PROVIDER_API_ENDPOINT"), | |
api_key=os.environ.get("PROVIDER_API_KEY") | |
) | |
model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default fallback | |
logger.info(f"Custom provider client initialized with model: {model_name}") | |
# Option 3: Standard OpenAI | |
elif has_openai: | |
logger.info("Initializing standard OpenAI client") | |
client = OpenAI( | |
api_key=os.environ.get("OPENAI_API_KEY") | |
) | |
model_name = os.environ.get("MODEL_NAME", "gpt-4") # Default fallback | |
logger.info(f"OpenAI client initialized with model: {model_name}") | |
# Test the client with a simple request (optional - skip if client initialization should be fast) | |
if client: | |
logger.info("Testing client connection...") | |
try: | |
# Simple test to verify the client works | |
_ = client.chat.completions.create( | |
model=model_name, | |
messages=[{"role": "user", "content": "Hello"}], | |
max_tokens=5 | |
) | |
logger.info("Client connection test successful") | |
except Exception as test_error: | |
logger.error(f"Client connection test failed: {str(test_error)}") | |
# Don't raise here, let the main application handle it | |
return client, model_name | |
except Exception as e: | |
logger.error(f"Failed to initialize OpenAI client: {str(e)}") | |
logger.warning("Client will be initialized later when user provides valid API keys") | |
return None, None | |
client, model_name = initialize_openai_client() | |
# If no client was initialized, it means no API keys are available | |
if client is None: | |
logger.info("No OpenAI client initialized - waiting for user to provide API keys through UI") | |
init_notebook = JupyterNotebook() | |
if not os.path.exists(TMP_DIR): | |
os.makedirs(TMP_DIR) | |
logger.info(f"Created temporary directory: {TMP_DIR}") | |
else: | |
logger.info(f"Using existing temporary directory: {TMP_DIR}") | |
with open(TMP_DIR+"jupyter-agent.ipynb", 'w', encoding='utf-8') as f: | |
json.dump(JupyterNotebook().data, f, indent=2) | |
logger.info(f"Initialized default notebook file: {TMP_DIR}jupyter-agent.ipynb") | |
try: | |
with open("system_prompt.txt", "r") as f: | |
DEFAULT_SYSTEM_PROMPT = f.read() | |
logger.info("Loaded system prompt from ds-system-prompt.txt") | |
except FileNotFoundError: | |
logger.warning("ds-system-prompt.txt not found, using fallback system prompt") | |
def execute_jupyter_agent( | |
user_input, files, message_history, gpu_type, cpu_cores, memory_gb, timeout_sec, env_vars_text, | |
modal_token_id, modal_token_secret, hf_token, provider_api_key, provider_api_endpoint, user_model_name, | |
tavily_api_key, enable_web_search, request: gr.Request | |
): | |
session_id = request.session_hash | |
logger.info(f"Starting execution for session {session_id}") | |
logger.info(f"Hardware config: GPU={gpu_type}, CPU={cpu_cores}, Memory={memory_gb}GB, Timeout={timeout_sec}s") | |
logger.info(f"User input length: {len(user_input)} characters") | |
# Check if execution is already running for this session | |
if session_id in EXECUTION_STATES and EXECUTION_STATES[session_id].get("running", False): | |
error_message = "❌ Execution already in progress for this session. Please wait for it to complete or stop it first." | |
error_notification = create_notification_html(error_message, "warning") | |
# Return current state without starting new execution | |
session_dir = os.path.join(TMP_DIR, session_id) | |
save_dir = os.path.join(session_dir, 'jupyter-agent.ipynb') | |
if os.path.exists(save_dir): | |
yield error_notification, message_history, save_dir | |
else: | |
yield error_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
return | |
# Initialize session state manager | |
session_manager = SessionStateManager(session_id, TMP_DIR) | |
# Check if this is a continuing session | |
existing_session_state = session_manager.load_state() | |
is_continuing_session = existing_session_state is not None | |
if is_continuing_session: | |
logger.info(f"Found existing session state for {session_id} - continuing from previous state") | |
else: | |
logger.info(f"No existing session state found for {session_id} - starting new session") | |
# Apply user-provided API keys if any are provided | |
user_api_keys = {} | |
if modal_token_id: | |
user_api_keys["MODAL_TOKEN_ID"] = modal_token_id | |
if modal_token_secret: | |
user_api_keys["MODAL_TOKEN_SECRET"] = modal_token_secret | |
if hf_token: | |
user_api_keys["HF_TOKEN"] = hf_token | |
if provider_api_key: | |
user_api_keys["PROVIDER_API_KEY"] = provider_api_key | |
if provider_api_endpoint: | |
user_api_keys["PROVIDER_API_ENDPOINT"] = provider_api_endpoint | |
if user_model_name: | |
user_api_keys["MODEL_NAME"] = user_model_name | |
if tavily_api_key: | |
user_api_keys["TAVILY_API_KEY"] = tavily_api_key | |
# Check if we have a client or need to initialize one with user keys | |
global client, model_name | |
if client is None and not user_api_keys: | |
missing_keys = get_missing_api_keys() | |
if missing_keys: | |
error_message = f"""❌ Missing Required API Keys | |
Please provide the following API keys to continue: | |
{chr(10).join([f"• {key}: {config['description']}" for key, config in missing_keys.items()])} | |
You can either: | |
1. Add them to your .env file, or | |
2. Enter them in the API Keys section above""" | |
error_notification = create_notification_html(error_message, "error") | |
yield error_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
return | |
# Validate user-provided API keys | |
if user_api_keys: | |
validation_message = "🔍 Validating API keys..." | |
validation_notification = create_progress_notification(validation_message) | |
yield validation_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
validation_errors = [] | |
for key, value in user_api_keys.items(): | |
is_valid, message = validate_api_key_format(key, value) | |
if not is_valid: | |
validation_errors.append(f"{key}: {message}") | |
if validation_errors: | |
error_message = f"❌ API Key Validation Failed:\n" + "\n".join(f"• {error}" for error in validation_errors) | |
error_notification = create_notification_html(error_message, "error") | |
yield error_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
return | |
logger.info(f"Applying user-provided API keys: {list(user_api_keys.keys())}") | |
apply_user_api_keys(user_api_keys) | |
# Reinitialize OpenAI client with new keys if provider keys were updated | |
if any(key in user_api_keys for key in ["PROVIDER_API_KEY", "PROVIDER_API_ENDPOINT", "MODEL_NAME"]): | |
try: | |
reinit_message = "🔄 Reinitializing AI client with new credentials..." | |
reinit_notification = create_progress_notification(reinit_message) | |
yield reinit_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
client, model_name = initialize_openai_client() | |
if client is None: | |
error_message = "Failed to initialize client with provided API keys. Please check your credentials." | |
logger.error(error_message) | |
error_notification = create_notification_html(error_message, "error") | |
yield error_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
return | |
logger.info("Reinitialized OpenAI client with user-provided keys") | |
success_message = "✅ API credentials validated and applied successfully!" | |
success_notification = create_notification_html(success_message, "success") | |
yield success_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
except Exception as e: | |
error_message = f"Failed to initialize client with provided API keys: {str(e)}" | |
logger.error(error_message) | |
error_notification = create_notification_html(error_message, "error") | |
yield error_notification, message_history, TMP_DIR + "jupyter-agent.ipynb" | |
return | |
# Initialize or reset stop event for this session | |
STOP_EVENTS[session_id] = threading.Event() | |
EXECUTION_STATES[session_id] = {"running": True, "paused": False, "current_phase": "initializing"} | |
# Set up save directory early for notifications | |
session_dir = os.path.join(TMP_DIR, request.session_hash) | |
os.makedirs(session_dir, exist_ok=True) | |
save_dir = os.path.join(session_dir, 'jupyter-agent.ipynb') | |
# Create initial notebook file so it exists for Gradio | |
with open(save_dir, 'w', encoding='utf-8') as f: | |
json.dump(init_notebook.data, f, indent=2) | |
logger.info(f"Initialized notebook for session {session_id}") | |
# Session configuration is now handled by SessionStateManager | |
if request.session_hash not in SANDBOXES: | |
logger.info(f"Creating new Modal sandbox for session {session_id}") | |
# Show initialization notification with spinner | |
gpu_info = gpu_type.upper() if gpu_type != "cpu" else "CPU Only" | |
if gpu_type in ["T4", "L4", "A100-40GB", "A100-80GB", "H100"]: | |
gpu_info = f"NVIDIA {gpu_type}" | |
init_message = f"Initializing {gpu_info} sandbox with {cpu_cores} CPU cores and {memory_gb}GB RAM..." | |
notification_html = create_progress_notification(init_message) | |
yield notification_html, message_history, save_dir | |
# Create Modal sandbox with user-specified configuration | |
environment_vars = {} | |
if MODAL_TOKEN_ID and MODAL_TOKEN_SECRET: | |
environment_vars.update({ | |
"MODAL_TOKEN_ID": MODAL_TOKEN_ID, | |
"MODAL_TOKEN_SECRET": MODAL_TOKEN_SECRET | |
}) | |
logger.debug(f"Modal credentials configured for session {session_id}") | |
# Parse and add user-provided environment variables | |
user_env_vars = parse_environment_variables(env_vars_text) | |
if user_env_vars: | |
environment_vars.update(user_env_vars) | |
logger.info(f"Added {len(user_env_vars)} custom environment variables for session {session_id}") | |
logger.debug(f"Custom environment variables: {list(user_env_vars.keys())}") | |
try: | |
SANDBOXES[request.session_hash] = create_modal_sandbox( | |
gpu_config=gpu_type, | |
cpu_cores=cpu_cores, | |
memory_gb=memory_gb, | |
timeout=int(timeout_sec), | |
environment_vars=environment_vars | |
) | |
logger.info(f"Successfully created Modal sandbox for session {session_id}") | |
# Show success notification | |
success_message = f"✨ {gpu_info} sandbox ready! Environment initialized with all packages." | |
success_notification_html = create_notification_html(success_message, "success") | |
yield success_notification_html, message_history, save_dir | |
except Exception as e: | |
logger.error(f"Failed to create Modal sandbox for session {session_id}: {str(e)}") | |
# Show error notification | |
error_message = f"Failed to initialize sandbox: {str(e)}" | |
error_notification_html = create_notification_html(error_message, "error") | |
yield error_notification_html, message_history, save_dir | |
raise | |
else: | |
logger.info(f"Reusing existing Modal sandbox for session {session_id}") | |
# Show reuse notification | |
gpu_info = gpu_type.upper() if gpu_type != "cpu" else "CPU Only" | |
if gpu_type in ["T4", "L4", "A100-40GB", "A100-80GB", "H100"]: | |
gpu_info = f"NVIDIA {gpu_type}" | |
reuse_message = f"Using existing {gpu_info} sandbox - ready to execute!" | |
reuse_notification_html = create_notification_html(reuse_message, "success") | |
yield reuse_notification_html, message_history, save_dir | |
sbx = SANDBOXES[request.session_hash] | |
logger.debug(f"Notebook will be saved to: {save_dir}") | |
# Initial notebook render | |
yield init_notebook.render(), message_history, save_dir | |
filenames = [] | |
if files is not None: | |
logger.info(f"Processing {len(files)} uploaded files for session {session_id}") | |
for filepath in files: | |
filpath = Path(filepath) | |
try: | |
# Get file size for verification | |
file_size = os.path.getsize(filepath) | |
with open(filepath, "rb") as file: | |
logger.info(f"Uploading file {filepath} ({file_size} bytes) to session {session_id}") | |
sbx.files.write(filpath.name, file) | |
# Verify upload succeeded | |
if sbx.files.verify_file_upload(filpath.name, file_size): | |
filenames.append(filpath.name) | |
logger.debug(f"Successfully uploaded and verified {filpath.name}") | |
else: | |
logger.error(f"File upload verification failed for {filpath.name}") | |
raise RuntimeError(f"File upload verification failed for {filpath.name}") | |
except Exception as e: | |
logger.error(f"Failed to upload file {filepath} for session {session_id}: {str(e)}") | |
raise | |
else: | |
logger.info(f"No files to upload for session {session_id}") | |
# Initialize or continue session state | |
if is_continuing_session: | |
# Load existing session state | |
session_state = existing_session_state | |
# Validate and repair conversation history to prevent API errors | |
session_manager.validate_and_repair_conversation(session_state) | |
message_history = session_manager.get_conversation_history(session_state) | |
logger.info(f"Continuing session {session_id} with {len(message_history)} existing messages") | |
# Add new user input if provided | |
if user_input and user_input.strip(): | |
# Check if this input was already added by comparing with the last message | |
last_message = message_history[-1] if message_history else None | |
should_add_input = True | |
if last_message and last_message.get("role") == "user": | |
# If the last message is from user and has the same content, don't add duplicate | |
if last_message.get("content") == user_input: | |
should_add_input = False | |
logger.debug(f"User input already present in session {session_id}") | |
if should_add_input: | |
session_manager.add_message(session_state, "user", user_input) | |
message_history = session_manager.get_conversation_history(session_state) | |
logger.info(f"Added new user input to existing session {session_id}") | |
# Show notification that we're continuing the conversation | |
continue_message = "🔄 Continuing conversation with new input..." | |
continue_notification = create_progress_notification(continue_message) | |
yield continue_notification, message_history, save_dir | |
else: | |
# Create new session state | |
logger.info(f"Initializing new session {session_id}") | |
# Format files section | |
if files is None: | |
files_section = "- None" | |
else: | |
files_section = "- " + "\n- ".join(filenames) | |
logger.info(f"System prompt includes {len(filenames)} files: {filenames}") | |
# Format GPU information | |
gpu_info = gpu_type.upper() if gpu_type != "cpu" else "CPU Only" | |
if gpu_type in ["T4", "L4", "A100-40GB", "A100-80GB", "H100"]: | |
gpu_info = f"NVIDIA {gpu_type}" | |
# Format available packages based on hardware configuration | |
packages_list = sbx.available_packages | |
packages_section = "\n".join([f"- {package}" for package in packages_list]) | |
# Format the complete system prompt with named placeholders | |
system_prompt = DEFAULT_SYSTEM_PROMPT.replace("{AVAILABLE_FILES}", files_section) | |
system_prompt = system_prompt.replace("{GPU_TYPE}", gpu_info) | |
system_prompt = system_prompt.replace("{CPU_CORES}", str(cpu_cores)) | |
system_prompt = system_prompt.replace("{MEMORY_GB}", str(memory_gb)) | |
system_prompt = system_prompt.replace("{TIMEOUT_SECONDS}", str(timeout_sec)) | |
system_prompt = system_prompt.replace("{AVAILABLE_PACKAGES}", packages_section) | |
# Create session state with configuration | |
hardware_config = { | |
"gpu_type": gpu_type, | |
"cpu_cores": cpu_cores, | |
"memory_gb": memory_gb, | |
"timeout_sec": timeout_sec | |
} | |
api_config = { | |
"model_name": model_name or user_model_name or "unknown", | |
"provider_endpoint": os.environ.get("PROVIDER_API_ENDPOINT") or provider_api_endpoint, | |
"provider_type": "openai_compatible" | |
} | |
environment_config = { | |
"variables": env_vars_text or "", | |
"files_uploaded": filenames if filenames else [] | |
} | |
# Create initial session state | |
session_state = session_manager.create_initial_state( | |
hardware_config, api_config, environment_config, system_prompt | |
) | |
# Add user input if provided | |
if user_input and user_input.strip(): | |
session_manager.add_message(session_state, "user", user_input) | |
# Get conversation history | |
message_history = session_manager.get_conversation_history(session_state) | |
# Save initial state | |
session_manager.save_state(session_state) | |
logger.info(f"Created new session {session_id} with {len(message_history)} messages") | |
logger.debug(f"Session {session_id} ready with {len(message_history)} messages") | |
# Determine which tools to use based on web search toggle | |
from jupyter_agent import TOOLS | |
if enable_web_search: | |
# Check if Tavily API key is available | |
tavily_key = os.environ.get("TAVILY_API_KEY") or tavily_api_key | |
if tavily_key: | |
selected_tools = TOOLS # Use all tools (code + search) | |
logger.info(f"Web search enabled for session {session_id} - using all tools") | |
else: | |
selected_tools = TOOLS[:1] # Use only code execution tool | |
logger.warning(f"Web search enabled but no Tavily API key found for session {session_id} - using code tool only") | |
else: | |
selected_tools = TOOLS[:1] # Use only code execution tool | |
logger.info(f"Web search disabled for session {session_id} - using code tool only") | |
logger.info(f"Starting interactive notebook execution for session {session_id}") | |
# Import Phoenix session context if available | |
try: | |
from jupyter_agent import create_phoenix_session_context | |
phoenix_available = True | |
except ImportError: | |
phoenix_available = False | |
# Prepare session metadata for Phoenix tracing at the session level | |
if phoenix_available: | |
session_level_metadata = { | |
"agent_type": "eureka-agent", | |
"session_type": "jupyter_execution", | |
"gpu_type": gpu_type, | |
"cpu_cores": cpu_cores, | |
"memory_gb": memory_gb, | |
"timeout_sec": timeout_sec, | |
"web_search_enabled": enable_web_search, | |
"tools_available": len(selected_tools) | |
} | |
# Add API provider info if available | |
if model_name: | |
session_level_metadata["model"] = model_name | |
session_context = create_phoenix_session_context( | |
session_id=session_id, | |
user_id=None, # Could add user identification if available | |
metadata=session_level_metadata | |
) | |
else: | |
from contextlib import nullcontext | |
session_context = nullcontext() | |
# Wrap the entire execution in a Phoenix session context | |
with session_context: | |
logger.debug(f"Starting session-level Phoenix tracing for {session_id}") | |
try: | |
for notebook_html, notebook_data, messages in run_interactive_notebook_with_session_state( | |
client, model_name, session_manager, session_state, sbx, STOP_EVENTS[session_id], selected_tools | |
): | |
message_history = messages | |
logger.debug(f"Interactive notebook yield for session {session_id}") | |
# Update session state and yield with legacy notebook file for UI compatibility | |
session_manager.update_notebook_data(session_state, notebook_data) | |
session_manager.save_state(session_state) | |
# Create legacy notebook file for UI download compatibility | |
with open(save_dir, 'w', encoding='utf-8') as f: | |
json.dump(notebook_data, f, indent=2) | |
yield notebook_html, message_history, save_dir | |
except Exception as e: | |
logger.error(f"Error during interactive notebook execution for session {session_id}: {str(e)}") | |
# Save error state | |
session_manager.update_execution_state(session_state, is_running=False, last_execution_successful=False) | |
session_manager.save_state(session_state) | |
raise | |
# Final save and cleanup | |
try: | |
session_manager.update_execution_state(session_state, is_running=False) | |
session_manager.save_state(session_state) | |
logger.info(f"Final session state saved for session {session_id}") | |
# Create final legacy notebook file for UI | |
with open(save_dir, 'w', encoding='utf-8') as f: | |
json.dump(notebook_data, f, indent=2) | |
except Exception as e: | |
logger.error(f"Failed to save final session state for session {session_id}: {str(e)}") | |
raise | |
yield notebook_html, message_history, save_dir | |
logger.info(f"Completed execution for session {session_id}") | |
# Update legacy execution state for compatibility | |
if session_id in EXECUTION_STATES: | |
EXECUTION_STATES[session_id]["running"] = False | |
def clear(msg_state, request: gr.Request): | |
"""Clear notebook but keep session data (less destructive than shutdown)""" | |
session_id = request.session_hash | |
logger.info(f"Clearing notebook for session {session_id}") | |
# Stop any running execution | |
if session_id in STOP_EVENTS: | |
STOP_EVENTS[session_id].set() | |
# Clear execution states but keep session data | |
if session_id in EXECUTION_STATES: | |
EXECUTION_STATES[session_id]["running"] = False | |
EXECUTION_STATES[session_id]["paused"] = False | |
EXECUTION_STATES[session_id]["current_phase"] = "ready" | |
# Reset message state for UI | |
msg_state = [] | |
logger.info(f"Reset notebook display for session {session_id}") | |
return init_notebook.render(), msg_state | |
def stop_execution(request: gr.Request): | |
"""Stop the current execution for this session""" | |
session_id = request.session_hash | |
logger.info(f"Stopping execution for session {session_id}") | |
if session_id in STOP_EVENTS and session_id in EXECUTION_STATES: | |
# Check if execution is actually running | |
if EXECUTION_STATES[session_id].get("running", False): | |
STOP_EVENTS[session_id].set() | |
logger.info(f"Stop signal sent for session {session_id}") | |
# Update execution state | |
EXECUTION_STATES[session_id]["running"] = False | |
EXECUTION_STATES[session_id]["paused"] = True | |
EXECUTION_STATES[session_id]["current_phase"] = "stopping" | |
# Also update session state if available | |
session_manager = SessionStateManager(session_id, TMP_DIR) | |
session_state = session_manager.load_state() | |
if session_state: | |
session_manager.update_execution_state( | |
session_state, is_running=False, is_paused=True, current_phase="stopping" | |
) | |
session_manager.save_state(session_state) | |
return "⏸️ Execution stopped - click Run to resume with new input" | |
else: | |
logger.info(f"No active execution to stop for session {session_id}") | |
return "⚪ No active execution to stop" | |
else: | |
logger.warning(f"No execution session found for {session_id}") | |
return "❌ No execution session found" | |
def shutdown_sandbox(request: gr.Request): | |
"""Shutdown the sandbox while preserving all session data and files""" | |
session_id = request.session_hash | |
logger.info(f"Shutting down sandbox for {session_id} (preserving all session data and files)") | |
try: | |
# 1. Stop any running execution first | |
if session_id in STOP_EVENTS: | |
STOP_EVENTS[session_id].set() | |
logger.info(f"Stopped execution for session {session_id}") | |
# 2. Shutdown Modal sandbox only | |
if session_id in SANDBOXES: | |
logger.info(f"Killing Modal sandbox for session {session_id}") | |
SANDBOXES[session_id].kill() | |
SANDBOXES.pop(session_id) | |
logger.info(f"Successfully shutdown sandbox for session {session_id}") | |
# 3. Log what's being preserved (but don't remove anything) | |
session_manager = SessionStateManager(session_id, TMP_DIR) | |
if session_manager.session_exists(): | |
logger.info(f"Preserving session data for {session_id}") | |
# Load session state to show what's being preserved | |
session_state = session_manager.load_state() | |
if session_state: | |
# Log what we're preserving | |
stats = session_state.get("session_stats", {}) | |
llm_interactions = len(session_state.get("llm_interactions", [])) | |
tool_executions = len(session_state.get("tool_executions", [])) | |
logger.info(f"Preserving session {session_id}: " | |
f"{stats.get('total_messages', 0)} messages, " | |
f"{llm_interactions} LLM interactions, " | |
f"{tool_executions} tool executions, " | |
f"{stats.get('total_code_executions', 0)} code runs") | |
# Log all preserved files | |
if session_manager.session_dir.exists(): | |
try: | |
preserved_files = [] | |
for file_path in session_manager.session_dir.iterdir(): | |
if file_path.is_file(): | |
preserved_files.append(file_path.name) | |
if preserved_files: | |
logger.info(f"Preserving {len(preserved_files)} files in {session_id}: {preserved_files}") | |
else: | |
logger.info(f"No files found in session {session_id}") | |
except OSError as e: | |
logger.warning(f"Could not check session directory {session_id}: {e}") | |
# 4. Keep execution tracking data (don't clear anything) | |
logger.info(f"Preserving execution state and stop events for {session_id}") | |
logger.info(f"Sandbox shutdown completed for session {session_id} (all data preserved)") | |
return gr.Button(visible=False) | |
except Exception as e: | |
logger.error(f"Error during shutdown for session {session_id}: {str(e)}") | |
return f"❌ Error during shutdown: {str(e)}", gr.Button(visible=True) | |
# continue_execution function removed - functionality integrated into execute_jupyter_agent | |
def get_execution_status(request: gr.Request): | |
"""Get the current execution status for UI updates""" | |
session_id = request.session_hash | |
if session_id not in EXECUTION_STATES: | |
return "⚪ Ready" | |
state = EXECUTION_STATES[session_id] | |
if state["running"]: | |
if session_id in STOP_EVENTS and STOP_EVENTS[session_id].is_set(): | |
return "⏸️ Stopping..." | |
else: | |
# Check if we have more detailed phase information | |
phase = state.get("current_phase", "running") | |
if phase == "generating": | |
return "🟢 Generating response..." | |
elif phase == "executing_code": | |
return "🟢 Executing code..." | |
elif phase == "searching": | |
return "🟢 Searching web..." | |
else: | |
return "🟢 Running" | |
elif state.get("paused", False): | |
return "⏸️ Paused - Click Run to continue" | |
else: | |
return "⚪ Ready" | |
def is_sandbox_active(request: gr.Request): | |
"""Check if sandbox is active for the current session""" | |
session_id = request.session_hash | |
return session_id in SANDBOXES | |
def get_sandbox_status_and_visibility(request: gr.Request): | |
"""Get sandbox status message and button visibility""" | |
session_id = request.session_hash | |
if session_id in SANDBOXES: | |
return "🟢 Sandbox active", gr.Button(visible=True) | |
else: | |
return "⚪ No sandbox active", gr.Button(visible=False) | |
def update_sandbox_button_visibility(request: gr.Request): | |
"""Update only the button visibility based on sandbox status""" | |
session_id = request.session_hash | |
return gr.Button(visible=session_id in SANDBOXES) | |
def reset_ui_after_shutdown(request: gr.Request): | |
"""Reset UI components after complete shutdown""" | |
session_id = request.session_hash | |
# Check if session is truly cleared | |
is_cleared = (session_id not in SANDBOXES and | |
session_id not in EXECUTION_STATES and | |
session_id not in STOP_EVENTS) | |
if is_cleared: | |
# Return reset state for all UI components | |
return ( | |
init_notebook.render(), # Reset notebook display | |
[], # Clear message state | |
"⚪ Ready", # Reset status | |
"⚪ No sandbox active", # Reset sandbox status | |
gr.Button(visible=False) # Hide shutdown button | |
) | |
else: | |
# Return current state if not fully cleared | |
status = get_execution_status(request) | |
sandbox_status, button_vis = get_sandbox_status_and_visibility(request) | |
return ( | |
init_notebook.render(), # Still reset notebook display | |
[], # Still clear message state | |
status, | |
sandbox_status, | |
button_vis | |
) | |
def reconstruct_message_history_from_notebook(notebook_data): | |
"""Reconstruct message history from notebook cells""" | |
message_history = [] | |
cells = notebook_data.get('cells', []) | |
system_prompt = None | |
current_conversation = [] | |
for cell in cells: | |
cell_type = cell.get('cell_type', '') | |
if cell_type == 'markdown': | |
content = cell.get('source', '') | |
if isinstance(content, list): | |
content = ''.join(content) | |
# Check if this is a system message | |
if 'System' in content and 'IMPORTANT EXECUTION GUIDELINES' in content: | |
# Extract the system prompt content | |
system_content = content | |
# Clean up the HTML and extract the actual content | |
# Remove HTML tags and extract the text content | |
clean_content = re.sub(r'<[^>]+>', '', system_content) | |
clean_content = re.sub(r'\n+', '\n', clean_content).strip() | |
system_prompt = clean_content | |
elif 'User' in content and not any(word in content for word in ['Assistant', 'System']): | |
# This is a user message | |
# Extract the user content after the User header | |
user_content = content.split('User')[1] if 'User' in content else content | |
# Clean up HTML and formatting | |
user_content = re.sub(r'<[^>]+>', '', user_content) | |
user_content = re.sub(r'-{3,}', '', user_content) | |
user_content = user_content.strip() | |
if user_content: | |
current_conversation.append({ | |
"role": "user", | |
"content": user_content | |
}) | |
elif 'Assistant' in content: | |
# This is an assistant message | |
assistant_content = content.split('Assistant')[1] if 'Assistant' in content else content | |
# Clean up HTML and formatting | |
assistant_content = re.sub(r'<[^>]+>', '', assistant_content) | |
assistant_content = re.sub(r'-{3,}', '', assistant_content) | |
assistant_content = assistant_content.strip() | |
if assistant_content: | |
current_conversation.append({ | |
"role": "assistant", | |
"content": assistant_content | |
}) | |
# Build the final message history | |
if system_prompt: | |
message_history.append({ | |
"role": "system", | |
"content": system_prompt | |
}) | |
# Add the conversation messages | |
message_history.extend(current_conversation) | |
return message_history | |
def load_previous_notebook(notebook_choice, request: gr.Request): | |
"""Load a previous notebook with complete session configuration (dev only)""" | |
if not is_dev_environment(): | |
return (init_notebook.render(), [], "Load previous notebooks is only available in development mode", | |
None, None, None, None, None, "", "", "", "", "", "", "", False) | |
if not notebook_choice or notebook_choice == "None": | |
return (init_notebook.render(), [], "Please select a notebook to load", | |
None, None, None, None, None, "", "", "", "", "", "", "", False) | |
try: | |
# Parse the notebook choice to get the session ID | |
session_id = notebook_choice.split(" ")[0] | |
notebook_path = Path(TMP_DIR) / session_id / "jupyter-agent.ipynb" | |
if not notebook_path.exists(): | |
return (init_notebook.render(), [], f"Notebook file not found: {notebook_path}", | |
None, None, None, None, None, "", "", "", "", "", "", "", False) | |
# Load the notebook | |
with open(notebook_path, 'r') as f: | |
notebook_data = json.load(f) | |
# Load session state | |
temp_session_manager = SessionStateManager(session_id, TMP_DIR) | |
session_state = temp_session_manager.load_state() | |
session_config = None # For backward compatibility | |
# Extract config from session state for UI restoration | |
if session_state: | |
session_config = { | |
"hardware": session_state.get("hardware_config", {}), | |
"environment_vars": session_state.get("environment", {}).get("variables", ""), | |
"api_keys": { | |
"model_name": session_state.get("api_config", {}).get("model_name", "") | |
} | |
} | |
# Create a new JupyterNotebook instance with the loaded data | |
loaded_notebook = JupyterNotebook() | |
loaded_notebook.data = notebook_data | |
# Reconstruct message history from notebook cells | |
message_history = reconstruct_message_history_from_notebook(notebook_data) | |
# Store the loaded notebook info in session for continue functionality | |
session_id_hash = request.session_hash | |
if session_id_hash not in EXECUTION_STATES: | |
EXECUTION_STATES[session_id_hash] = {} | |
EXECUTION_STATES[session_id_hash]["loaded_notebook"] = { | |
"notebook_data": notebook_data, | |
"message_history": message_history, | |
"original_session": session_id, | |
"session_config": session_config | |
} | |
logger.info(f"Successfully loaded notebook from {notebook_path}") | |
logger.info(f"Reconstructed message history with {len(message_history)} messages") | |
# Prepare configuration values to restore UI state | |
config_loaded = "" | |
gpu_type = None | |
cpu_cores = None | |
memory_gb = None | |
timeout_sec = None | |
env_vars = "" | |
modal_token_id = "" | |
modal_token_secret = "" | |
hf_token = "" | |
provider_api_key = "" | |
provider_api_endpoint = "" | |
model_name = "" | |
if session_config: | |
hardware = session_config.get("hardware", {}) | |
gpu_type = hardware.get("gpu_type") | |
cpu_cores = hardware.get("cpu_cores") | |
memory_gb = hardware.get("memory_gb") | |
timeout_sec = hardware.get("timeout_sec") | |
env_vars = session_config.get("environment_vars", "") | |
api_keys = session_config.get("api_keys", {}) | |
modal_token_id = api_keys.get("modal_token_id", "") | |
modal_token_secret = api_keys.get("modal_token_secret", "") | |
hf_token = api_keys.get("hf_token", "") | |
provider_api_key = api_keys.get("provider_api_key", "") | |
provider_api_endpoint = api_keys.get("provider_api_endpoint", "") | |
model_name = api_keys.get("model_name", "") | |
config_loaded = f"✅ Configuration restored: GPU={gpu_type}, CPU={cpu_cores}, Memory={memory_gb}GB, Timeout={timeout_sec}s" | |
success_message = f"✅ Loaded notebook: {session_id} ({len(notebook_data.get('cells', []))} cells, {len(message_history)} messages)" | |
if config_loaded: | |
success_message += f"\n{config_loaded}" | |
return (loaded_notebook.render(), message_history, success_message, | |
gpu_type, cpu_cores, memory_gb, timeout_sec, env_vars, | |
modal_token_id, modal_token_secret, hf_token, provider_api_key, provider_api_endpoint, model_name, | |
"", False) # Default empty tavily_api_key and False for enable_web_search | |
except Exception as e: | |
logger.error(f"Failed to load notebook {notebook_choice}: {str(e)}") | |
error_message = f"❌ Failed to load notebook: {str(e)}" | |
return (init_notebook.render(), [], error_message, | |
None, None, None, None, None, "", "", "", "", "", "", "", False) | |
def get_notebook_options(): | |
"""Get options for notebook dropdown (dev only)""" | |
if not is_dev_environment(): | |
return ["Load previous notebooks is only available in development mode"] | |
notebooks = get_previous_notebooks() | |
if not notebooks: | |
return ["No previous notebooks found"] | |
options = ["None"] + [nb['display_name'] for nb in notebooks[:20]] # Limit to 20 most recent | |
return options | |
def refresh_notebook_options(): | |
"""Refresh the notebook options dropdown""" | |
return gr.Dropdown(choices=get_notebook_options(), value="None") | |
# Legacy session configuration functions removed - replaced by SessionStateManager | |
# All session data is now stored in a single comprehensive session_state.json file | |
css = """ | |
#component-0 { | |
height: 100vh; | |
overflow-y: auto; | |
padding: 20px; | |
} | |
.gradio-container { | |
height: 100vh !important; | |
} | |
.contain { | |
height: 100vh !important; | |
} | |
/* Button states for execution control */ | |
.button-executing { | |
opacity: 0.6 !important; | |
pointer-events: none !important; | |
cursor: not-allowed !important; | |
} | |
.button-executing::after { | |
content: " ⏳"; | |
} | |
.status-running { | |
animation: pulse 2s infinite; | |
} | |
@keyframes pulse { | |
0% { opacity: 1; } | |
50% { opacity: 0.5; } | |
100% { opacity: 1; } | |
} | |
""" | |
# Create the interface | |
with gr.Blocks() as demo: | |
msg_state = gr.State(value=[]) | |
# Environment info display | |
env_info = gr.Markdown(f""" | |
**Environment**: {get_environment().upper()} | **Features**: {"Development features enabled" if is_dev_environment() else "Production mode"} | |
""") | |
html_output = gr.HTML(value=JupyterNotebook().render()) | |
user_input = gr.Textbox( | |
# value="train a 5 neuron neural network to classify the iris dataset", | |
value="can you finetune llama 3.2 1b on tiny stories dataset and using unsloth", | |
lines=3, | |
label="Agent task" | |
) | |
with gr.Accordion("Upload files ⬆ | Download notebook⬇", open=False): | |
files = gr.File(label="Upload files to use", file_count="multiple") | |
file = gr.File(TMP_DIR+"jupyter-agent.ipynb", label="Download Jupyter Notebook") | |
with gr.Row(): | |
# Web Search Configuration | |
with gr.Accordion("🔍 Web Search Settings", open=False): | |
with gr.Row(): | |
enable_web_search = gr.Checkbox( | |
label="Enable Web Search", | |
value=bool(os.environ.get("TAVILY_API_KEY")), # Default to True if API key is available | |
info="Allow the agent to search the web for current information and documentation" | |
) | |
# Show web search status with better formatting | |
tavily_status = "✅ Available" if os.environ.get("TAVILY_API_KEY") else "❌ API Key Required" | |
gr.Markdown(f"**Status:** {tavily_status}") | |
gr.Markdown(""" | |
**Web Search Features:** | |
- 🌐 Search for current tutorials, documentation, and best practices | |
- 🐛 Find solutions to error messages and debugging help | |
- 📚 Access up-to-date library documentation and examples | |
- 💡 Get recent examples and code snippets from the web | |
⚠️ **Note**: Web search requires a Tavily API key. Get one free at [tavily.com](https://tavily.com) | |
""") | |
# Previous notebooks section (dev only) | |
if is_dev_environment(): | |
with gr.Accordion("📂 Load Previous Notebook (Dev Only)", open=False): | |
notebook_dropdown = gr.Dropdown( | |
choices=get_notebook_options(), | |
value="None", | |
label="Select Previous Notebook", | |
info="Load a previously created notebook session" | |
) | |
with gr.Row(): | |
load_notebook_btn = gr.Button("📖 Load Selected", variant="secondary") | |
refresh_notebooks_btn = gr.Button("🔄 Refresh List", variant="secondary") | |
load_status = gr.Textbox( | |
label="Load Status", | |
interactive=False, | |
visible=False | |
) | |
# Check for missing API keys and show input fields conditionally | |
missing_keys = get_missing_api_keys() | |
# API Key Configuration (shown only if keys are missing) | |
if missing_keys: | |
with gr.Accordion("🔑 Required API Keys (Missing from .env)", open=True): | |
gr.Markdown(""" | |
**⚠️ Some required API keys are missing from your .env file.** | |
Please provide them below to use the application: | |
""") | |
api_key_components = {} | |
if "MODAL_TOKEN_ID" in missing_keys: | |
api_key_components["modal_token_id"] = gr.Textbox( | |
label="Modal Token ID", | |
placeholder="ak-...", | |
info="Modal Token ID for sandbox access", | |
type="password" | |
) | |
else: | |
api_key_components["modal_token_id"] = gr.Textbox(visible=False) | |
if "MODAL_TOKEN_SECRET" in missing_keys: | |
api_key_components["modal_token_secret"] = gr.Textbox( | |
label="Modal Token Secret", | |
placeholder="as-...", | |
info="Modal Token Secret for sandbox access", | |
type="password" | |
) | |
else: | |
api_key_components["modal_token_secret"] = gr.Textbox(visible=False) | |
if "HF_TOKEN" in missing_keys: | |
api_key_components["hf_token"] = gr.Textbox( | |
label="Hugging Face Token (Optional)", | |
placeholder="hf_...", | |
info="Hugging Face Token for model access", | |
type="password" | |
) | |
else: | |
api_key_components["hf_token"] = gr.Textbox(visible=False) | |
if "PROVIDER_API_KEY" in missing_keys: | |
api_key_components["provider_api_key"] = gr.Textbox( | |
label="AI Provider API Key", | |
placeholder="sk-, gsk_, or csk-...", | |
info="API Key for your AI provider (Anthropic, OpenAI, Cerebras, etc.)", | |
type="password" | |
) | |
else: | |
api_key_components["provider_api_key"] = gr.Textbox(visible=False) | |
if "PROVIDER_API_ENDPOINT" in missing_keys: | |
api_key_components["provider_api_endpoint"] = gr.Textbox( | |
label="AI Provider API Endpoint", | |
placeholder="https://api.anthropic.com/v1/", | |
info="API endpoint for your AI provider" | |
) | |
else: | |
api_key_components["provider_api_endpoint"] = gr.Textbox(visible=False) | |
if "MODEL_NAME" in missing_keys: | |
api_key_components["model_name"] = gr.Textbox( | |
label="Model Name", | |
placeholder="claude-sonnet-4-20250514", | |
info="Name of the model to use" | |
) | |
else: | |
api_key_components["model_name"] = gr.Textbox(visible=False) | |
if "TAVILY_API_KEY" in missing_keys: | |
api_key_components["tavily_api_key"] = gr.Textbox( | |
label="Tavily API Key (Optional)", | |
placeholder="tvly-...", | |
info="Tavily API Key for web search functionality", | |
type="password" | |
) | |
else: | |
api_key_components["tavily_api_key"] = gr.Textbox(visible=False) | |
else: | |
# Create hidden components when no keys are missing | |
api_key_components = { | |
"modal_token_id": gr.Textbox(visible=False), | |
"modal_token_secret": gr.Textbox(visible=False), | |
"hf_token": gr.Textbox(visible=False), | |
"provider_api_key": gr.Textbox(visible=False), | |
"provider_api_endpoint": gr.Textbox(visible=False), | |
"model_name": gr.Textbox(visible=False), | |
"tavily_api_key": gr.Textbox(visible=False) | |
} | |
with gr.Accordion("Hardware Configuration ⚙️", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
env_vars = gr.Textbox( | |
label="Environment Variables", | |
placeholder="Enter environment variables (one per line):\nAPI_KEY=your_key_here\nDATA_PATH=/path/to/data\nDEBUG=true", | |
lines=5, | |
info="Add custom environment variables for the sandbox. Format: KEY=value (one per line)" | |
) | |
env_info = gr.Markdown(""" | |
**Environment Variables Info:** | |
- Variables will be available in the sandbox environment | |
- Use KEY=value format, one per line | |
- Common examples: API keys, data paths, configuration flags | |
- Variables are session-specific and not persisted between sessions | |
⚠️ **Security**: Avoid sensitive credentials in shared environments | |
""") | |
with gr.Column(): | |
with gr.Row(): | |
gpu_type = gr.Dropdown( | |
choices=GPU_OPTIONS, | |
value="cpu", | |
label="GPU Type", | |
info="Select hardware acceleration" | |
) | |
cpu_cores = gr.Slider( | |
minimum=0.25, | |
maximum=16, | |
value=2.0, | |
step=0.25, | |
label="CPU Cores", | |
info="Number of CPU cores" | |
) | |
with gr.Row(): | |
memory_gb = gr.Slider( | |
minimum=0.5, | |
maximum=64, | |
value=8.0, | |
step=0.5, | |
label="Memory (GB)", | |
info="RAM allocation" | |
) | |
timeout_sec = gr.Slider( | |
minimum=60, | |
maximum=1800, | |
value=300, | |
step=60, | |
label="Timeout (seconds)", | |
info="Maximum execution time" | |
) | |
hardware_info = gr.Markdown(""" | |
**Hardware Options:** | |
- **CPU Only**: Free, good for basic tasks | |
- **T4**: Low-cost GPU, good for small models | |
- **L4**: Mid-range GPU, better performance | |
- **A100 40/80GB**: High-end GPU for large models | |
- **H100**: Latest flagship GPU for maximum performance | |
⚠️ **Note**: GPU instances cost more. Choose based on your workload. | |
""") | |
# with gr.Accordion("Environment Variables 🔧", open=False): | |
with gr.Row(): | |
generate_btn = gr.Button("Run!", variant="primary") | |
stop_btn = gr.Button("⏸️ Stop", variant="secondary") | |
# continue_btn removed - Run button handles continuation automatically | |
clear_btn = gr.Button("Clear Notebook", variant="stop") | |
shutdown_btn = gr.Button("🔴 Shutdown Sandbox", variant="stop", visible=False) | |
# Status display | |
status_display = gr.Textbox( | |
value="⚪ Ready", | |
label="Execution Status", | |
interactive=False, | |
max_lines=1 | |
) | |
generate_btn.click( | |
fn=execute_jupyter_agent, | |
inputs=[ | |
user_input, files, msg_state, gpu_type, cpu_cores, memory_gb, timeout_sec, env_vars, | |
api_key_components["modal_token_id"], api_key_components["modal_token_secret"], | |
api_key_components["hf_token"], api_key_components["provider_api_key"], | |
api_key_components["provider_api_endpoint"], api_key_components["model_name"], | |
api_key_components["tavily_api_key"], enable_web_search | |
], | |
outputs=[html_output, msg_state, file], | |
show_progress="hidden", | |
) | |
stop_btn.click( | |
fn=stop_execution, | |
outputs=[status_display], | |
show_progress="hidden", | |
) | |
# continue_btn.click handler removed - Run button handles continuation automatically | |
clear_btn.click(fn=clear, inputs=[msg_state], outputs=[html_output, msg_state]) | |
shutdown_btn.click( | |
fn=shutdown_sandbox, | |
outputs=[shutdown_btn], | |
show_progress="hidden", | |
) | |
# Add event handlers for notebook loading (dev only) | |
if is_dev_environment(): | |
load_notebook_btn.click( | |
fn=load_previous_notebook, | |
inputs=[notebook_dropdown], | |
outputs=[ | |
html_output, msg_state, load_status, | |
gpu_type, cpu_cores, memory_gb, timeout_sec, env_vars, | |
api_key_components["modal_token_id"], api_key_components["modal_token_secret"], | |
api_key_components["hf_token"], api_key_components["provider_api_key"], | |
api_key_components["provider_api_endpoint"], api_key_components["model_name"], | |
api_key_components["tavily_api_key"], enable_web_search | |
], | |
show_progress="hidden" | |
) | |
refresh_notebooks_btn.click( | |
fn=refresh_notebook_options, | |
outputs=[notebook_dropdown], | |
show_progress="hidden" | |
) | |
# Show/hide load status based on selection | |
notebook_dropdown.change( | |
fn=lambda choice: gr.Textbox(visible=choice != "None"), | |
inputs=[notebook_dropdown], | |
outputs=[load_status] | |
) | |
# Periodic status update using timer | |
status_timer = gr.Timer(2.0) # Update every 2 seconds | |
status_timer.tick( | |
fn=get_execution_status, | |
outputs=[status_display], | |
show_progress="hidden" | |
) | |
# Update button visibility periodically | |
button_timer = gr.Timer(3.0) # Check every 3 seconds | |
button_timer.tick( | |
fn=update_sandbox_button_visibility, | |
outputs=[shutdown_btn], | |
show_progress="hidden" | |
) | |
demo.load( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
js=""" () => { | |
if (document.querySelectorAll('.dark').length) { | |
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); | |
} | |
// Add execution state management functions | |
window.setExecutionState = function(isExecuting) { | |
// Find Run button by text content since variant attribute might not be reliable | |
const buttons = document.querySelectorAll('button'); | |
let runButton = null; | |
let stopButton = null; | |
buttons.forEach(button => { | |
const text = button.textContent.trim().toLowerCase(); | |
if (text.includes('run') && !text.includes('stop')) { | |
runButton = button; | |
} else if (text.includes('stop') || text.includes('⏸️')) { | |
stopButton = button; | |
} | |
}); | |
if (runButton) { | |
if (isExecuting) { | |
runButton.classList.add('button-executing'); | |
runButton.disabled = true; | |
runButton.style.opacity = '0.6'; | |
runButton.style.cursor = 'not-allowed'; | |
runButton.style.pointerEvents = 'none'; | |
if (runButton.textContent.indexOf('⏳') === -1) { | |
runButton.textContent = runButton.textContent.replace('!', '! ⏳'); | |
} | |
} else { | |
runButton.classList.remove('button-executing'); | |
runButton.disabled = false; | |
runButton.style.opacity = '1'; | |
runButton.style.cursor = 'pointer'; | |
runButton.style.pointerEvents = 'auto'; | |
runButton.textContent = runButton.textContent.replace(' ⏳', ''); | |
} | |
} | |
// Also update stop button visibility/state | |
if (stopButton) { | |
stopButton.style.display = isExecuting ? 'block' : 'inline-block'; | |
} | |
}; | |
// Monitor for status changes and update button states | |
window.monitorExecutionStatus = function() { | |
// Try multiple ways to find the status element | |
let statusElement = document.querySelector('input[label*="Execution Status"], input[label*="Status"], textarea[label*="Status"]'); | |
if (!statusElement) { | |
// Fallback: look for any input that might contain status | |
const allInputs = document.querySelectorAll('input, textarea'); | |
allInputs.forEach(input => { | |
if (input.value && (input.value.includes('🟢') || input.value.includes('⚪') || input.value.includes('⏸️'))) { | |
statusElement = input; | |
} | |
}); | |
} | |
if (statusElement) { | |
const status = statusElement.value || ''; | |
const isRunning = status.includes('🟢') || status.includes('Running') || status.includes('Generating') || status.includes('Executing'); | |
const isReady = status.includes('⚪') || status.includes('Ready'); | |
window.setExecutionState(isRunning); | |
// Add visual indicator to status element | |
if (isRunning) { | |
statusElement.style.background = '#e3f2fd'; | |
statusElement.style.borderColor = '#2196f3'; | |
} else if (isReady) { | |
statusElement.style.background = '#f5f5f5'; | |
statusElement.style.borderColor = '#ccc'; | |
} else { | |
statusElement.style.background = '#fff3e0'; | |
statusElement.style.borderColor = '#ff9800'; | |
} | |
} | |
}; | |
// Set up mutation observer to watch for status changes | |
const observer = new MutationObserver(function(mutations) { | |
mutations.forEach(function(mutation) { | |
if (mutation.type === 'childList' || mutation.type === 'attributes') { | |
setTimeout(window.monitorExecutionStatus, 100); | |
} | |
}); | |
}); | |
// Start observing | |
observer.observe(document.body, { | |
childList: true, | |
subtree: true, | |
attributes: true | |
}); | |
} | |
""" | |
) | |
logger.info("Starting Gradio application") | |
demo.launch(ssr_mode=False) | |