from jupyter_handler import JupyterNotebook import json import logging import os import datetime from pathlib import Path from typing import Dict, List, Any, Optional from tavily import TavilyClient # Phoenix tracing imports try: from openinference.instrumentation import using_session PHOENIX_AVAILABLE = True print("Phoenix session tracking imports successful") except ImportError: PHOENIX_AVAILABLE = False print("Phoenix session tracking not available - missing openinference packages") # Configure logging for utils module logger = logging.getLogger(__name__) # Initialize Tavily client TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") tavily_client = TavilyClient(api_key=TAVILY_API_KEY) if TAVILY_API_KEY else None TOOLS = [ { "type": "function", "function": { "name": "add_and_execute_jupyter_code_cell", "description": "A Python code execution environment that runs code in a Jupyter notebook interface. This is stateful - variables and imports persist between executions.", "parameters": { "type": "object", "properties": { "code": { "type": "string", "description": "The Python code to execute." } }, "required": ["code"] } } }, { "type": "function", "function": { "name": "edit_and_execute_current_cell", "description": "Edit the current/last code cell and execute the new code. Use this to fix errors or modify the previous code instead of creating a new cell.", "parameters": { "type": "object", "properties": { "code": { "type": "string", "description": "The updated Python code to replace the current cell with and execute." } }, "required": ["code"] } } }, { "type": "function", "function": { "name": "execute_shell_command", "description": "Execute shell/system commands like ls, cat, mkdir, etc. This runs independently of Python and provides terminal-style output.", "parameters": { "type": "object", "properties": { "command": { "type": "string", "description": "The shell command to execute (e.g., 'ls -la', 'cat file.txt', 'mkdir new_folder')." } }, "required": ["command"] } } }, { "type": "function", "function": { "name": "web_search", "description": "Search the web for current information, documentation, tutorials, and solutions to coding problems. Use this to get context before starting tasks or when encountering errors.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "Search query (max 400 characters). Be specific and include relevant keywords." } }, "required": ["query"] } } }, ] # TOOLS = TOOLS[:1] MAX_TURNS = 20 def create_phoenix_session_context(session_id: str, user_id: str = None, metadata: Dict = None): """ Create a Phoenix session context for tracing LLM interactions. Args: session_id: Unique identifier for the session user_id: Optional user identifier metadata: Additional metadata to include in traces Returns: Context manager for Phoenix session tracking """ if not PHOENIX_AVAILABLE: # Return a no-op context manager if Phoenix is not available from contextlib import nullcontext return nullcontext() try: # Use using_session for proper session grouping in Phoenix # This ensures all LLM calls within this context are grouped under the same session logger.debug(f"Creating Phoenix session context for session_id: {session_id}") return using_session(session_id) except Exception as e: logger.warning(f"Failed to create Phoenix session context for {session_id}: {e}") # Fallback to no-op context if Phoenix session creation fails from contextlib import nullcontext return nullcontext() class SessionStateManager: """Manages comprehensive session state in a single JSON file""" def __init__(self, session_id: str, base_dir: str = './temp/'): self.session_id = session_id self.base_dir = Path(base_dir) self.session_dir = self.base_dir / session_id self.state_file = self.session_dir / 'session_state.json' self.session_dir.mkdir(parents=True, exist_ok=True) logger.info(f"SessionStateManager initialized for {session_id}") def create_initial_state(self, hardware_config: Dict, api_config: Dict, environment: Dict, system_prompt: str) -> Dict: """Create initial session state structure""" timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() initial_state = { "session_id": self.session_id, "created_at": timestamp, "last_updated": timestamp, "version": "1.0", "hardware_config": hardware_config, "api_config": api_config, "environment": environment, "conversation_history": [ { "role": "system", "content": system_prompt, "timestamp": timestamp, "metadata": {"type": "system_initialization"} } ], "llm_interactions": [], # Complete API call logs "tool_executions": [], # All tool calls and results "notebook_data": { "cells": [], "metadata": { "kernel_info": {"name": "python3"}, "language_info": {"name": "python", "version": "3.12"}, }, "nbformat": 4, "nbformat_minor": 0 }, "execution_state": { "current_turn": 0, "max_turns": MAX_TURNS, "is_running": False, "is_paused": False, "last_execution_successful": None, "sandbox_active": False, "sandbox_info": None }, "session_stats": { "total_messages": 1, "total_code_executions": 0, "total_searches": 0, "total_errors": 0, "session_duration_seconds": 0 } } logger.info("Created initial session state for %s", self.session_id) return initial_state def load_state(self) -> Optional[Dict]: """Load session state from file with improved error handling""" if not self.state_file.exists(): logger.info(f"No existing session state found for {self.session_id}") return None try: with open(self.state_file, 'r', encoding='utf-8') as f: state = json.load(f) logger.info(f"Loaded session state for {self.session_id} with {len(state.get('conversation_history', []))} messages") return state except json.JSONDecodeError as e: logger.error(f"JSON corruption in session state for {self.session_id}: {str(e)}") logger.info(f"Creating backup of corrupted file: {self.state_file}.corrupted") try: import shutil shutil.copy2(self.state_file, str(self.state_file) + ".corrupted") logger.info(f"Backup created successfully") except Exception as backup_error: logger.warning(f"Failed to create backup: {backup_error}") return None except Exception as e: logger.error(f"Failed to load session state for {self.session_id}: {str(e)}") return None def save_state(self, state: Dict) -> bool: """Save session state to file with improved error handling""" try: # Update last_updated timestamp state["last_updated"] = datetime.datetime.now(datetime.timezone.utc).isoformat() # Update session stats if "session_stats" not in state: state["session_stats"] = {} created_at = datetime.datetime.fromisoformat(state["created_at"]) current_time = datetime.datetime.now(datetime.timezone.utc) state["session_stats"]["session_duration_seconds"] = int((current_time - created_at).total_seconds()) state["session_stats"]["total_messages"] = len(state.get("conversation_history", [])) # Validate JSON serializability before writing try: json.dumps(state, ensure_ascii=False) except (TypeError, ValueError) as e: logger.error(f"State contains non-serializable data: {e}") logger.info("Attempting to clean non-serializable data...") state = self._clean_non_serializable_data(state) # Write to temporary file first, then rename for atomic operation temp_file = self.state_file.with_suffix('.tmp') with open(temp_file, 'w', encoding='utf-8') as f: json.dump(state, f, indent=2, ensure_ascii=False) # Atomic rename temp_file.replace(self.state_file) logger.debug(f"Saved session state for {self.session_id} ({len(json.dumps(state))} characters)") return True except Exception as e: logger.error(f"Failed to save session state for {self.session_id}: {str(e)}") # Clean up temp file if it exists temp_file = self.state_file.with_suffix('.tmp') if temp_file.exists(): try: temp_file.unlink() except Exception: pass return False def _clean_non_serializable_data(self, obj): """Recursively clean non-serializable data from objects""" if isinstance(obj, dict): cleaned = {} for key, value in obj.items(): try: json.dumps(value) cleaned[key] = self._clean_non_serializable_data(value) except (TypeError, ValueError): logger.warning(f"Removing non-serializable field: {key}") cleaned[key] = f"" return cleaned elif isinstance(obj, list): cleaned = [] for item in obj: try: json.dumps(item) cleaned.append(self._clean_non_serializable_data(item)) except (TypeError, ValueError): cleaned.append(f"") return cleaned else: return obj def log_llm_interaction(self, state: Dict, request_data: Dict, response_data: Dict, model: str, turn: int) -> None: """Log complete LLM API interaction""" timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() interaction = { "timestamp": timestamp, "turn": turn, "model": model, "request": { "messages_count": len(request_data.get("messages", [])), "tools_count": len(request_data.get("tools", [])), "model": request_data.get("model"), "tool_choice": request_data.get("tool_choice") }, "response": { "content": response_data.get("choices", [{}])[0].get("message", {}).get("content"), "tool_calls": response_data.get("choices", [{}])[0].get("message", {}).get("tool_calls"), "finish_reason": response_data.get("choices", [{}])[0].get("finish_reason"), "usage": response_data.get("usage") } } if "llm_interactions" not in state: state["llm_interactions"] = [] state["llm_interactions"].append(interaction) # Log Phoenix session information for easy debugging logger.debug(f"Logged LLM interaction for turn {turn} in session {self.session_id}") logger.debug(f"Phoenix session tracking: session_id={self.session_id}, turn={turn}, model={model}") # Log usage information if available for monitoring usage = response_data.get("usage") if usage: logger.info(f"Session {self.session_id} turn {turn}: " f"prompt_tokens={usage.get('prompt_tokens', 0)}, " f"completion_tokens={usage.get('completion_tokens', 0)}, " f"total_tokens={usage.get('total_tokens', 0)}") def log_tool_execution(self, state: Dict, tool_call_id: str, tool_name: str, tool_args: Dict, result: str, execution_data: Any = None) -> None: """Log tool execution with full details""" timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() # Safely serialize execution_data to prevent JSON corruption safe_execution_data = None if execution_data is not None: try: # Convert execution_data to a safe, serializable format if hasattr(execution_data, '__dict__'): safe_execution_data = { "type": type(execution_data).__name__, "error": str(execution_data.error) if hasattr(execution_data, 'error') and execution_data.error else None, "has_results": hasattr(execution_data, 'results') and bool(execution_data.results), "has_stdout": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stdout') and bool(execution_data.logs.stdout), "has_stderr": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stderr') and bool(execution_data.logs.stderr) } else: # For simple types, convert to string safely safe_execution_data = str(execution_data)[:200] # Limit length except Exception as e: logger.warning(f"Failed to serialize execution_data for {tool_call_id}: {e}") safe_execution_data = {"serialization_error": str(e)} tool_execution = { "timestamp": timestamp, "tool_call_id": tool_call_id, "tool_name": tool_name, "arguments": tool_args, "result_summary": result[:500] + "..." if len(result) > 500 else result, "result_length": len(result), "execution_data": safe_execution_data, "success": execution_data is None or (hasattr(execution_data, 'error') and execution_data.error is None) if execution_data else True } if "tool_executions" not in state: state["tool_executions"] = [] state["tool_executions"].append(tool_execution) # Update stats if tool_name == "add_and_execute_jupyter_code_cell": state["session_stats"]["total_code_executions"] = state["session_stats"].get("total_code_executions", 0) + 1 elif tool_name == "web_search": state["session_stats"]["total_searches"] = state["session_stats"].get("total_searches", 0) + 1 if not tool_execution["success"]: state["session_stats"]["total_errors"] = state["session_stats"].get("total_errors", 0) + 1 logger.debug(f"Logged tool execution {tool_name} ({tool_call_id}) in session {self.session_id}") def add_message(self, state: Dict, role: str, content: str, tool_calls: List = None, tool_call_id: str = None, raw_execution: Any = None, metadata: Dict = None) -> None: """Add message to conversation history with full context""" timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() message = { "role": role, "content": content, "timestamp": timestamp } if tool_calls: message["tool_calls"] = tool_calls if tool_call_id: message["tool_call_id"] = tool_call_id if raw_execution: message["raw_execution"] = raw_execution if metadata: message["metadata"] = metadata state["conversation_history"].append(message) logger.debug(f"Added {role} message to session {self.session_id} conversation history") def update_execution_state(self, state: Dict, **kwargs) -> None: """Update execution state fields""" for key, value in kwargs.items(): if key in state["execution_state"]: state["execution_state"][key] = value logger.debug(f"Updated execution state {key}={value} for session {self.session_id}") # Try to sync with global EXECUTION_STATES for UI consistency (if available) try: import sys if 'app' in sys.modules: execution_states = getattr(sys.modules['app'], 'EXECUTION_STATES', None) if execution_states and self.session_id in execution_states: for key, value in kwargs.items(): execution_states[self.session_id][key] = value except (ImportError, AttributeError): pass # Ignore if we can't sync with global state def update_notebook_data(self, state: Dict, notebook_data: Dict) -> None: """Update notebook data in session state""" state["notebook_data"] = notebook_data logger.debug(f"Updated notebook data for session {self.session_id} ({len(notebook_data.get('cells', []))} cells)") def get_conversation_history(self, state: Dict) -> List[Dict]: """Get conversation history suitable for LLM API calls""" return state.get("conversation_history", []) def validate_and_repair_conversation(self, state: Dict) -> None: """Validate and repair conversation history to ensure tool calls have responses""" conversation = state.get("conversation_history", []) if not conversation: return pending_tool_calls = set() valid_messages = [] for message in conversation: if message.get("role") == "assistant" and message.get("tool_calls"): # Track tool calls for tool_call in message["tool_calls"]: pending_tool_calls.add(tool_call["id"]) valid_messages.append(message) elif message.get("role") == "tool" and message.get("tool_call_id"): # Remove from pending when we find a response pending_tool_calls.discard(message["tool_call_id"]) valid_messages.append(message) else: # Regular message (system, user, assistant without tool calls) valid_messages.append(message) # If there are incomplete tool calls, remove the assistant messages that created them if pending_tool_calls: logger.warning(f"Found incomplete tool calls in conversation: {pending_tool_calls}") logger.warning("Removing incomplete assistant messages to repair conversation") repaired_messages = [] for message in valid_messages: if (message.get("role") == "assistant" and message.get("tool_calls") and any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])): logger.debug("Removing assistant message with incomplete tool calls") continue repaired_messages.append(message) # Update conversation history state["conversation_history"] = repaired_messages logger.info(f"Repaired conversation: {len(conversation)} -> {len(repaired_messages)} messages") # Save the repaired state self.save_state(state) def session_exists(self) -> bool: """Check if session state file exists""" return self.state_file.exists() def get_session_summary(self, state: Dict) -> str: """Get human-readable session summary""" stats = state.get("session_stats", {}) created = datetime.datetime.fromisoformat(state["created_at"]) return f"""Session {self.session_id}: - Created: {created.strftime('%Y-%m-%d %H:%M:%S UTC')} - Messages: {stats.get('total_messages', 0)} - Code Executions: {stats.get('total_code_executions', 0)} - Web Searches: {stats.get('total_searches', 0)} - Errors: {stats.get('total_errors', 0)} - Duration: {stats.get('session_duration_seconds', 0)}s - Hardware: {state.get('hardware_config', {}).get('gpu_type', 'unknown')} - Model: {state.get('api_config', {}).get('model_name', 'unknown')}""" def execute_code(sbx, code): logger.debug(f"Executing code in sandbox ({len(code)} characters)") execution = sbx.run_code(code, on_stdout=lambda data: logger.debug(f'stdout: {data}')) output = "" if len(execution.logs.stdout) > 0: output += "\n".join(execution.logs.stdout) logger.debug(f"Execution produced {len(execution.logs.stdout)} stdout lines") if len(execution.logs.stderr) > 0: output += "\n".join(execution.logs.stderr) logger.debug(f"Execution produced {len(execution.logs.stderr)} stderr lines") if execution.error is not None: output += execution.error.traceback logger.warning(f"Execution error: {execution.error.name}: {execution.error.value}") logger.debug(f"Code execution completed, output length: {len(output)}") return output, execution def parse_exec_result_llm(execution, max_code_output=1000): logger.debug(f"Parsing execution result for LLM (max_output: {max_code_output})") output = [] def truncate_if_needed(text): if len(text) > max_code_output: return (text[:max_code_output] + f"\n[Output is truncated as it is more than {max_code_output} characters]") return text if execution.results: results_text_parts = [] plot_count = 0 for result in execution.results: if hasattr(result, 'text') and result.text: results_text_parts.append(result.text) elif hasattr(result, 'png') and result.png: plot_count += 1 results_text_parts.append(f"[Plot {plot_count} generated and displayed]") elif hasattr(result, 'html') and result.html: results_text_parts.append("[HTML output generated]") if results_text_parts: results_text = "\n".join(results_text_parts) output.append(truncate_if_needed(results_text)) logger.debug(f"Added {len(execution.results)} execution results (including {plot_count} plots)") if execution.logs.stdout: stdout_text = "\n".join(execution.logs.stdout) output.append(truncate_if_needed(stdout_text)) logger.debug(f"Added stdout output ({len(execution.logs.stdout)} lines)") if execution.logs.stderr: stderr_text = "\n".join(execution.logs.stderr) output.append(truncate_if_needed(stderr_text)) logger.debug(f"Added stderr output ({len(execution.logs.stderr)} lines)") if execution.error is not None: output.append(truncate_if_needed(execution.error.traceback)) logger.debug(f"Added error traceback: {execution.error.name}") final_output = "\n".join(output) logger.debug(f"Parsed execution result for LLM: {len(final_output)} characters") return final_output def clean_messages_for_api(messages): """ Create a clean copy of messages without raw_execution fields and metadata for API calls. Also validates that tool calls have corresponding tool responses. This prevents 413 errors and API validation errors. """ logger.debug(f"Cleaning {len(messages)} messages for API call") cleaned_messages = [] raw_execution_count = 0 metadata_count = 0 pending_tool_calls = set() for message in messages: cleaned_message = message.copy() # Remove raw_execution data if "raw_execution" in cleaned_message: cleaned_message.pop("raw_execution") raw_execution_count += 1 # Remove metadata and timestamp if "metadata" in cleaned_message: cleaned_message.pop("metadata") metadata_count += 1 if "timestamp" in cleaned_message: cleaned_message.pop("timestamp") # Track tool calls and responses for validation if cleaned_message.get("role") == "assistant" and cleaned_message.get("tool_calls"): for tool_call in cleaned_message["tool_calls"]: pending_tool_calls.add(tool_call["id"]) elif cleaned_message.get("role") == "tool" and cleaned_message.get("tool_call_id"): pending_tool_calls.discard(cleaned_message["tool_call_id"]) cleaned_messages.append(cleaned_message) # If there are pending tool calls without responses, remove the assistant message with tool calls if pending_tool_calls: logger.warning(f"Found {len(pending_tool_calls)} tool calls without responses: {pending_tool_calls}") logger.warning("Removing incomplete tool call messages to prevent API errors") # Remove messages with incomplete tool calls filtered_messages = [] for message in cleaned_messages: if (message.get("role") == "assistant" and message.get("tool_calls") and any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])): logger.debug("Removing assistant message with incomplete tool calls") continue filtered_messages.append(message) cleaned_messages = filtered_messages logger.debug(f"Cleaned messages: removed raw_execution from {raw_execution_count}, metadata from {metadata_count}") logger.debug(f"Final cleaned message count: {len(cleaned_messages)}") return cleaned_messages def web_search(query): """ Perform web search using Tavily API with automatic year addition and formatting. Args: query (str): Search query (max 400 characters) Returns: str: Formatted search results for LLM consumption """ if not tavily_client: logger.error("Tavily client not initialized - API key missing") return "❌ Search unavailable: Tavily API key not configured" # Validate query length if len(query) > 400: logger.warning(f"Query too long ({len(query)} chars), truncating to 400") query = query[:400] # Add current year to query for more recent results current_year = datetime.datetime.now().year if str(current_year) not in query: # Only add year if query has room for it year_addition = f" {current_year}" if len(query + year_addition) <= 400: query += year_addition logger.debug(f"Added current year to query: {current_year}") logger.info(f"Performing Tavily search: '{query}' ({len(query)} chars)") try: # Perform search with optimized parameters response = tavily_client.search( query=query, search_depth="basic", # Use basic for faster results max_results=5, # Limit results to avoid overwhelming context include_answer=True, # Include AI-generated answer include_raw_content=False, # Don't include raw content to save tokens include_images=False # Don't include images ) logger.info(f"Search completed: {len(response.get('results', []))} results found") # Format results for LLM consumption formatted_results = format_search_results_for_llm(response) logger.debug(f"Formatted search results: {len(formatted_results)} characters") return formatted_results except Exception as e: logger.error(f"Tavily search failed: {str(e)}") return f"❌ Search failed: {str(e)}" def format_search_results_for_llm(response): """Format Tavily search results for LLM consumption""" query = response.get('query', 'Unknown query') results = response.get('results', []) answer = response.get('answer', '') formatted = f"πŸ” **Web Search Results for:** {query}\n\n" if answer: formatted += f"**Quick Answer:** {answer}\n\n" if results: formatted += f"**Found {len(results)} relevant sources:**\n\n" for i, result in enumerate(results, 1): title = result.get('title', 'Untitled') url = result.get('url', '') content = result.get('content', '') score = result.get('score', 0) # Truncate content to reasonable length # if len(content) > 300: # content = content[:300] + "..." formatted += f"**{i}. {title}** (Relevance: {score:.2f})\n" formatted += f" πŸ”— {url}\n" formatted += f" πŸ“„ {content}\n\n" else: formatted += "No results found.\n" return formatted def run_interactive_notebook_with_session_state(client, model, session_state_manager, session_state, sbx, stop_event=None, tools=None): logger.info(f"Starting interactive notebook with session state for {session_state_manager.session_id}") # Get conversation history from session state messages = session_state_manager.get_conversation_history(session_state) notebook = JupyterNotebook(messages) # Update execution state session_state_manager.update_execution_state(session_state, is_running=True, sandbox_active=True, current_phase="initializing") # Use provided tools or default to all tools if tools is None: tools = TOOLS try: sbx_info = sbx.get_info() notebook.add_sandbox_countdown(sbx_info.started_at, sbx_info.end_at) # Store sandbox info in session state session_state["execution_state"]["sandbox_info"] = { "started_at": sbx_info.started_at.isoformat(), "end_at": sbx_info.end_at.isoformat(), "timeout_seconds": int((sbx_info.end_at - sbx_info.started_at).total_seconds()) } logger.debug(f"Added sandbox countdown: {sbx_info.started_at} to {sbx_info.end_at}") except Exception as e: logger.warning(f"Failed to get sandbox info: {str(e)}") logger.debug("Initial notebook yield in 'generating' mode") # Update notebook data in session state session_state_manager.update_notebook_data(session_state, notebook.data) # Save initial state session_state_manager.save_state(session_state) yield notebook.render(mode="generating"), notebook.data, messages max_code_output = 1000 turns = session_state["execution_state"]["current_turn"] done = False previous_execution_had_error = False previous_execution_had_warnings = False logger.info(f"Starting interactive loop from turn {turns} with max_output={max_code_output}, max_turns={MAX_TURNS}") while not done and (turns <= MAX_TURNS) and (stop_event is None or not stop_event.is_set()): turns += 1 logger.info(f"Starting turn {turns}/{MAX_TURNS}") try: # Update phase to generating session_state_manager.update_execution_state(session_state, current_phase="generating") # Refresh messages from session state before API call messages = session_state_manager.get_conversation_history(session_state) logger.debug(f"Making API call to {model} with {len(messages)} messages") # Prepare request data for logging request_data = { "messages": clean_messages_for_api(messages), "model": model, "tools": tools, "tool_choice": "auto" } # Prepare session metadata for Phoenix tracing session_metadata = { "turn": turns, "max_turns": MAX_TURNS, "model": model, "tools_count": len(tools), "messages_count": len(messages), "current_phase": "generating" } # Add hardware config if available if "hardware_config" in session_state: hw_config = session_state["hardware_config"] session_metadata.update({ "gpu_type": hw_config.get("gpu_type", "unknown"), "cpu_cores": hw_config.get("cpu_cores", "unknown"), "memory_gb": hw_config.get("memory_gb", "unknown") }) # Wrap OpenAI API call with Phoenix session context for proper grouping with create_phoenix_session_context( session_id=session_state_manager.session_id, user_id=None, # Could be extracted from request context if available metadata=session_metadata ): logger.debug(f"Making OpenAI API call with Phoenix session context: {session_state_manager.session_id}") response = client.chat.completions.create(**request_data) logger.debug("API call successful within Phoenix session context") # Log the complete LLM interaction session_state_manager.log_llm_interaction( session_state, request_data, response.model_dump(), model, turns ) except Exception as e: # Handle inference client errors logger.error(f"Inference failed on turn {turns}: {str(e)}") # Add detailed error information to the notebook error_message = str(e) if "429" in error_message or "too_many_requests" in error_message.lower(): detailed_error = f"""**API Rate Limit Exceeded** 🚫 The inference service has reached its rate limit. This typically means: - Too many requests have been sent in a short period - Daily quota has been exceeded - Service is temporarily overloaded **What you can try:** - Wait a few minutes and try again - If using Cerebras API, check your daily quota - Try using a different model or service - Contact support if the issue persists **Technical details:** ``` {error_message} ```""" elif "401" in error_message or "unauthorized" in error_message.lower(): detailed_error = f"""**Authentication Error** πŸ” There's an issue with API authentication: - API key might be missing or invalid - API key might have expired - Insufficient permissions **Technical details:** ``` {error_message} ```""" elif "500" in error_message or "internal" in error_message.lower(): detailed_error = f"""**Server Error** πŸ”§ The inference service encountered an internal error: - Service might be temporarily unavailable - Try again in a few moments - If the issue persists, it's likely a service-side problem **Technical details:** ``` {error_message} ```""" else: detailed_error = f"""**Inference Service Error** ⚠️ An error occurred while communicating with the AI service: **Technical details:** ``` {error_message} ``` **What you can try:** - Check your internet connection - Try again in a few moments - If the problem persists, contact support""" notebook.add_error(detailed_error) # Add error to session state session_state_manager.add_message( session_state, "assistant", detailed_error, metadata={"type": "error", "error_type": "api_error", "turn": turns} ) # Update execution state session_state_manager.update_execution_state( session_state, is_running=False, last_execution_successful=False ) # Update notebook data and save state session_state_manager.update_notebook_data(session_state, notebook.data) session_state_manager.save_state(session_state) yield notebook.render(mode="error"), notebook.data, messages return # Get the response content and tool calls full_response = response.choices[0].message.content or "" tool_calls = response.choices[0].message.tool_calls or [] logger.debug(f"Turn {turns}: Response content length: {len(full_response)}, Tool calls: {len(tool_calls)}") # Add markdown cell for assistant's thinking if full_response.strip(): logger.debug(f"Adding assistant response as markdown ({len(full_response)} chars)") notebook.add_markdown(full_response, "assistant") else: logger.debug("Skipping empty assistant response") # Handle tool calls and add assistant message to session state only if tool_calls: logger.info(f"Processing {len(tool_calls)} tool calls on turn {turns}") # Add assistant message to session state (messages will be derived from this) session_state_manager.add_message( session_state, "assistant", full_response, tool_calls=[{ "id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments} } for tc in tool_calls], metadata={"turn": turns, "type": "thinking"} ) logger.debug(f"Added assistant message with {len(tool_calls)} tool calls to session state") elif full_response.strip(): # If no tool calls but we have content, add regular assistant message session_state_manager.add_message( session_state, "assistant", full_response, metadata={"turn": turns, "type": "thinking"} ) logger.debug("Added regular assistant message to session state") for i, tool_call in enumerate(tool_calls): logger.debug(f"Processing tool call {i+1}/{len(tool_calls)}: {tool_call.function.name}") if tool_call.function.name == "add_and_execute_jupyter_code_cell": # Update phase to executing code session_state_manager.update_execution_state(session_state, current_phase="executing_code") logger.debug(f"Processing code execution tool call: {tool_call.id}") tool_args = json.loads(tool_call.function.arguments) code = tool_args["code"] logger.debug(f"Code to execute: {len(code)} characters") # Determine if we should reuse the last cell or create a new one # Reuse if there were errors (not just warnings) in the previous execution should_reuse_cell = (previous_execution_had_error and notebook.get_last_cell_type() == "code") if should_reuse_cell: logger.info("Reusing last code cell due to previous execution error") # Update the existing cell's code instead of creating a new one notebook.update_last_code_cell(code) else: logger.debug("Creating new code cell") # Create a new cell (normal behavior) notebook.add_code(code) logger.debug("Yielding notebook in 'executing' mode") yield notebook.render(mode="executing"), notebook.data, messages try: # Check for stop event before execution if stop_event and stop_event.is_set(): logger.info("Stop event detected before code execution") stopped_message = """**Execution Stopped** ⏸️ The execution was stopped by user request before the code could run.""" notebook.add_markdown(stopped_message, "assistant") yield notebook.render(mode="stopped"), notebook.data, messages return # Execution sandbox call - might timeout logger.info("Executing code in sandbox") execution = sbx.run_code(code) notebook.append_execution(execution) # Update error and warning tracking for next iteration previous_execution_had_error = notebook.has_execution_error(execution) previous_execution_had_warnings = notebook.has_execution_warnings(execution) # Log tool execution in session state tool_args = json.loads(tool_call.function.arguments) tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output) session_state_manager.log_tool_execution( session_state, tool_call.id, "add_and_execute_jupyter_code_cell", tool_args, tool_response_content, execution ) if previous_execution_had_error: logger.warning("Code execution resulted in error") elif previous_execution_had_warnings: logger.info("Code execution completed with warnings") else: logger.info("Code execution completed successfully") except Exception as e: # Handle sandbox timeout/execution errors logger.error(f"Code execution failed: {str(e)}") # Add detailed error information for code execution failures error_message = str(e) if "timeout" in error_message.lower(): detailed_error = f"""**Code Execution Timeout** ⏰ The code execution took too long and was terminated: - Code may have entered an infinite loop - Processing large datasets can cause timeouts - Complex computations may exceed time limits **What you can try:** - Optimize your code for better performance - Break down complex operations into smaller steps - Increase the timeout limit in settings - Check for infinite loops or blocking operations **Technical details:** ``` {error_message} ```""" else: detailed_error = f"""**Code Execution Failed** πŸ’₯ An error occurred while executing the code in the sandbox: **Technical details:** ``` {error_message} ``` **What you can try:** - Check the code for syntax errors - Verify all required packages are available - Try simplifying the code - Check the sandbox logs for more details""" notebook.add_error(detailed_error) yield notebook.render(mode="error"), notebook.data, messages return # Prepare tool response (already computed above) raw_execution = notebook.parse_exec_result_nb(execution) logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs") # Add tool response to session state only session_state_manager.add_message( session_state, "tool", tool_response_content, tool_call_id=tool_call.id, raw_execution=raw_execution, metadata={"turn": turns, "execution_successful": not previous_execution_had_error} ) elif tool_call.function.name == "web_search": # Update phase to searching session_state_manager.update_execution_state(session_state, current_phase="searching") logger.debug(f"Processing search tool call: {tool_call.id}") tool_args = json.loads(tool_call.function.arguments) query = tool_args["query"] logger.debug(f"Search query: '{query}' ({len(query)} chars)") # Add search status to notebook notebook.add_markdown("πŸ” **Searching the web...**", "assistant") yield notebook.render(mode="generating"), notebook.data, messages try: # Perform search search_results = web_search(query) logger.info("Search completed successfully") # Log search tool execution tool_args = json.loads(tool_call.function.arguments) session_state_manager.log_tool_execution( session_state, tool_call.id, "web_search", tool_args, search_results ) # Add search results to notebook notebook.add_markdown(search_results, "assistant") # Add tool response to session state only session_state_manager.add_message( session_state, "tool", search_results, tool_call_id=tool_call.id, metadata={"turn": turns, "search_successful": True} ) except Exception as e: error_message = f"❌ Search failed: {str(e)}" logger.error(f"Search tool call failed: {str(e)}") # Log failed search tool_args = json.loads(tool_call.function.arguments) session_state_manager.log_tool_execution( session_state, tool_call.id, "web_search", tool_args, error_message ) # Add error to notebook notebook.add_markdown(error_message, "assistant") # Add error response to session state only session_state_manager.add_message( session_state, "tool", error_message, tool_call_id=tool_call.id, metadata={"turn": turns, "search_successful": False, "error": str(e)} ) elif tool_call.function.name == "edit_and_execute_current_cell": # Update phase to executing code session_state_manager.update_execution_state(session_state, current_phase="executing_code") logger.debug(f"Processing edit current cell tool call: {tool_call.id}") tool_args = json.loads(tool_call.function.arguments) code = tool_args["code"] logger.debug(f"Code to execute in current cell: {len(code)} characters") # Check if we have a code cell to edit if notebook.get_last_cell_type() == "code": logger.info("Editing last code cell with new code") notebook.update_last_code_cell(code) else: logger.info("No code cell to edit, creating new cell") notebook.add_code(code) logger.debug("Yielding notebook in 'executing' mode") yield notebook.render(mode="executing"), notebook.data, messages try: # Check for stop event before execution if stop_event and stop_event.is_set(): logger.info("Stop event detected before code execution") stopped_message = """**Execution Stopped** ⏸️ The execution was stopped by user request before the code could run.""" notebook.add_markdown(stopped_message, "assistant") yield notebook.render(mode="stopped"), notebook.data, messages return # Execution sandbox call - might timeout logger.info("Executing edited code in sandbox") execution = sbx.run_code(code) notebook.append_execution(execution) # Update error and warning tracking for next iteration previous_execution_had_error = notebook.has_execution_error(execution) previous_execution_had_warnings = notebook.has_execution_warnings(execution) # Log tool execution in session state tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output) session_state_manager.log_tool_execution( session_state, tool_call.id, "edit_and_execute_current_cell", tool_args, tool_response_content, execution ) if previous_execution_had_error: logger.warning("Edited code execution resulted in error") elif previous_execution_had_warnings: logger.info("Edited code execution completed with warnings") else: logger.info("Edited code execution completed successfully") except Exception as e: # Handle sandbox timeout/execution errors logger.error(f"Edited code execution failed: {str(e)}") # Add detailed error information for code execution failures error_message = str(e) if "timeout" in error_message.lower(): detailed_error = f"""**Code Execution Timeout** ⏰ The edited code execution took too long and was terminated: - Code may have entered an infinite loop - Processing large datasets can cause timeouts - Complex computations may exceed time limits **What you can try:** - Optimize your code for better performance - Break down complex operations into smaller steps - Increase the timeout limit in settings - Check for infinite loops or blocking operations **Technical details:** ``` {error_message} ```""" else: detailed_error = f"""**Code Execution Failed** πŸ’₯ An error occurred while executing the edited code in the sandbox: **Technical details:** ``` {error_message} ``` **What you can try:** - Check the code for syntax errors - Verify all required packages are available - Try simplifying the code - Check the sandbox logs for more details""" notebook.add_error(detailed_error) yield notebook.render(mode="error"), notebook.data, messages return # Prepare tool response raw_execution = notebook.parse_exec_result_nb(execution) logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs") # Add tool response to session state only session_state_manager.add_message( session_state, "tool", tool_response_content, tool_call_id=tool_call.id, raw_execution=raw_execution, metadata={"turn": turns, "execution_successful": not previous_execution_had_error, "action": "edit_cell"} ) elif tool_call.function.name == "execute_shell_command": # Update phase to executing shell command session_state_manager.update_execution_state(session_state, current_phase="executing_shell") logger.debug(f"Processing shell command tool call: {tool_call.id}") tool_args = json.loads(tool_call.function.arguments) command = tool_args["command"] logger.debug(f"Shell command to execute: '{command}'") # Add shell command to notebook with special styling notebook.add_shell_command(command) logger.debug("Yielding notebook in 'executing' mode") yield notebook.render(mode="executing"), notebook.data, messages try: # Check for stop event before execution if stop_event and stop_event.is_set(): logger.info("Stop event detected before shell execution") stopped_message = """**Execution Stopped** ⏸️ The execution was stopped by user request before the shell command could run.""" notebook.add_markdown(stopped_message, "assistant") yield notebook.render(mode="stopped"), notebook.data, messages return # Execute shell command in sandbox using raw shell execution logger.info(f"Executing raw shell command in sandbox: {command}") try: # Use the new raw shell execution method if hasattr(sbx, 'run_shell'): shell_execution = sbx.run_shell(command, timeout=60) logger.info("Shell command executed using raw shell method") else: # Fallback: Execute shell command using Python subprocess within sandbox shell_code = f""" import subprocess import sys try: result = subprocess.run( {repr(command)}, shell=True, capture_output=True, text=True, timeout=60 ) if result.stdout: print("STDOUT:") print(result.stdout) if result.stderr: print("STDERR:") print(result.stderr) print(f"Exit code: {{result.returncode}}") except subprocess.TimeoutExpired: print("Error: Command timed out after 60 seconds") except Exception as e: print(f"Error executing command: {{e}}") """ shell_execution = sbx.run_code(shell_code) logger.info("Shell command executed via Python subprocess fallback") # Add shell execution results to notebook notebook.append_shell_execution(shell_execution) # Prepare response content for LLM shell_response_content = parse_exec_result_llm(shell_execution, max_code_output=max_code_output) # Log tool execution in session state session_state_manager.log_tool_execution( session_state, tool_call.id, "execute_shell_command", tool_args, shell_response_content, shell_execution ) # Check for errors shell_had_error = notebook.has_execution_error(shell_execution) if shell_had_error: logger.warning("Shell command execution resulted in error") else: logger.info("Shell command execution completed successfully") except Exception as shell_error: logger.error(f"Shell command execution failed: {str(shell_error)}") # Create error message detailed_error = f"""**Shell Command Failed** πŸ”§ An error occurred while executing the shell command: **Command:** `{command}` **Technical details:** ``` {str(shell_error)} ``` **What you can try:** - Check if the command exists in the sandbox environment - Verify command syntax - Try a simpler version of the command - Check if required tools/packages are installed""" notebook.add_error(detailed_error) # Log failed execution session_state_manager.log_tool_execution( session_state, tool_call.id, "execute_shell_command", tool_args, detailed_error ) yield notebook.render(mode="error"), notebook.data, messages return except Exception as e: # Handle general execution errors logger.error(f"Shell command execution failed: {str(e)}") detailed_error = f"""**Shell Execution Error** ⚠️ An unexpected error occurred while executing the shell command: **Command:** `{command}` **Technical details:** ``` {str(e)} ```""" notebook.add_error(detailed_error) yield notebook.render(mode="error"), notebook.data, messages return # Prepare tool response for LLM and session state raw_execution = notebook.parse_exec_result_nb(shell_execution) logger.debug(f"Shell tool response: {len(shell_response_content)} chars content") # Add tool response to session state session_state_manager.add_message( session_state, "tool", shell_response_content, tool_call_id=tool_call.id, raw_execution=raw_execution, metadata={"turn": turns, "command": command, "execution_successful": not shell_had_error, "action": "shell_command"} ) else: logger.warning(f"Unknown tool call function: {tool_call.function.name}") if not tool_calls: logger.info(f"No tool calls on turn {turns}, conversation ending") if len(full_response.strip())==0: logger.error("Assistant provided no content and no tool calls") notebook.add_error(f"No tool call and empty assistant response:\n{response.model_dump_json(indent=2)}") # Only add the final assistant message if we didn't already add it above # (in the elif full_response.strip() block) if full_response.strip(): # Since we're now only using session state, we can safely add the message # The session state manager will handle any deduplication if needed session_state_manager.add_message( session_state, "assistant", full_response, metadata={"turn": turns, "type": "final_response"} ) logger.debug("Added final assistant response to session state") done = True # Update session state after each turn session_state_manager.update_execution_state( session_state, current_turn=turns, last_execution_successful=not previous_execution_had_error ) session_state_manager.update_notebook_data(session_state, notebook.data) session_state_manager.save_state(session_state) if done: logger.info(f"Interactive notebook completed after {turns} turns") session_state_manager.update_execution_state( session_state, is_running=False, sandbox_active=True ) session_state_manager.save_state(session_state) yield notebook.render(mode="done"), notebook.data, messages else: logger.debug(f"Turn {turns} completed, yielding in 'generating' mode") yield notebook.render(mode="generating"), notebook.data, messages if turns > MAX_TURNS: logger.warning(f"Interactive notebook reached maximum turns ({MAX_TURNS})") error_msg = f"**Maximum Turns Reached** πŸ”„\n\nThe conversation has reached the maximum number of turns ({MAX_TURNS}). This is a safety limit to prevent infinite loops.\n\n**What you can try:**\n- Start a new conversation\n- Clear the notebook and begin fresh\n- Contact support if you need a higher turn limit" notebook.add_error(error_msg) # Add error to session state session_state_manager.add_message( session_state, "assistant", error_msg, metadata={"type": "error", "error_type": "max_turns_exceeded", "turn": turns} ) # Update final state session_state_manager.update_execution_state( session_state, is_running=False, last_execution_successful=False ) session_state_manager.update_notebook_data(session_state, notebook.data) session_state_manager.save_state(session_state) yield notebook.render(mode="error"), notebook.data, messages elif stop_event and stop_event.is_set(): logger.info("Interactive notebook stopped by user") # Add a stopped message to the notebook stopped_message = """**Execution Stopped** ⏸️ The execution was stopped by user request. You can resume by clicking Run again.""" notebook.add_markdown(stopped_message, "assistant") # Add stopped message to session state session_state_manager.add_message( session_state, "assistant", stopped_message, metadata={"type": "status", "status_type": "stopped_by_user", "turn": turns} ) # Update state to indicate pause session_state_manager.update_execution_state( session_state, is_running=False, is_paused=True ) session_state_manager.update_notebook_data(session_state, notebook.data) session_state_manager.save_state(session_state) yield notebook.render(mode="stopped"), notebook.data, messages def run_interactive_notebook(client, model, messages, sbx, stop_event=None, tools=None): """Backward compatibility wrapper for the new session state system""" logger.warning("Using legacy run_interactive_notebook - this should be replaced with session state version") # Create a temporary session for backward compatibility import uuid temp_session_id = str(uuid.uuid4())[:8] session_manager = SessionStateManager(temp_session_id) # Create basic session state session_state = session_manager.create_initial_state( hardware_config={"gpu_type": "unknown", "cpu_cores": 2, "memory_gb": 8, "timeout_sec": 300}, api_config={"model_name": model, "provider_type": "unknown"}, environment={"variables": "", "files_uploaded": []}, system_prompt=messages[0].get("content", "") if messages and messages[0].get("role") == "system" else "" ) # Initialize conversation history with provided messages session_state["conversation_history"] = messages # Use the new session-based function yield from run_interactive_notebook_with_session_state( client, model, session_manager, session_state, sbx, stop_event, tools )