EurekaAgent / jupyter_agent.py
AdithyaSK's picture
Eureka agent init - Adithya S K
744e5e2
from jupyter_handler import JupyterNotebook
import json
import logging
import os
import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
from tavily import TavilyClient
# Phoenix tracing imports
try:
from openinference.instrumentation import using_session
PHOENIX_AVAILABLE = True
print("Phoenix session tracking imports successful")
except ImportError:
PHOENIX_AVAILABLE = False
print("Phoenix session tracking not available - missing openinference packages")
# Configure logging for utils module
logger = logging.getLogger(__name__)
# Initialize Tavily client
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) if TAVILY_API_KEY else None
TOOLS = [
{
"type": "function",
"function": {
"name": "add_and_execute_jupyter_code_cell",
"description": "A Python code execution environment that runs code in a Jupyter notebook interface. This is stateful - variables and imports persist between executions.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The Python code to execute."
}
},
"required": ["code"]
}
}
},
{
"type": "function",
"function": {
"name": "edit_and_execute_current_cell",
"description": "Edit the current/last code cell and execute the new code. Use this to fix errors or modify the previous code instead of creating a new cell.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The updated Python code to replace the current cell with and execute."
}
},
"required": ["code"]
}
}
},
{
"type": "function",
"function": {
"name": "execute_shell_command",
"description": "Execute shell/system commands like ls, cat, mkdir, etc. This runs independently of Python and provides terminal-style output.",
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute (e.g., 'ls -la', 'cat file.txt', 'mkdir new_folder')."
}
},
"required": ["command"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information, documentation, tutorials, and solutions to coding problems. Use this to get context before starting tasks or when encountering errors.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (max 400 characters). Be specific and include relevant keywords."
}
},
"required": ["query"]
}
}
},
]
# TOOLS = TOOLS[:1]
MAX_TURNS = 20
def create_phoenix_session_context(session_id: str, user_id: str = None, metadata: Dict = None):
"""
Create a Phoenix session context for tracing LLM interactions.
Args:
session_id: Unique identifier for the session
user_id: Optional user identifier
metadata: Additional metadata to include in traces
Returns:
Context manager for Phoenix session tracking
"""
if not PHOENIX_AVAILABLE:
# Return a no-op context manager if Phoenix is not available
from contextlib import nullcontext
return nullcontext()
try:
# Use using_session for proper session grouping in Phoenix
# This ensures all LLM calls within this context are grouped under the same session
logger.debug(f"Creating Phoenix session context for session_id: {session_id}")
return using_session(session_id)
except Exception as e:
logger.warning(f"Failed to create Phoenix session context for {session_id}: {e}")
# Fallback to no-op context if Phoenix session creation fails
from contextlib import nullcontext
return nullcontext()
class SessionStateManager:
"""Manages comprehensive session state in a single JSON file"""
def __init__(self, session_id: str, base_dir: str = './temp/'):
self.session_id = session_id
self.base_dir = Path(base_dir)
self.session_dir = self.base_dir / session_id
self.state_file = self.session_dir / 'session_state.json'
self.session_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"SessionStateManager initialized for {session_id}")
def create_initial_state(self, hardware_config: Dict, api_config: Dict,
environment: Dict, system_prompt: str) -> Dict:
"""Create initial session state structure"""
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
initial_state = {
"session_id": self.session_id,
"created_at": timestamp,
"last_updated": timestamp,
"version": "1.0",
"hardware_config": hardware_config,
"api_config": api_config,
"environment": environment,
"conversation_history": [
{
"role": "system",
"content": system_prompt,
"timestamp": timestamp,
"metadata": {"type": "system_initialization"}
}
],
"llm_interactions": [], # Complete API call logs
"tool_executions": [], # All tool calls and results
"notebook_data": {
"cells": [],
"metadata": {
"kernel_info": {"name": "python3"},
"language_info": {"name": "python", "version": "3.12"},
},
"nbformat": 4,
"nbformat_minor": 0
},
"execution_state": {
"current_turn": 0,
"max_turns": MAX_TURNS,
"is_running": False,
"is_paused": False,
"last_execution_successful": None,
"sandbox_active": False,
"sandbox_info": None
},
"session_stats": {
"total_messages": 1,
"total_code_executions": 0,
"total_searches": 0,
"total_errors": 0,
"session_duration_seconds": 0
}
}
logger.info("Created initial session state for %s", self.session_id)
return initial_state
def load_state(self) -> Optional[Dict]:
"""Load session state from file with improved error handling"""
if not self.state_file.exists():
logger.info(f"No existing session state found for {self.session_id}")
return None
try:
with open(self.state_file, 'r', encoding='utf-8') as f:
state = json.load(f)
logger.info(f"Loaded session state for {self.session_id} with {len(state.get('conversation_history', []))} messages")
return state
except json.JSONDecodeError as e:
logger.error(f"JSON corruption in session state for {self.session_id}: {str(e)}")
logger.info(f"Creating backup of corrupted file: {self.state_file}.corrupted")
try:
import shutil
shutil.copy2(self.state_file, str(self.state_file) + ".corrupted")
logger.info(f"Backup created successfully")
except Exception as backup_error:
logger.warning(f"Failed to create backup: {backup_error}")
return None
except Exception as e:
logger.error(f"Failed to load session state for {self.session_id}: {str(e)}")
return None
def save_state(self, state: Dict) -> bool:
"""Save session state to file with improved error handling"""
try:
# Update last_updated timestamp
state["last_updated"] = datetime.datetime.now(datetime.timezone.utc).isoformat()
# Update session stats
if "session_stats" not in state:
state["session_stats"] = {}
created_at = datetime.datetime.fromisoformat(state["created_at"])
current_time = datetime.datetime.now(datetime.timezone.utc)
state["session_stats"]["session_duration_seconds"] = int((current_time - created_at).total_seconds())
state["session_stats"]["total_messages"] = len(state.get("conversation_history", []))
# Validate JSON serializability before writing
try:
json.dumps(state, ensure_ascii=False)
except (TypeError, ValueError) as e:
logger.error(f"State contains non-serializable data: {e}")
logger.info("Attempting to clean non-serializable data...")
state = self._clean_non_serializable_data(state)
# Write to temporary file first, then rename for atomic operation
temp_file = self.state_file.with_suffix('.tmp')
with open(temp_file, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=2, ensure_ascii=False)
# Atomic rename
temp_file.replace(self.state_file)
logger.debug(f"Saved session state for {self.session_id} ({len(json.dumps(state))} characters)")
return True
except Exception as e:
logger.error(f"Failed to save session state for {self.session_id}: {str(e)}")
# Clean up temp file if it exists
temp_file = self.state_file.with_suffix('.tmp')
if temp_file.exists():
try:
temp_file.unlink()
except Exception:
pass
return False
def _clean_non_serializable_data(self, obj):
"""Recursively clean non-serializable data from objects"""
if isinstance(obj, dict):
cleaned = {}
for key, value in obj.items():
try:
json.dumps(value)
cleaned[key] = self._clean_non_serializable_data(value)
except (TypeError, ValueError):
logger.warning(f"Removing non-serializable field: {key}")
cleaned[key] = f"<non-serializable: {type(value).__name__}>"
return cleaned
elif isinstance(obj, list):
cleaned = []
for item in obj:
try:
json.dumps(item)
cleaned.append(self._clean_non_serializable_data(item))
except (TypeError, ValueError):
cleaned.append(f"<non-serializable: {type(item).__name__}>")
return cleaned
else:
return obj
def log_llm_interaction(self, state: Dict, request_data: Dict, response_data: Dict,
model: str, turn: int) -> None:
"""Log complete LLM API interaction"""
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
interaction = {
"timestamp": timestamp,
"turn": turn,
"model": model,
"request": {
"messages_count": len(request_data.get("messages", [])),
"tools_count": len(request_data.get("tools", [])),
"model": request_data.get("model"),
"tool_choice": request_data.get("tool_choice")
},
"response": {
"content": response_data.get("choices", [{}])[0].get("message", {}).get("content"),
"tool_calls": response_data.get("choices", [{}])[0].get("message", {}).get("tool_calls"),
"finish_reason": response_data.get("choices", [{}])[0].get("finish_reason"),
"usage": response_data.get("usage")
}
}
if "llm_interactions" not in state:
state["llm_interactions"] = []
state["llm_interactions"].append(interaction)
# Log Phoenix session information for easy debugging
logger.debug(f"Logged LLM interaction for turn {turn} in session {self.session_id}")
logger.debug(f"Phoenix session tracking: session_id={self.session_id}, turn={turn}, model={model}")
# Log usage information if available for monitoring
usage = response_data.get("usage")
if usage:
logger.info(f"Session {self.session_id} turn {turn}: "
f"prompt_tokens={usage.get('prompt_tokens', 0)}, "
f"completion_tokens={usage.get('completion_tokens', 0)}, "
f"total_tokens={usage.get('total_tokens', 0)}")
def log_tool_execution(self, state: Dict, tool_call_id: str, tool_name: str,
tool_args: Dict, result: str, execution_data: Any = None) -> None:
"""Log tool execution with full details"""
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
# Safely serialize execution_data to prevent JSON corruption
safe_execution_data = None
if execution_data is not None:
try:
# Convert execution_data to a safe, serializable format
if hasattr(execution_data, '__dict__'):
safe_execution_data = {
"type": type(execution_data).__name__,
"error": str(execution_data.error) if hasattr(execution_data, 'error') and execution_data.error else None,
"has_results": hasattr(execution_data, 'results') and bool(execution_data.results),
"has_stdout": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stdout') and bool(execution_data.logs.stdout),
"has_stderr": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stderr') and bool(execution_data.logs.stderr)
}
else:
# For simple types, convert to string safely
safe_execution_data = str(execution_data)[:200] # Limit length
except Exception as e:
logger.warning(f"Failed to serialize execution_data for {tool_call_id}: {e}")
safe_execution_data = {"serialization_error": str(e)}
tool_execution = {
"timestamp": timestamp,
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"arguments": tool_args,
"result_summary": result[:500] + "..." if len(result) > 500 else result,
"result_length": len(result),
"execution_data": safe_execution_data,
"success": execution_data is None or (hasattr(execution_data, 'error') and execution_data.error is None) if execution_data else True
}
if "tool_executions" not in state:
state["tool_executions"] = []
state["tool_executions"].append(tool_execution)
# Update stats
if tool_name == "add_and_execute_jupyter_code_cell":
state["session_stats"]["total_code_executions"] = state["session_stats"].get("total_code_executions", 0) + 1
elif tool_name == "web_search":
state["session_stats"]["total_searches"] = state["session_stats"].get("total_searches", 0) + 1
if not tool_execution["success"]:
state["session_stats"]["total_errors"] = state["session_stats"].get("total_errors", 0) + 1
logger.debug(f"Logged tool execution {tool_name} ({tool_call_id}) in session {self.session_id}")
def add_message(self, state: Dict, role: str, content: str,
tool_calls: List = None, tool_call_id: str = None,
raw_execution: Any = None, metadata: Dict = None) -> None:
"""Add message to conversation history with full context"""
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
message = {
"role": role,
"content": content,
"timestamp": timestamp
}
if tool_calls:
message["tool_calls"] = tool_calls
if tool_call_id:
message["tool_call_id"] = tool_call_id
if raw_execution:
message["raw_execution"] = raw_execution
if metadata:
message["metadata"] = metadata
state["conversation_history"].append(message)
logger.debug(f"Added {role} message to session {self.session_id} conversation history")
def update_execution_state(self, state: Dict, **kwargs) -> None:
"""Update execution state fields"""
for key, value in kwargs.items():
if key in state["execution_state"]:
state["execution_state"][key] = value
logger.debug(f"Updated execution state {key}={value} for session {self.session_id}")
# Try to sync with global EXECUTION_STATES for UI consistency (if available)
try:
import sys
if 'app' in sys.modules:
execution_states = getattr(sys.modules['app'], 'EXECUTION_STATES', None)
if execution_states and self.session_id in execution_states:
for key, value in kwargs.items():
execution_states[self.session_id][key] = value
except (ImportError, AttributeError):
pass # Ignore if we can't sync with global state
def update_notebook_data(self, state: Dict, notebook_data: Dict) -> None:
"""Update notebook data in session state"""
state["notebook_data"] = notebook_data
logger.debug(f"Updated notebook data for session {self.session_id} ({len(notebook_data.get('cells', []))} cells)")
def get_conversation_history(self, state: Dict) -> List[Dict]:
"""Get conversation history suitable for LLM API calls"""
return state.get("conversation_history", [])
def validate_and_repair_conversation(self, state: Dict) -> None:
"""Validate and repair conversation history to ensure tool calls have responses"""
conversation = state.get("conversation_history", [])
if not conversation:
return
pending_tool_calls = set()
valid_messages = []
for message in conversation:
if message.get("role") == "assistant" and message.get("tool_calls"):
# Track tool calls
for tool_call in message["tool_calls"]:
pending_tool_calls.add(tool_call["id"])
valid_messages.append(message)
elif message.get("role") == "tool" and message.get("tool_call_id"):
# Remove from pending when we find a response
pending_tool_calls.discard(message["tool_call_id"])
valid_messages.append(message)
else:
# Regular message (system, user, assistant without tool calls)
valid_messages.append(message)
# If there are incomplete tool calls, remove the assistant messages that created them
if pending_tool_calls:
logger.warning(f"Found incomplete tool calls in conversation: {pending_tool_calls}")
logger.warning("Removing incomplete assistant messages to repair conversation")
repaired_messages = []
for message in valid_messages:
if (message.get("role") == "assistant" and
message.get("tool_calls") and
any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])):
logger.debug("Removing assistant message with incomplete tool calls")
continue
repaired_messages.append(message)
# Update conversation history
state["conversation_history"] = repaired_messages
logger.info(f"Repaired conversation: {len(conversation)} -> {len(repaired_messages)} messages")
# Save the repaired state
self.save_state(state)
def session_exists(self) -> bool:
"""Check if session state file exists"""
return self.state_file.exists()
def get_session_summary(self, state: Dict) -> str:
"""Get human-readable session summary"""
stats = state.get("session_stats", {})
created = datetime.datetime.fromisoformat(state["created_at"])
return f"""Session {self.session_id}:
- Created: {created.strftime('%Y-%m-%d %H:%M:%S UTC')}
- Messages: {stats.get('total_messages', 0)}
- Code Executions: {stats.get('total_code_executions', 0)}
- Web Searches: {stats.get('total_searches', 0)}
- Errors: {stats.get('total_errors', 0)}
- Duration: {stats.get('session_duration_seconds', 0)}s
- Hardware: {state.get('hardware_config', {}).get('gpu_type', 'unknown')}
- Model: {state.get('api_config', {}).get('model_name', 'unknown')}"""
def execute_code(sbx, code):
logger.debug(f"Executing code in sandbox ({len(code)} characters)")
execution = sbx.run_code(code, on_stdout=lambda data: logger.debug(f'stdout: {data}'))
output = ""
if len(execution.logs.stdout) > 0:
output += "\n".join(execution.logs.stdout)
logger.debug(f"Execution produced {len(execution.logs.stdout)} stdout lines")
if len(execution.logs.stderr) > 0:
output += "\n".join(execution.logs.stderr)
logger.debug(f"Execution produced {len(execution.logs.stderr)} stderr lines")
if execution.error is not None:
output += execution.error.traceback
logger.warning(f"Execution error: {execution.error.name}: {execution.error.value}")
logger.debug(f"Code execution completed, output length: {len(output)}")
return output, execution
def parse_exec_result_llm(execution, max_code_output=1000):
logger.debug(f"Parsing execution result for LLM (max_output: {max_code_output})")
output = []
def truncate_if_needed(text):
if len(text) > max_code_output:
return (text[:max_code_output] + f"\n[Output is truncated as it is more than {max_code_output} characters]")
return text
if execution.results:
results_text_parts = []
plot_count = 0
for result in execution.results:
if hasattr(result, 'text') and result.text:
results_text_parts.append(result.text)
elif hasattr(result, 'png') and result.png:
plot_count += 1
results_text_parts.append(f"[Plot {plot_count} generated and displayed]")
elif hasattr(result, 'html') and result.html:
results_text_parts.append("[HTML output generated]")
if results_text_parts:
results_text = "\n".join(results_text_parts)
output.append(truncate_if_needed(results_text))
logger.debug(f"Added {len(execution.results)} execution results (including {plot_count} plots)")
if execution.logs.stdout:
stdout_text = "\n".join(execution.logs.stdout)
output.append(truncate_if_needed(stdout_text))
logger.debug(f"Added stdout output ({len(execution.logs.stdout)} lines)")
if execution.logs.stderr:
stderr_text = "\n".join(execution.logs.stderr)
output.append(truncate_if_needed(stderr_text))
logger.debug(f"Added stderr output ({len(execution.logs.stderr)} lines)")
if execution.error is not None:
output.append(truncate_if_needed(execution.error.traceback))
logger.debug(f"Added error traceback: {execution.error.name}")
final_output = "\n".join(output)
logger.debug(f"Parsed execution result for LLM: {len(final_output)} characters")
return final_output
def clean_messages_for_api(messages):
"""
Create a clean copy of messages without raw_execution fields and metadata for API calls.
Also validates that tool calls have corresponding tool responses.
This prevents 413 errors and API validation errors.
"""
logger.debug(f"Cleaning {len(messages)} messages for API call")
cleaned_messages = []
raw_execution_count = 0
metadata_count = 0
pending_tool_calls = set()
for message in messages:
cleaned_message = message.copy()
# Remove raw_execution data
if "raw_execution" in cleaned_message:
cleaned_message.pop("raw_execution")
raw_execution_count += 1
# Remove metadata and timestamp
if "metadata" in cleaned_message:
cleaned_message.pop("metadata")
metadata_count += 1
if "timestamp" in cleaned_message:
cleaned_message.pop("timestamp")
# Track tool calls and responses for validation
if cleaned_message.get("role") == "assistant" and cleaned_message.get("tool_calls"):
for tool_call in cleaned_message["tool_calls"]:
pending_tool_calls.add(tool_call["id"])
elif cleaned_message.get("role") == "tool" and cleaned_message.get("tool_call_id"):
pending_tool_calls.discard(cleaned_message["tool_call_id"])
cleaned_messages.append(cleaned_message)
# If there are pending tool calls without responses, remove the assistant message with tool calls
if pending_tool_calls:
logger.warning(f"Found {len(pending_tool_calls)} tool calls without responses: {pending_tool_calls}")
logger.warning("Removing incomplete tool call messages to prevent API errors")
# Remove messages with incomplete tool calls
filtered_messages = []
for message in cleaned_messages:
if (message.get("role") == "assistant" and
message.get("tool_calls") and
any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])):
logger.debug("Removing assistant message with incomplete tool calls")
continue
filtered_messages.append(message)
cleaned_messages = filtered_messages
logger.debug(f"Cleaned messages: removed raw_execution from {raw_execution_count}, metadata from {metadata_count}")
logger.debug(f"Final cleaned message count: {len(cleaned_messages)}")
return cleaned_messages
def web_search(query):
"""
Perform web search using Tavily API with automatic year addition and formatting.
Args:
query (str): Search query (max 400 characters)
Returns:
str: Formatted search results for LLM consumption
"""
if not tavily_client:
logger.error("Tavily client not initialized - API key missing")
return "❌ Search unavailable: Tavily API key not configured"
# Validate query length
if len(query) > 400:
logger.warning(f"Query too long ({len(query)} chars), truncating to 400")
query = query[:400]
# Add current year to query for more recent results
current_year = datetime.datetime.now().year
if str(current_year) not in query:
# Only add year if query has room for it
year_addition = f" {current_year}"
if len(query + year_addition) <= 400:
query += year_addition
logger.debug(f"Added current year to query: {current_year}")
logger.info(f"Performing Tavily search: '{query}' ({len(query)} chars)")
try:
# Perform search with optimized parameters
response = tavily_client.search(
query=query,
search_depth="basic", # Use basic for faster results
max_results=5, # Limit results to avoid overwhelming context
include_answer=True, # Include AI-generated answer
include_raw_content=False, # Don't include raw content to save tokens
include_images=False # Don't include images
)
logger.info(f"Search completed: {len(response.get('results', []))} results found")
# Format results for LLM consumption
formatted_results = format_search_results_for_llm(response)
logger.debug(f"Formatted search results: {len(formatted_results)} characters")
return formatted_results
except Exception as e:
logger.error(f"Tavily search failed: {str(e)}")
return f"❌ Search failed: {str(e)}"
def format_search_results_for_llm(response):
"""Format Tavily search results for LLM consumption"""
query = response.get('query', 'Unknown query')
results = response.get('results', [])
answer = response.get('answer', '')
formatted = f"🔍 **Web Search Results for:** {query}\n\n"
if answer:
formatted += f"**Quick Answer:** {answer}\n\n"
if results:
formatted += f"**Found {len(results)} relevant sources:**\n\n"
for i, result in enumerate(results, 1):
title = result.get('title', 'Untitled')
url = result.get('url', '')
content = result.get('content', '')
score = result.get('score', 0)
# Truncate content to reasonable length
# if len(content) > 300:
# content = content[:300] + "..."
formatted += f"**{i}. {title}** (Relevance: {score:.2f})\n"
formatted += f" 🔗 {url}\n"
formatted += f" 📄 {content}\n\n"
else:
formatted += "No results found.\n"
return formatted
def run_interactive_notebook_with_session_state(client, model, session_state_manager, session_state, sbx, stop_event=None, tools=None):
logger.info(f"Starting interactive notebook with session state for {session_state_manager.session_id}")
# Get conversation history from session state
messages = session_state_manager.get_conversation_history(session_state)
notebook = JupyterNotebook(messages)
# Update execution state
session_state_manager.update_execution_state(session_state, is_running=True, sandbox_active=True, current_phase="initializing")
# Use provided tools or default to all tools
if tools is None:
tools = TOOLS
try:
sbx_info = sbx.get_info()
notebook.add_sandbox_countdown(sbx_info.started_at, sbx_info.end_at)
# Store sandbox info in session state
session_state["execution_state"]["sandbox_info"] = {
"started_at": sbx_info.started_at.isoformat(),
"end_at": sbx_info.end_at.isoformat(),
"timeout_seconds": int((sbx_info.end_at - sbx_info.started_at).total_seconds())
}
logger.debug(f"Added sandbox countdown: {sbx_info.started_at} to {sbx_info.end_at}")
except Exception as e:
logger.warning(f"Failed to get sandbox info: {str(e)}")
logger.debug("Initial notebook yield in 'generating' mode")
# Update notebook data in session state
session_state_manager.update_notebook_data(session_state, notebook.data)
# Save initial state
session_state_manager.save_state(session_state)
yield notebook.render(mode="generating"), notebook.data, messages
max_code_output = 1000
turns = session_state["execution_state"]["current_turn"]
done = False
previous_execution_had_error = False
previous_execution_had_warnings = False
logger.info(f"Starting interactive loop from turn {turns} with max_output={max_code_output}, max_turns={MAX_TURNS}")
while not done and (turns <= MAX_TURNS) and (stop_event is None or not stop_event.is_set()):
turns += 1
logger.info(f"Starting turn {turns}/{MAX_TURNS}")
try:
# Update phase to generating
session_state_manager.update_execution_state(session_state, current_phase="generating")
# Refresh messages from session state before API call
messages = session_state_manager.get_conversation_history(session_state)
logger.debug(f"Making API call to {model} with {len(messages)} messages")
# Prepare request data for logging
request_data = {
"messages": clean_messages_for_api(messages),
"model": model,
"tools": tools,
"tool_choice": "auto"
}
# Prepare session metadata for Phoenix tracing
session_metadata = {
"turn": turns,
"max_turns": MAX_TURNS,
"model": model,
"tools_count": len(tools),
"messages_count": len(messages),
"current_phase": "generating"
}
# Add hardware config if available
if "hardware_config" in session_state:
hw_config = session_state["hardware_config"]
session_metadata.update({
"gpu_type": hw_config.get("gpu_type", "unknown"),
"cpu_cores": hw_config.get("cpu_cores", "unknown"),
"memory_gb": hw_config.get("memory_gb", "unknown")
})
# Wrap OpenAI API call with Phoenix session context for proper grouping
with create_phoenix_session_context(
session_id=session_state_manager.session_id,
user_id=None, # Could be extracted from request context if available
metadata=session_metadata
):
logger.debug(f"Making OpenAI API call with Phoenix session context: {session_state_manager.session_id}")
response = client.chat.completions.create(**request_data)
logger.debug("API call successful within Phoenix session context")
# Log the complete LLM interaction
session_state_manager.log_llm_interaction(
session_state, request_data, response.model_dump(), model, turns
)
except Exception as e:
# Handle inference client errors
logger.error(f"Inference failed on turn {turns}: {str(e)}")
# Add detailed error information to the notebook
error_message = str(e)
if "429" in error_message or "too_many_requests" in error_message.lower():
detailed_error = f"""**API Rate Limit Exceeded** 🚫
The inference service has reached its rate limit. This typically means:
- Too many requests have been sent in a short period
- Daily quota has been exceeded
- Service is temporarily overloaded
**What you can try:**
- Wait a few minutes and try again
- If using Cerebras API, check your daily quota
- Try using a different model or service
- Contact support if the issue persists
**Technical details:**
```
{error_message}
```"""
elif "401" in error_message or "unauthorized" in error_message.lower():
detailed_error = f"""**Authentication Error** 🔐
There's an issue with API authentication:
- API key might be missing or invalid
- API key might have expired
- Insufficient permissions
**Technical details:**
```
{error_message}
```"""
elif "500" in error_message or "internal" in error_message.lower():
detailed_error = f"""**Server Error** 🔧
The inference service encountered an internal error:
- Service might be temporarily unavailable
- Try again in a few moments
- If the issue persists, it's likely a service-side problem
**Technical details:**
```
{error_message}
```"""
else:
detailed_error = f"""**Inference Service Error** ⚠️
An error occurred while communicating with the AI service:
**Technical details:**
```
{error_message}
```
**What you can try:**
- Check your internet connection
- Try again in a few moments
- If the problem persists, contact support"""
notebook.add_error(detailed_error)
# Add error to session state
session_state_manager.add_message(
session_state, "assistant", detailed_error,
metadata={"type": "error", "error_type": "api_error", "turn": turns}
)
# Update execution state
session_state_manager.update_execution_state(
session_state, is_running=False, last_execution_successful=False
)
# Update notebook data and save state
session_state_manager.update_notebook_data(session_state, notebook.data)
session_state_manager.save_state(session_state)
yield notebook.render(mode="error"), notebook.data, messages
return
# Get the response content and tool calls
full_response = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls or []
logger.debug(f"Turn {turns}: Response content length: {len(full_response)}, Tool calls: {len(tool_calls)}")
# Add markdown cell for assistant's thinking
if full_response.strip():
logger.debug(f"Adding assistant response as markdown ({len(full_response)} chars)")
notebook.add_markdown(full_response, "assistant")
else:
logger.debug("Skipping empty assistant response")
# Handle tool calls and add assistant message to session state only
if tool_calls:
logger.info(f"Processing {len(tool_calls)} tool calls on turn {turns}")
# Add assistant message to session state (messages will be derived from this)
session_state_manager.add_message(
session_state, "assistant", full_response,
tool_calls=[{
"id": tc.id,
"type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments}
} for tc in tool_calls],
metadata={"turn": turns, "type": "thinking"}
)
logger.debug(f"Added assistant message with {len(tool_calls)} tool calls to session state")
elif full_response.strip():
# If no tool calls but we have content, add regular assistant message
session_state_manager.add_message(
session_state, "assistant", full_response,
metadata={"turn": turns, "type": "thinking"}
)
logger.debug("Added regular assistant message to session state")
for i, tool_call in enumerate(tool_calls):
logger.debug(f"Processing tool call {i+1}/{len(tool_calls)}: {tool_call.function.name}")
if tool_call.function.name == "add_and_execute_jupyter_code_cell":
# Update phase to executing code
session_state_manager.update_execution_state(session_state, current_phase="executing_code")
logger.debug(f"Processing code execution tool call: {tool_call.id}")
tool_args = json.loads(tool_call.function.arguments)
code = tool_args["code"]
logger.debug(f"Code to execute: {len(code)} characters")
# Determine if we should reuse the last cell or create a new one
# Reuse if there were errors (not just warnings) in the previous execution
should_reuse_cell = (previous_execution_had_error and
notebook.get_last_cell_type() == "code")
if should_reuse_cell:
logger.info("Reusing last code cell due to previous execution error")
# Update the existing cell's code instead of creating a new one
notebook.update_last_code_cell(code)
else:
logger.debug("Creating new code cell")
# Create a new cell (normal behavior)
notebook.add_code(code)
logger.debug("Yielding notebook in 'executing' mode")
yield notebook.render(mode="executing"), notebook.data, messages
try:
# Check for stop event before execution
if stop_event and stop_event.is_set():
logger.info("Stop event detected before code execution")
stopped_message = """**Execution Stopped** ⏸️
The execution was stopped by user request before the code could run."""
notebook.add_markdown(stopped_message, "assistant")
yield notebook.render(mode="stopped"), notebook.data, messages
return
# Execution sandbox call - might timeout
logger.info("Executing code in sandbox")
execution = sbx.run_code(code)
notebook.append_execution(execution)
# Update error and warning tracking for next iteration
previous_execution_had_error = notebook.has_execution_error(execution)
previous_execution_had_warnings = notebook.has_execution_warnings(execution)
# Log tool execution in session state
tool_args = json.loads(tool_call.function.arguments)
tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output)
session_state_manager.log_tool_execution(
session_state, tool_call.id, "add_and_execute_jupyter_code_cell",
tool_args, tool_response_content, execution
)
if previous_execution_had_error:
logger.warning("Code execution resulted in error")
elif previous_execution_had_warnings:
logger.info("Code execution completed with warnings")
else:
logger.info("Code execution completed successfully")
except Exception as e:
# Handle sandbox timeout/execution errors
logger.error(f"Code execution failed: {str(e)}")
# Add detailed error information for code execution failures
error_message = str(e)
if "timeout" in error_message.lower():
detailed_error = f"""**Code Execution Timeout** ⏰
The code execution took too long and was terminated:
- Code may have entered an infinite loop
- Processing large datasets can cause timeouts
- Complex computations may exceed time limits
**What you can try:**
- Optimize your code for better performance
- Break down complex operations into smaller steps
- Increase the timeout limit in settings
- Check for infinite loops or blocking operations
**Technical details:**
```
{error_message}
```"""
else:
detailed_error = f"""**Code Execution Failed** 💥
An error occurred while executing the code in the sandbox:
**Technical details:**
```
{error_message}
```
**What you can try:**
- Check the code for syntax errors
- Verify all required packages are available
- Try simplifying the code
- Check the sandbox logs for more details"""
notebook.add_error(detailed_error)
yield notebook.render(mode="error"), notebook.data, messages
return
# Prepare tool response (already computed above)
raw_execution = notebook.parse_exec_result_nb(execution)
logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs")
# Add tool response to session state only
session_state_manager.add_message(
session_state, "tool", tool_response_content,
tool_call_id=tool_call.id, raw_execution=raw_execution,
metadata={"turn": turns, "execution_successful": not previous_execution_had_error}
)
elif tool_call.function.name == "web_search":
# Update phase to searching
session_state_manager.update_execution_state(session_state, current_phase="searching")
logger.debug(f"Processing search tool call: {tool_call.id}")
tool_args = json.loads(tool_call.function.arguments)
query = tool_args["query"]
logger.debug(f"Search query: '{query}' ({len(query)} chars)")
# Add search status to notebook
notebook.add_markdown("🔍 **Searching the web...**", "assistant")
yield notebook.render(mode="generating"), notebook.data, messages
try:
# Perform search
search_results = web_search(query)
logger.info("Search completed successfully")
# Log search tool execution
tool_args = json.loads(tool_call.function.arguments)
session_state_manager.log_tool_execution(
session_state, tool_call.id, "web_search",
tool_args, search_results
)
# Add search results to notebook
notebook.add_markdown(search_results, "assistant")
# Add tool response to session state only
session_state_manager.add_message(
session_state, "tool", search_results,
tool_call_id=tool_call.id,
metadata={"turn": turns, "search_successful": True}
)
except Exception as e:
error_message = f"❌ Search failed: {str(e)}"
logger.error(f"Search tool call failed: {str(e)}")
# Log failed search
tool_args = json.loads(tool_call.function.arguments)
session_state_manager.log_tool_execution(
session_state, tool_call.id, "web_search",
tool_args, error_message
)
# Add error to notebook
notebook.add_markdown(error_message, "assistant")
# Add error response to session state only
session_state_manager.add_message(
session_state, "tool", error_message,
tool_call_id=tool_call.id,
metadata={"turn": turns, "search_successful": False, "error": str(e)}
)
elif tool_call.function.name == "edit_and_execute_current_cell":
# Update phase to executing code
session_state_manager.update_execution_state(session_state, current_phase="executing_code")
logger.debug(f"Processing edit current cell tool call: {tool_call.id}")
tool_args = json.loads(tool_call.function.arguments)
code = tool_args["code"]
logger.debug(f"Code to execute in current cell: {len(code)} characters")
# Check if we have a code cell to edit
if notebook.get_last_cell_type() == "code":
logger.info("Editing last code cell with new code")
notebook.update_last_code_cell(code)
else:
logger.info("No code cell to edit, creating new cell")
notebook.add_code(code)
logger.debug("Yielding notebook in 'executing' mode")
yield notebook.render(mode="executing"), notebook.data, messages
try:
# Check for stop event before execution
if stop_event and stop_event.is_set():
logger.info("Stop event detected before code execution")
stopped_message = """**Execution Stopped** ⏸️
The execution was stopped by user request before the code could run."""
notebook.add_markdown(stopped_message, "assistant")
yield notebook.render(mode="stopped"), notebook.data, messages
return
# Execution sandbox call - might timeout
logger.info("Executing edited code in sandbox")
execution = sbx.run_code(code)
notebook.append_execution(execution)
# Update error and warning tracking for next iteration
previous_execution_had_error = notebook.has_execution_error(execution)
previous_execution_had_warnings = notebook.has_execution_warnings(execution)
# Log tool execution in session state
tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output)
session_state_manager.log_tool_execution(
session_state, tool_call.id, "edit_and_execute_current_cell",
tool_args, tool_response_content, execution
)
if previous_execution_had_error:
logger.warning("Edited code execution resulted in error")
elif previous_execution_had_warnings:
logger.info("Edited code execution completed with warnings")
else:
logger.info("Edited code execution completed successfully")
except Exception as e:
# Handle sandbox timeout/execution errors
logger.error(f"Edited code execution failed: {str(e)}")
# Add detailed error information for code execution failures
error_message = str(e)
if "timeout" in error_message.lower():
detailed_error = f"""**Code Execution Timeout** ⏰
The edited code execution took too long and was terminated:
- Code may have entered an infinite loop
- Processing large datasets can cause timeouts
- Complex computations may exceed time limits
**What you can try:**
- Optimize your code for better performance
- Break down complex operations into smaller steps
- Increase the timeout limit in settings
- Check for infinite loops or blocking operations
**Technical details:**
```
{error_message}
```"""
else:
detailed_error = f"""**Code Execution Failed** 💥
An error occurred while executing the edited code in the sandbox:
**Technical details:**
```
{error_message}
```
**What you can try:**
- Check the code for syntax errors
- Verify all required packages are available
- Try simplifying the code
- Check the sandbox logs for more details"""
notebook.add_error(detailed_error)
yield notebook.render(mode="error"), notebook.data, messages
return
# Prepare tool response
raw_execution = notebook.parse_exec_result_nb(execution)
logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs")
# Add tool response to session state only
session_state_manager.add_message(
session_state, "tool", tool_response_content,
tool_call_id=tool_call.id, raw_execution=raw_execution,
metadata={"turn": turns, "execution_successful": not previous_execution_had_error, "action": "edit_cell"}
)
elif tool_call.function.name == "execute_shell_command":
# Update phase to executing shell command
session_state_manager.update_execution_state(session_state, current_phase="executing_shell")
logger.debug(f"Processing shell command tool call: {tool_call.id}")
tool_args = json.loads(tool_call.function.arguments)
command = tool_args["command"]
logger.debug(f"Shell command to execute: '{command}'")
# Add shell command to notebook with special styling
notebook.add_shell_command(command)
logger.debug("Yielding notebook in 'executing' mode")
yield notebook.render(mode="executing"), notebook.data, messages
try:
# Check for stop event before execution
if stop_event and stop_event.is_set():
logger.info("Stop event detected before shell execution")
stopped_message = """**Execution Stopped** ⏸️
The execution was stopped by user request before the shell command could run."""
notebook.add_markdown(stopped_message, "assistant")
yield notebook.render(mode="stopped"), notebook.data, messages
return
# Execute shell command in sandbox using raw shell execution
logger.info(f"Executing raw shell command in sandbox: {command}")
try:
# Use the new raw shell execution method
if hasattr(sbx, 'run_shell'):
shell_execution = sbx.run_shell(command, timeout=60)
logger.info("Shell command executed using raw shell method")
else:
# Fallback: Execute shell command using Python subprocess within sandbox
shell_code = f"""
import subprocess
import sys
try:
result = subprocess.run(
{repr(command)},
shell=True,
capture_output=True,
text=True,
timeout=60
)
if result.stdout:
print("STDOUT:")
print(result.stdout)
if result.stderr:
print("STDERR:")
print(result.stderr)
print(f"Exit code: {{result.returncode}}")
except subprocess.TimeoutExpired:
print("Error: Command timed out after 60 seconds")
except Exception as e:
print(f"Error executing command: {{e}}")
"""
shell_execution = sbx.run_code(shell_code)
logger.info("Shell command executed via Python subprocess fallback")
# Add shell execution results to notebook
notebook.append_shell_execution(shell_execution)
# Prepare response content for LLM
shell_response_content = parse_exec_result_llm(shell_execution, max_code_output=max_code_output)
# Log tool execution in session state
session_state_manager.log_tool_execution(
session_state, tool_call.id, "execute_shell_command",
tool_args, shell_response_content, shell_execution
)
# Check for errors
shell_had_error = notebook.has_execution_error(shell_execution)
if shell_had_error:
logger.warning("Shell command execution resulted in error")
else:
logger.info("Shell command execution completed successfully")
except Exception as shell_error:
logger.error(f"Shell command execution failed: {str(shell_error)}")
# Create error message
detailed_error = f"""**Shell Command Failed** 🔧
An error occurred while executing the shell command:
**Command:** `{command}`
**Technical details:**
```
{str(shell_error)}
```
**What you can try:**
- Check if the command exists in the sandbox environment
- Verify command syntax
- Try a simpler version of the command
- Check if required tools/packages are installed"""
notebook.add_error(detailed_error)
# Log failed execution
session_state_manager.log_tool_execution(
session_state, tool_call.id, "execute_shell_command",
tool_args, detailed_error
)
yield notebook.render(mode="error"), notebook.data, messages
return
except Exception as e:
# Handle general execution errors
logger.error(f"Shell command execution failed: {str(e)}")
detailed_error = f"""**Shell Execution Error** ⚠️
An unexpected error occurred while executing the shell command:
**Command:** `{command}`
**Technical details:**
```
{str(e)}
```"""
notebook.add_error(detailed_error)
yield notebook.render(mode="error"), notebook.data, messages
return
# Prepare tool response for LLM and session state
raw_execution = notebook.parse_exec_result_nb(shell_execution)
logger.debug(f"Shell tool response: {len(shell_response_content)} chars content")
# Add tool response to session state
session_state_manager.add_message(
session_state, "tool", shell_response_content,
tool_call_id=tool_call.id, raw_execution=raw_execution,
metadata={"turn": turns, "command": command, "execution_successful": not shell_had_error, "action": "shell_command"}
)
else:
logger.warning(f"Unknown tool call function: {tool_call.function.name}")
if not tool_calls:
logger.info(f"No tool calls on turn {turns}, conversation ending")
if len(full_response.strip())==0:
logger.error("Assistant provided no content and no tool calls")
notebook.add_error(f"No tool call and empty assistant response:\n{response.model_dump_json(indent=2)}")
# Only add the final assistant message if we didn't already add it above
# (in the elif full_response.strip() block)
if full_response.strip():
# Since we're now only using session state, we can safely add the message
# The session state manager will handle any deduplication if needed
session_state_manager.add_message(
session_state, "assistant", full_response,
metadata={"turn": turns, "type": "final_response"}
)
logger.debug("Added final assistant response to session state")
done = True
# Update session state after each turn
session_state_manager.update_execution_state(
session_state, current_turn=turns, last_execution_successful=not previous_execution_had_error
)
session_state_manager.update_notebook_data(session_state, notebook.data)
session_state_manager.save_state(session_state)
if done:
logger.info(f"Interactive notebook completed after {turns} turns")
session_state_manager.update_execution_state(
session_state, is_running=False, sandbox_active=True
)
session_state_manager.save_state(session_state)
yield notebook.render(mode="done"), notebook.data, messages
else:
logger.debug(f"Turn {turns} completed, yielding in 'generating' mode")
yield notebook.render(mode="generating"), notebook.data, messages
if turns > MAX_TURNS:
logger.warning(f"Interactive notebook reached maximum turns ({MAX_TURNS})")
error_msg = f"**Maximum Turns Reached** 🔄\n\nThe conversation has reached the maximum number of turns ({MAX_TURNS}). This is a safety limit to prevent infinite loops.\n\n**What you can try:**\n- Start a new conversation\n- Clear the notebook and begin fresh\n- Contact support if you need a higher turn limit"
notebook.add_error(error_msg)
# Add error to session state
session_state_manager.add_message(
session_state, "assistant", error_msg,
metadata={"type": "error", "error_type": "max_turns_exceeded", "turn": turns}
)
# Update final state
session_state_manager.update_execution_state(
session_state, is_running=False, last_execution_successful=False
)
session_state_manager.update_notebook_data(session_state, notebook.data)
session_state_manager.save_state(session_state)
yield notebook.render(mode="error"), notebook.data, messages
elif stop_event and stop_event.is_set():
logger.info("Interactive notebook stopped by user")
# Add a stopped message to the notebook
stopped_message = """**Execution Stopped** ⏸️
The execution was stopped by user request. You can resume by clicking Run again."""
notebook.add_markdown(stopped_message, "assistant")
# Add stopped message to session state
session_state_manager.add_message(
session_state, "assistant", stopped_message,
metadata={"type": "status", "status_type": "stopped_by_user", "turn": turns}
)
# Update state to indicate pause
session_state_manager.update_execution_state(
session_state, is_running=False, is_paused=True
)
session_state_manager.update_notebook_data(session_state, notebook.data)
session_state_manager.save_state(session_state)
yield notebook.render(mode="stopped"), notebook.data, messages
def run_interactive_notebook(client, model, messages, sbx, stop_event=None, tools=None):
"""Backward compatibility wrapper for the new session state system"""
logger.warning("Using legacy run_interactive_notebook - this should be replaced with session state version")
# Create a temporary session for backward compatibility
import uuid
temp_session_id = str(uuid.uuid4())[:8]
session_manager = SessionStateManager(temp_session_id)
# Create basic session state
session_state = session_manager.create_initial_state(
hardware_config={"gpu_type": "unknown", "cpu_cores": 2, "memory_gb": 8, "timeout_sec": 300},
api_config={"model_name": model, "provider_type": "unknown"},
environment={"variables": "", "files_uploaded": []},
system_prompt=messages[0].get("content", "") if messages and messages[0].get("role") == "system" else ""
)
# Initialize conversation history with provided messages
session_state["conversation_history"] = messages
# Use the new session-based function
yield from run_interactive_notebook_with_session_state(
client, model, session_manager, session_state, sbx, stop_event, tools
)