Spaces:
Sleeping
Sleeping
from jupyter_handler import JupyterNotebook | |
import json | |
import logging | |
import os | |
import datetime | |
from pathlib import Path | |
from typing import Dict, List, Any, Optional | |
from tavily import TavilyClient | |
# Phoenix tracing imports | |
try: | |
from openinference.instrumentation import using_session | |
PHOENIX_AVAILABLE = True | |
print("Phoenix session tracking imports successful") | |
except ImportError: | |
PHOENIX_AVAILABLE = False | |
print("Phoenix session tracking not available - missing openinference packages") | |
# Configure logging for utils module | |
logger = logging.getLogger(__name__) | |
# Initialize Tavily client | |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) if TAVILY_API_KEY else None | |
TOOLS = [ | |
{ | |
"type": "function", | |
"function": { | |
"name": "add_and_execute_jupyter_code_cell", | |
"description": "A Python code execution environment that runs code in a Jupyter notebook interface. This is stateful - variables and imports persist between executions.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"code": { | |
"type": "string", | |
"description": "The Python code to execute." | |
} | |
}, | |
"required": ["code"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "edit_and_execute_current_cell", | |
"description": "Edit the current/last code cell and execute the new code. Use this to fix errors or modify the previous code instead of creating a new cell.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"code": { | |
"type": "string", | |
"description": "The updated Python code to replace the current cell with and execute." | |
} | |
}, | |
"required": ["code"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "execute_shell_command", | |
"description": "Execute shell/system commands like ls, cat, mkdir, etc. This runs independently of Python and provides terminal-style output.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"command": { | |
"type": "string", | |
"description": "The shell command to execute (e.g., 'ls -la', 'cat file.txt', 'mkdir new_folder')." | |
} | |
}, | |
"required": ["command"] | |
} | |
} | |
}, | |
{ | |
"type": "function", | |
"function": { | |
"name": "web_search", | |
"description": "Search the web for current information, documentation, tutorials, and solutions to coding problems. Use this to get context before starting tasks or when encountering errors.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "Search query (max 400 characters). Be specific and include relevant keywords." | |
} | |
}, | |
"required": ["query"] | |
} | |
} | |
}, | |
] | |
# TOOLS = TOOLS[:1] | |
MAX_TURNS = 20 | |
def create_phoenix_session_context(session_id: str, user_id: str = None, metadata: Dict = None): | |
""" | |
Create a Phoenix session context for tracing LLM interactions. | |
Args: | |
session_id: Unique identifier for the session | |
user_id: Optional user identifier | |
metadata: Additional metadata to include in traces | |
Returns: | |
Context manager for Phoenix session tracking | |
""" | |
if not PHOENIX_AVAILABLE: | |
# Return a no-op context manager if Phoenix is not available | |
from contextlib import nullcontext | |
return nullcontext() | |
try: | |
# Use using_session for proper session grouping in Phoenix | |
# This ensures all LLM calls within this context are grouped under the same session | |
logger.debug(f"Creating Phoenix session context for session_id: {session_id}") | |
return using_session(session_id) | |
except Exception as e: | |
logger.warning(f"Failed to create Phoenix session context for {session_id}: {e}") | |
# Fallback to no-op context if Phoenix session creation fails | |
from contextlib import nullcontext | |
return nullcontext() | |
class SessionStateManager: | |
"""Manages comprehensive session state in a single JSON file""" | |
def __init__(self, session_id: str, base_dir: str = './temp/'): | |
self.session_id = session_id | |
self.base_dir = Path(base_dir) | |
self.session_dir = self.base_dir / session_id | |
self.state_file = self.session_dir / 'session_state.json' | |
self.session_dir.mkdir(parents=True, exist_ok=True) | |
logger.info(f"SessionStateManager initialized for {session_id}") | |
def create_initial_state(self, hardware_config: Dict, api_config: Dict, | |
environment: Dict, system_prompt: str) -> Dict: | |
"""Create initial session state structure""" | |
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() | |
initial_state = { | |
"session_id": self.session_id, | |
"created_at": timestamp, | |
"last_updated": timestamp, | |
"version": "1.0", | |
"hardware_config": hardware_config, | |
"api_config": api_config, | |
"environment": environment, | |
"conversation_history": [ | |
{ | |
"role": "system", | |
"content": system_prompt, | |
"timestamp": timestamp, | |
"metadata": {"type": "system_initialization"} | |
} | |
], | |
"llm_interactions": [], # Complete API call logs | |
"tool_executions": [], # All tool calls and results | |
"notebook_data": { | |
"cells": [], | |
"metadata": { | |
"kernel_info": {"name": "python3"}, | |
"language_info": {"name": "python", "version": "3.12"}, | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
}, | |
"execution_state": { | |
"current_turn": 0, | |
"max_turns": MAX_TURNS, | |
"is_running": False, | |
"is_paused": False, | |
"last_execution_successful": None, | |
"sandbox_active": False, | |
"sandbox_info": None | |
}, | |
"session_stats": { | |
"total_messages": 1, | |
"total_code_executions": 0, | |
"total_searches": 0, | |
"total_errors": 0, | |
"session_duration_seconds": 0 | |
} | |
} | |
logger.info("Created initial session state for %s", self.session_id) | |
return initial_state | |
def load_state(self) -> Optional[Dict]: | |
"""Load session state from file with improved error handling""" | |
if not self.state_file.exists(): | |
logger.info(f"No existing session state found for {self.session_id}") | |
return None | |
try: | |
with open(self.state_file, 'r', encoding='utf-8') as f: | |
state = json.load(f) | |
logger.info(f"Loaded session state for {self.session_id} with {len(state.get('conversation_history', []))} messages") | |
return state | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON corruption in session state for {self.session_id}: {str(e)}") | |
logger.info(f"Creating backup of corrupted file: {self.state_file}.corrupted") | |
try: | |
import shutil | |
shutil.copy2(self.state_file, str(self.state_file) + ".corrupted") | |
logger.info(f"Backup created successfully") | |
except Exception as backup_error: | |
logger.warning(f"Failed to create backup: {backup_error}") | |
return None | |
except Exception as e: | |
logger.error(f"Failed to load session state for {self.session_id}: {str(e)}") | |
return None | |
def save_state(self, state: Dict) -> bool: | |
"""Save session state to file with improved error handling""" | |
try: | |
# Update last_updated timestamp | |
state["last_updated"] = datetime.datetime.now(datetime.timezone.utc).isoformat() | |
# Update session stats | |
if "session_stats" not in state: | |
state["session_stats"] = {} | |
created_at = datetime.datetime.fromisoformat(state["created_at"]) | |
current_time = datetime.datetime.now(datetime.timezone.utc) | |
state["session_stats"]["session_duration_seconds"] = int((current_time - created_at).total_seconds()) | |
state["session_stats"]["total_messages"] = len(state.get("conversation_history", [])) | |
# Validate JSON serializability before writing | |
try: | |
json.dumps(state, ensure_ascii=False) | |
except (TypeError, ValueError) as e: | |
logger.error(f"State contains non-serializable data: {e}") | |
logger.info("Attempting to clean non-serializable data...") | |
state = self._clean_non_serializable_data(state) | |
# Write to temporary file first, then rename for atomic operation | |
temp_file = self.state_file.with_suffix('.tmp') | |
with open(temp_file, 'w', encoding='utf-8') as f: | |
json.dump(state, f, indent=2, ensure_ascii=False) | |
# Atomic rename | |
temp_file.replace(self.state_file) | |
logger.debug(f"Saved session state for {self.session_id} ({len(json.dumps(state))} characters)") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to save session state for {self.session_id}: {str(e)}") | |
# Clean up temp file if it exists | |
temp_file = self.state_file.with_suffix('.tmp') | |
if temp_file.exists(): | |
try: | |
temp_file.unlink() | |
except Exception: | |
pass | |
return False | |
def _clean_non_serializable_data(self, obj): | |
"""Recursively clean non-serializable data from objects""" | |
if isinstance(obj, dict): | |
cleaned = {} | |
for key, value in obj.items(): | |
try: | |
json.dumps(value) | |
cleaned[key] = self._clean_non_serializable_data(value) | |
except (TypeError, ValueError): | |
logger.warning(f"Removing non-serializable field: {key}") | |
cleaned[key] = f"<non-serializable: {type(value).__name__}>" | |
return cleaned | |
elif isinstance(obj, list): | |
cleaned = [] | |
for item in obj: | |
try: | |
json.dumps(item) | |
cleaned.append(self._clean_non_serializable_data(item)) | |
except (TypeError, ValueError): | |
cleaned.append(f"<non-serializable: {type(item).__name__}>") | |
return cleaned | |
else: | |
return obj | |
def log_llm_interaction(self, state: Dict, request_data: Dict, response_data: Dict, | |
model: str, turn: int) -> None: | |
"""Log complete LLM API interaction""" | |
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() | |
interaction = { | |
"timestamp": timestamp, | |
"turn": turn, | |
"model": model, | |
"request": { | |
"messages_count": len(request_data.get("messages", [])), | |
"tools_count": len(request_data.get("tools", [])), | |
"model": request_data.get("model"), | |
"tool_choice": request_data.get("tool_choice") | |
}, | |
"response": { | |
"content": response_data.get("choices", [{}])[0].get("message", {}).get("content"), | |
"tool_calls": response_data.get("choices", [{}])[0].get("message", {}).get("tool_calls"), | |
"finish_reason": response_data.get("choices", [{}])[0].get("finish_reason"), | |
"usage": response_data.get("usage") | |
} | |
} | |
if "llm_interactions" not in state: | |
state["llm_interactions"] = [] | |
state["llm_interactions"].append(interaction) | |
# Log Phoenix session information for easy debugging | |
logger.debug(f"Logged LLM interaction for turn {turn} in session {self.session_id}") | |
logger.debug(f"Phoenix session tracking: session_id={self.session_id}, turn={turn}, model={model}") | |
# Log usage information if available for monitoring | |
usage = response_data.get("usage") | |
if usage: | |
logger.info(f"Session {self.session_id} turn {turn}: " | |
f"prompt_tokens={usage.get('prompt_tokens', 0)}, " | |
f"completion_tokens={usage.get('completion_tokens', 0)}, " | |
f"total_tokens={usage.get('total_tokens', 0)}") | |
def log_tool_execution(self, state: Dict, tool_call_id: str, tool_name: str, | |
tool_args: Dict, result: str, execution_data: Any = None) -> None: | |
"""Log tool execution with full details""" | |
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() | |
# Safely serialize execution_data to prevent JSON corruption | |
safe_execution_data = None | |
if execution_data is not None: | |
try: | |
# Convert execution_data to a safe, serializable format | |
if hasattr(execution_data, '__dict__'): | |
safe_execution_data = { | |
"type": type(execution_data).__name__, | |
"error": str(execution_data.error) if hasattr(execution_data, 'error') and execution_data.error else None, | |
"has_results": hasattr(execution_data, 'results') and bool(execution_data.results), | |
"has_stdout": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stdout') and bool(execution_data.logs.stdout), | |
"has_stderr": hasattr(execution_data, 'logs') and hasattr(execution_data.logs, 'stderr') and bool(execution_data.logs.stderr) | |
} | |
else: | |
# For simple types, convert to string safely | |
safe_execution_data = str(execution_data)[:200] # Limit length | |
except Exception as e: | |
logger.warning(f"Failed to serialize execution_data for {tool_call_id}: {e}") | |
safe_execution_data = {"serialization_error": str(e)} | |
tool_execution = { | |
"timestamp": timestamp, | |
"tool_call_id": tool_call_id, | |
"tool_name": tool_name, | |
"arguments": tool_args, | |
"result_summary": result[:500] + "..." if len(result) > 500 else result, | |
"result_length": len(result), | |
"execution_data": safe_execution_data, | |
"success": execution_data is None or (hasattr(execution_data, 'error') and execution_data.error is None) if execution_data else True | |
} | |
if "tool_executions" not in state: | |
state["tool_executions"] = [] | |
state["tool_executions"].append(tool_execution) | |
# Update stats | |
if tool_name == "add_and_execute_jupyter_code_cell": | |
state["session_stats"]["total_code_executions"] = state["session_stats"].get("total_code_executions", 0) + 1 | |
elif tool_name == "web_search": | |
state["session_stats"]["total_searches"] = state["session_stats"].get("total_searches", 0) + 1 | |
if not tool_execution["success"]: | |
state["session_stats"]["total_errors"] = state["session_stats"].get("total_errors", 0) + 1 | |
logger.debug(f"Logged tool execution {tool_name} ({tool_call_id}) in session {self.session_id}") | |
def add_message(self, state: Dict, role: str, content: str, | |
tool_calls: List = None, tool_call_id: str = None, | |
raw_execution: Any = None, metadata: Dict = None) -> None: | |
"""Add message to conversation history with full context""" | |
timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() | |
message = { | |
"role": role, | |
"content": content, | |
"timestamp": timestamp | |
} | |
if tool_calls: | |
message["tool_calls"] = tool_calls | |
if tool_call_id: | |
message["tool_call_id"] = tool_call_id | |
if raw_execution: | |
message["raw_execution"] = raw_execution | |
if metadata: | |
message["metadata"] = metadata | |
state["conversation_history"].append(message) | |
logger.debug(f"Added {role} message to session {self.session_id} conversation history") | |
def update_execution_state(self, state: Dict, **kwargs) -> None: | |
"""Update execution state fields""" | |
for key, value in kwargs.items(): | |
if key in state["execution_state"]: | |
state["execution_state"][key] = value | |
logger.debug(f"Updated execution state {key}={value} for session {self.session_id}") | |
# Try to sync with global EXECUTION_STATES for UI consistency (if available) | |
try: | |
import sys | |
if 'app' in sys.modules: | |
execution_states = getattr(sys.modules['app'], 'EXECUTION_STATES', None) | |
if execution_states and self.session_id in execution_states: | |
for key, value in kwargs.items(): | |
execution_states[self.session_id][key] = value | |
except (ImportError, AttributeError): | |
pass # Ignore if we can't sync with global state | |
def update_notebook_data(self, state: Dict, notebook_data: Dict) -> None: | |
"""Update notebook data in session state""" | |
state["notebook_data"] = notebook_data | |
logger.debug(f"Updated notebook data for session {self.session_id} ({len(notebook_data.get('cells', []))} cells)") | |
def get_conversation_history(self, state: Dict) -> List[Dict]: | |
"""Get conversation history suitable for LLM API calls""" | |
return state.get("conversation_history", []) | |
def validate_and_repair_conversation(self, state: Dict) -> None: | |
"""Validate and repair conversation history to ensure tool calls have responses""" | |
conversation = state.get("conversation_history", []) | |
if not conversation: | |
return | |
pending_tool_calls = set() | |
valid_messages = [] | |
for message in conversation: | |
if message.get("role") == "assistant" and message.get("tool_calls"): | |
# Track tool calls | |
for tool_call in message["tool_calls"]: | |
pending_tool_calls.add(tool_call["id"]) | |
valid_messages.append(message) | |
elif message.get("role") == "tool" and message.get("tool_call_id"): | |
# Remove from pending when we find a response | |
pending_tool_calls.discard(message["tool_call_id"]) | |
valid_messages.append(message) | |
else: | |
# Regular message (system, user, assistant without tool calls) | |
valid_messages.append(message) | |
# If there are incomplete tool calls, remove the assistant messages that created them | |
if pending_tool_calls: | |
logger.warning(f"Found incomplete tool calls in conversation: {pending_tool_calls}") | |
logger.warning("Removing incomplete assistant messages to repair conversation") | |
repaired_messages = [] | |
for message in valid_messages: | |
if (message.get("role") == "assistant" and | |
message.get("tool_calls") and | |
any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])): | |
logger.debug("Removing assistant message with incomplete tool calls") | |
continue | |
repaired_messages.append(message) | |
# Update conversation history | |
state["conversation_history"] = repaired_messages | |
logger.info(f"Repaired conversation: {len(conversation)} -> {len(repaired_messages)} messages") | |
# Save the repaired state | |
self.save_state(state) | |
def session_exists(self) -> bool: | |
"""Check if session state file exists""" | |
return self.state_file.exists() | |
def get_session_summary(self, state: Dict) -> str: | |
"""Get human-readable session summary""" | |
stats = state.get("session_stats", {}) | |
created = datetime.datetime.fromisoformat(state["created_at"]) | |
return f"""Session {self.session_id}: | |
- Created: {created.strftime('%Y-%m-%d %H:%M:%S UTC')} | |
- Messages: {stats.get('total_messages', 0)} | |
- Code Executions: {stats.get('total_code_executions', 0)} | |
- Web Searches: {stats.get('total_searches', 0)} | |
- Errors: {stats.get('total_errors', 0)} | |
- Duration: {stats.get('session_duration_seconds', 0)}s | |
- Hardware: {state.get('hardware_config', {}).get('gpu_type', 'unknown')} | |
- Model: {state.get('api_config', {}).get('model_name', 'unknown')}""" | |
def execute_code(sbx, code): | |
logger.debug(f"Executing code in sandbox ({len(code)} characters)") | |
execution = sbx.run_code(code, on_stdout=lambda data: logger.debug(f'stdout: {data}')) | |
output = "" | |
if len(execution.logs.stdout) > 0: | |
output += "\n".join(execution.logs.stdout) | |
logger.debug(f"Execution produced {len(execution.logs.stdout)} stdout lines") | |
if len(execution.logs.stderr) > 0: | |
output += "\n".join(execution.logs.stderr) | |
logger.debug(f"Execution produced {len(execution.logs.stderr)} stderr lines") | |
if execution.error is not None: | |
output += execution.error.traceback | |
logger.warning(f"Execution error: {execution.error.name}: {execution.error.value}") | |
logger.debug(f"Code execution completed, output length: {len(output)}") | |
return output, execution | |
def parse_exec_result_llm(execution, max_code_output=1000): | |
logger.debug(f"Parsing execution result for LLM (max_output: {max_code_output})") | |
output = [] | |
def truncate_if_needed(text): | |
if len(text) > max_code_output: | |
return (text[:max_code_output] + f"\n[Output is truncated as it is more than {max_code_output} characters]") | |
return text | |
if execution.results: | |
results_text_parts = [] | |
plot_count = 0 | |
for result in execution.results: | |
if hasattr(result, 'text') and result.text: | |
results_text_parts.append(result.text) | |
elif hasattr(result, 'png') and result.png: | |
plot_count += 1 | |
results_text_parts.append(f"[Plot {plot_count} generated and displayed]") | |
elif hasattr(result, 'html') and result.html: | |
results_text_parts.append("[HTML output generated]") | |
if results_text_parts: | |
results_text = "\n".join(results_text_parts) | |
output.append(truncate_if_needed(results_text)) | |
logger.debug(f"Added {len(execution.results)} execution results (including {plot_count} plots)") | |
if execution.logs.stdout: | |
stdout_text = "\n".join(execution.logs.stdout) | |
output.append(truncate_if_needed(stdout_text)) | |
logger.debug(f"Added stdout output ({len(execution.logs.stdout)} lines)") | |
if execution.logs.stderr: | |
stderr_text = "\n".join(execution.logs.stderr) | |
output.append(truncate_if_needed(stderr_text)) | |
logger.debug(f"Added stderr output ({len(execution.logs.stderr)} lines)") | |
if execution.error is not None: | |
output.append(truncate_if_needed(execution.error.traceback)) | |
logger.debug(f"Added error traceback: {execution.error.name}") | |
final_output = "\n".join(output) | |
logger.debug(f"Parsed execution result for LLM: {len(final_output)} characters") | |
return final_output | |
def clean_messages_for_api(messages): | |
""" | |
Create a clean copy of messages without raw_execution fields and metadata for API calls. | |
Also validates that tool calls have corresponding tool responses. | |
This prevents 413 errors and API validation errors. | |
""" | |
logger.debug(f"Cleaning {len(messages)} messages for API call") | |
cleaned_messages = [] | |
raw_execution_count = 0 | |
metadata_count = 0 | |
pending_tool_calls = set() | |
for message in messages: | |
cleaned_message = message.copy() | |
# Remove raw_execution data | |
if "raw_execution" in cleaned_message: | |
cleaned_message.pop("raw_execution") | |
raw_execution_count += 1 | |
# Remove metadata and timestamp | |
if "metadata" in cleaned_message: | |
cleaned_message.pop("metadata") | |
metadata_count += 1 | |
if "timestamp" in cleaned_message: | |
cleaned_message.pop("timestamp") | |
# Track tool calls and responses for validation | |
if cleaned_message.get("role") == "assistant" and cleaned_message.get("tool_calls"): | |
for tool_call in cleaned_message["tool_calls"]: | |
pending_tool_calls.add(tool_call["id"]) | |
elif cleaned_message.get("role") == "tool" and cleaned_message.get("tool_call_id"): | |
pending_tool_calls.discard(cleaned_message["tool_call_id"]) | |
cleaned_messages.append(cleaned_message) | |
# If there are pending tool calls without responses, remove the assistant message with tool calls | |
if pending_tool_calls: | |
logger.warning(f"Found {len(pending_tool_calls)} tool calls without responses: {pending_tool_calls}") | |
logger.warning("Removing incomplete tool call messages to prevent API errors") | |
# Remove messages with incomplete tool calls | |
filtered_messages = [] | |
for message in cleaned_messages: | |
if (message.get("role") == "assistant" and | |
message.get("tool_calls") and | |
any(tc["id"] in pending_tool_calls for tc in message["tool_calls"])): | |
logger.debug("Removing assistant message with incomplete tool calls") | |
continue | |
filtered_messages.append(message) | |
cleaned_messages = filtered_messages | |
logger.debug(f"Cleaned messages: removed raw_execution from {raw_execution_count}, metadata from {metadata_count}") | |
logger.debug(f"Final cleaned message count: {len(cleaned_messages)}") | |
return cleaned_messages | |
def web_search(query): | |
""" | |
Perform web search using Tavily API with automatic year addition and formatting. | |
Args: | |
query (str): Search query (max 400 characters) | |
Returns: | |
str: Formatted search results for LLM consumption | |
""" | |
if not tavily_client: | |
logger.error("Tavily client not initialized - API key missing") | |
return "❌ Search unavailable: Tavily API key not configured" | |
# Validate query length | |
if len(query) > 400: | |
logger.warning(f"Query too long ({len(query)} chars), truncating to 400") | |
query = query[:400] | |
# Add current year to query for more recent results | |
current_year = datetime.datetime.now().year | |
if str(current_year) not in query: | |
# Only add year if query has room for it | |
year_addition = f" {current_year}" | |
if len(query + year_addition) <= 400: | |
query += year_addition | |
logger.debug(f"Added current year to query: {current_year}") | |
logger.info(f"Performing Tavily search: '{query}' ({len(query)} chars)") | |
try: | |
# Perform search with optimized parameters | |
response = tavily_client.search( | |
query=query, | |
search_depth="basic", # Use basic for faster results | |
max_results=5, # Limit results to avoid overwhelming context | |
include_answer=True, # Include AI-generated answer | |
include_raw_content=False, # Don't include raw content to save tokens | |
include_images=False # Don't include images | |
) | |
logger.info(f"Search completed: {len(response.get('results', []))} results found") | |
# Format results for LLM consumption | |
formatted_results = format_search_results_for_llm(response) | |
logger.debug(f"Formatted search results: {len(formatted_results)} characters") | |
return formatted_results | |
except Exception as e: | |
logger.error(f"Tavily search failed: {str(e)}") | |
return f"❌ Search failed: {str(e)}" | |
def format_search_results_for_llm(response): | |
"""Format Tavily search results for LLM consumption""" | |
query = response.get('query', 'Unknown query') | |
results = response.get('results', []) | |
answer = response.get('answer', '') | |
formatted = f"🔍 **Web Search Results for:** {query}\n\n" | |
if answer: | |
formatted += f"**Quick Answer:** {answer}\n\n" | |
if results: | |
formatted += f"**Found {len(results)} relevant sources:**\n\n" | |
for i, result in enumerate(results, 1): | |
title = result.get('title', 'Untitled') | |
url = result.get('url', '') | |
content = result.get('content', '') | |
score = result.get('score', 0) | |
# Truncate content to reasonable length | |
# if len(content) > 300: | |
# content = content[:300] + "..." | |
formatted += f"**{i}. {title}** (Relevance: {score:.2f})\n" | |
formatted += f" 🔗 {url}\n" | |
formatted += f" 📄 {content}\n\n" | |
else: | |
formatted += "No results found.\n" | |
return formatted | |
def run_interactive_notebook_with_session_state(client, model, session_state_manager, session_state, sbx, stop_event=None, tools=None): | |
logger.info(f"Starting interactive notebook with session state for {session_state_manager.session_id}") | |
# Get conversation history from session state | |
messages = session_state_manager.get_conversation_history(session_state) | |
notebook = JupyterNotebook(messages) | |
# Update execution state | |
session_state_manager.update_execution_state(session_state, is_running=True, sandbox_active=True, current_phase="initializing") | |
# Use provided tools or default to all tools | |
if tools is None: | |
tools = TOOLS | |
try: | |
sbx_info = sbx.get_info() | |
notebook.add_sandbox_countdown(sbx_info.started_at, sbx_info.end_at) | |
# Store sandbox info in session state | |
session_state["execution_state"]["sandbox_info"] = { | |
"started_at": sbx_info.started_at.isoformat(), | |
"end_at": sbx_info.end_at.isoformat(), | |
"timeout_seconds": int((sbx_info.end_at - sbx_info.started_at).total_seconds()) | |
} | |
logger.debug(f"Added sandbox countdown: {sbx_info.started_at} to {sbx_info.end_at}") | |
except Exception as e: | |
logger.warning(f"Failed to get sandbox info: {str(e)}") | |
logger.debug("Initial notebook yield in 'generating' mode") | |
# Update notebook data in session state | |
session_state_manager.update_notebook_data(session_state, notebook.data) | |
# Save initial state | |
session_state_manager.save_state(session_state) | |
yield notebook.render(mode="generating"), notebook.data, messages | |
max_code_output = 1000 | |
turns = session_state["execution_state"]["current_turn"] | |
done = False | |
previous_execution_had_error = False | |
previous_execution_had_warnings = False | |
logger.info(f"Starting interactive loop from turn {turns} with max_output={max_code_output}, max_turns={MAX_TURNS}") | |
while not done and (turns <= MAX_TURNS) and (stop_event is None or not stop_event.is_set()): | |
turns += 1 | |
logger.info(f"Starting turn {turns}/{MAX_TURNS}") | |
try: | |
# Update phase to generating | |
session_state_manager.update_execution_state(session_state, current_phase="generating") | |
# Refresh messages from session state before API call | |
messages = session_state_manager.get_conversation_history(session_state) | |
logger.debug(f"Making API call to {model} with {len(messages)} messages") | |
# Prepare request data for logging | |
request_data = { | |
"messages": clean_messages_for_api(messages), | |
"model": model, | |
"tools": tools, | |
"tool_choice": "auto" | |
} | |
# Prepare session metadata for Phoenix tracing | |
session_metadata = { | |
"turn": turns, | |
"max_turns": MAX_TURNS, | |
"model": model, | |
"tools_count": len(tools), | |
"messages_count": len(messages), | |
"current_phase": "generating" | |
} | |
# Add hardware config if available | |
if "hardware_config" in session_state: | |
hw_config = session_state["hardware_config"] | |
session_metadata.update({ | |
"gpu_type": hw_config.get("gpu_type", "unknown"), | |
"cpu_cores": hw_config.get("cpu_cores", "unknown"), | |
"memory_gb": hw_config.get("memory_gb", "unknown") | |
}) | |
# Wrap OpenAI API call with Phoenix session context for proper grouping | |
with create_phoenix_session_context( | |
session_id=session_state_manager.session_id, | |
user_id=None, # Could be extracted from request context if available | |
metadata=session_metadata | |
): | |
logger.debug(f"Making OpenAI API call with Phoenix session context: {session_state_manager.session_id}") | |
response = client.chat.completions.create(**request_data) | |
logger.debug("API call successful within Phoenix session context") | |
# Log the complete LLM interaction | |
session_state_manager.log_llm_interaction( | |
session_state, request_data, response.model_dump(), model, turns | |
) | |
except Exception as e: | |
# Handle inference client errors | |
logger.error(f"Inference failed on turn {turns}: {str(e)}") | |
# Add detailed error information to the notebook | |
error_message = str(e) | |
if "429" in error_message or "too_many_requests" in error_message.lower(): | |
detailed_error = f"""**API Rate Limit Exceeded** 🚫 | |
The inference service has reached its rate limit. This typically means: | |
- Too many requests have been sent in a short period | |
- Daily quota has been exceeded | |
- Service is temporarily overloaded | |
**What you can try:** | |
- Wait a few minutes and try again | |
- If using Cerebras API, check your daily quota | |
- Try using a different model or service | |
- Contact support if the issue persists | |
**Technical details:** | |
``` | |
{error_message} | |
```""" | |
elif "401" in error_message or "unauthorized" in error_message.lower(): | |
detailed_error = f"""**Authentication Error** 🔐 | |
There's an issue with API authentication: | |
- API key might be missing or invalid | |
- API key might have expired | |
- Insufficient permissions | |
**Technical details:** | |
``` | |
{error_message} | |
```""" | |
elif "500" in error_message or "internal" in error_message.lower(): | |
detailed_error = f"""**Server Error** 🔧 | |
The inference service encountered an internal error: | |
- Service might be temporarily unavailable | |
- Try again in a few moments | |
- If the issue persists, it's likely a service-side problem | |
**Technical details:** | |
``` | |
{error_message} | |
```""" | |
else: | |
detailed_error = f"""**Inference Service Error** ⚠️ | |
An error occurred while communicating with the AI service: | |
**Technical details:** | |
``` | |
{error_message} | |
``` | |
**What you can try:** | |
- Check your internet connection | |
- Try again in a few moments | |
- If the problem persists, contact support""" | |
notebook.add_error(detailed_error) | |
# Add error to session state | |
session_state_manager.add_message( | |
session_state, "assistant", detailed_error, | |
metadata={"type": "error", "error_type": "api_error", "turn": turns} | |
) | |
# Update execution state | |
session_state_manager.update_execution_state( | |
session_state, is_running=False, last_execution_successful=False | |
) | |
# Update notebook data and save state | |
session_state_manager.update_notebook_data(session_state, notebook.data) | |
session_state_manager.save_state(session_state) | |
yield notebook.render(mode="error"), notebook.data, messages | |
return | |
# Get the response content and tool calls | |
full_response = response.choices[0].message.content or "" | |
tool_calls = response.choices[0].message.tool_calls or [] | |
logger.debug(f"Turn {turns}: Response content length: {len(full_response)}, Tool calls: {len(tool_calls)}") | |
# Add markdown cell for assistant's thinking | |
if full_response.strip(): | |
logger.debug(f"Adding assistant response as markdown ({len(full_response)} chars)") | |
notebook.add_markdown(full_response, "assistant") | |
else: | |
logger.debug("Skipping empty assistant response") | |
# Handle tool calls and add assistant message to session state only | |
if tool_calls: | |
logger.info(f"Processing {len(tool_calls)} tool calls on turn {turns}") | |
# Add assistant message to session state (messages will be derived from this) | |
session_state_manager.add_message( | |
session_state, "assistant", full_response, | |
tool_calls=[{ | |
"id": tc.id, | |
"type": "function", | |
"function": {"name": tc.function.name, "arguments": tc.function.arguments} | |
} for tc in tool_calls], | |
metadata={"turn": turns, "type": "thinking"} | |
) | |
logger.debug(f"Added assistant message with {len(tool_calls)} tool calls to session state") | |
elif full_response.strip(): | |
# If no tool calls but we have content, add regular assistant message | |
session_state_manager.add_message( | |
session_state, "assistant", full_response, | |
metadata={"turn": turns, "type": "thinking"} | |
) | |
logger.debug("Added regular assistant message to session state") | |
for i, tool_call in enumerate(tool_calls): | |
logger.debug(f"Processing tool call {i+1}/{len(tool_calls)}: {tool_call.function.name}") | |
if tool_call.function.name == "add_and_execute_jupyter_code_cell": | |
# Update phase to executing code | |
session_state_manager.update_execution_state(session_state, current_phase="executing_code") | |
logger.debug(f"Processing code execution tool call: {tool_call.id}") | |
tool_args = json.loads(tool_call.function.arguments) | |
code = tool_args["code"] | |
logger.debug(f"Code to execute: {len(code)} characters") | |
# Determine if we should reuse the last cell or create a new one | |
# Reuse if there were errors (not just warnings) in the previous execution | |
should_reuse_cell = (previous_execution_had_error and | |
notebook.get_last_cell_type() == "code") | |
if should_reuse_cell: | |
logger.info("Reusing last code cell due to previous execution error") | |
# Update the existing cell's code instead of creating a new one | |
notebook.update_last_code_cell(code) | |
else: | |
logger.debug("Creating new code cell") | |
# Create a new cell (normal behavior) | |
notebook.add_code(code) | |
logger.debug("Yielding notebook in 'executing' mode") | |
yield notebook.render(mode="executing"), notebook.data, messages | |
try: | |
# Check for stop event before execution | |
if stop_event and stop_event.is_set(): | |
logger.info("Stop event detected before code execution") | |
stopped_message = """**Execution Stopped** ⏸️ | |
The execution was stopped by user request before the code could run.""" | |
notebook.add_markdown(stopped_message, "assistant") | |
yield notebook.render(mode="stopped"), notebook.data, messages | |
return | |
# Execution sandbox call - might timeout | |
logger.info("Executing code in sandbox") | |
execution = sbx.run_code(code) | |
notebook.append_execution(execution) | |
# Update error and warning tracking for next iteration | |
previous_execution_had_error = notebook.has_execution_error(execution) | |
previous_execution_had_warnings = notebook.has_execution_warnings(execution) | |
# Log tool execution in session state | |
tool_args = json.loads(tool_call.function.arguments) | |
tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output) | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "add_and_execute_jupyter_code_cell", | |
tool_args, tool_response_content, execution | |
) | |
if previous_execution_had_error: | |
logger.warning("Code execution resulted in error") | |
elif previous_execution_had_warnings: | |
logger.info("Code execution completed with warnings") | |
else: | |
logger.info("Code execution completed successfully") | |
except Exception as e: | |
# Handle sandbox timeout/execution errors | |
logger.error(f"Code execution failed: {str(e)}") | |
# Add detailed error information for code execution failures | |
error_message = str(e) | |
if "timeout" in error_message.lower(): | |
detailed_error = f"""**Code Execution Timeout** ⏰ | |
The code execution took too long and was terminated: | |
- Code may have entered an infinite loop | |
- Processing large datasets can cause timeouts | |
- Complex computations may exceed time limits | |
**What you can try:** | |
- Optimize your code for better performance | |
- Break down complex operations into smaller steps | |
- Increase the timeout limit in settings | |
- Check for infinite loops or blocking operations | |
**Technical details:** | |
``` | |
{error_message} | |
```""" | |
else: | |
detailed_error = f"""**Code Execution Failed** 💥 | |
An error occurred while executing the code in the sandbox: | |
**Technical details:** | |
``` | |
{error_message} | |
``` | |
**What you can try:** | |
- Check the code for syntax errors | |
- Verify all required packages are available | |
- Try simplifying the code | |
- Check the sandbox logs for more details""" | |
notebook.add_error(detailed_error) | |
yield notebook.render(mode="error"), notebook.data, messages | |
return | |
# Prepare tool response (already computed above) | |
raw_execution = notebook.parse_exec_result_nb(execution) | |
logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs") | |
# Add tool response to session state only | |
session_state_manager.add_message( | |
session_state, "tool", tool_response_content, | |
tool_call_id=tool_call.id, raw_execution=raw_execution, | |
metadata={"turn": turns, "execution_successful": not previous_execution_had_error} | |
) | |
elif tool_call.function.name == "web_search": | |
# Update phase to searching | |
session_state_manager.update_execution_state(session_state, current_phase="searching") | |
logger.debug(f"Processing search tool call: {tool_call.id}") | |
tool_args = json.loads(tool_call.function.arguments) | |
query = tool_args["query"] | |
logger.debug(f"Search query: '{query}' ({len(query)} chars)") | |
# Add search status to notebook | |
notebook.add_markdown("🔍 **Searching the web...**", "assistant") | |
yield notebook.render(mode="generating"), notebook.data, messages | |
try: | |
# Perform search | |
search_results = web_search(query) | |
logger.info("Search completed successfully") | |
# Log search tool execution | |
tool_args = json.loads(tool_call.function.arguments) | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "web_search", | |
tool_args, search_results | |
) | |
# Add search results to notebook | |
notebook.add_markdown(search_results, "assistant") | |
# Add tool response to session state only | |
session_state_manager.add_message( | |
session_state, "tool", search_results, | |
tool_call_id=tool_call.id, | |
metadata={"turn": turns, "search_successful": True} | |
) | |
except Exception as e: | |
error_message = f"❌ Search failed: {str(e)}" | |
logger.error(f"Search tool call failed: {str(e)}") | |
# Log failed search | |
tool_args = json.loads(tool_call.function.arguments) | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "web_search", | |
tool_args, error_message | |
) | |
# Add error to notebook | |
notebook.add_markdown(error_message, "assistant") | |
# Add error response to session state only | |
session_state_manager.add_message( | |
session_state, "tool", error_message, | |
tool_call_id=tool_call.id, | |
metadata={"turn": turns, "search_successful": False, "error": str(e)} | |
) | |
elif tool_call.function.name == "edit_and_execute_current_cell": | |
# Update phase to executing code | |
session_state_manager.update_execution_state(session_state, current_phase="executing_code") | |
logger.debug(f"Processing edit current cell tool call: {tool_call.id}") | |
tool_args = json.loads(tool_call.function.arguments) | |
code = tool_args["code"] | |
logger.debug(f"Code to execute in current cell: {len(code)} characters") | |
# Check if we have a code cell to edit | |
if notebook.get_last_cell_type() == "code": | |
logger.info("Editing last code cell with new code") | |
notebook.update_last_code_cell(code) | |
else: | |
logger.info("No code cell to edit, creating new cell") | |
notebook.add_code(code) | |
logger.debug("Yielding notebook in 'executing' mode") | |
yield notebook.render(mode="executing"), notebook.data, messages | |
try: | |
# Check for stop event before execution | |
if stop_event and stop_event.is_set(): | |
logger.info("Stop event detected before code execution") | |
stopped_message = """**Execution Stopped** ⏸️ | |
The execution was stopped by user request before the code could run.""" | |
notebook.add_markdown(stopped_message, "assistant") | |
yield notebook.render(mode="stopped"), notebook.data, messages | |
return | |
# Execution sandbox call - might timeout | |
logger.info("Executing edited code in sandbox") | |
execution = sbx.run_code(code) | |
notebook.append_execution(execution) | |
# Update error and warning tracking for next iteration | |
previous_execution_had_error = notebook.has_execution_error(execution) | |
previous_execution_had_warnings = notebook.has_execution_warnings(execution) | |
# Log tool execution in session state | |
tool_response_content = parse_exec_result_llm(execution, max_code_output=max_code_output) | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "edit_and_execute_current_cell", | |
tool_args, tool_response_content, execution | |
) | |
if previous_execution_had_error: | |
logger.warning("Edited code execution resulted in error") | |
elif previous_execution_had_warnings: | |
logger.info("Edited code execution completed with warnings") | |
else: | |
logger.info("Edited code execution completed successfully") | |
except Exception as e: | |
# Handle sandbox timeout/execution errors | |
logger.error(f"Edited code execution failed: {str(e)}") | |
# Add detailed error information for code execution failures | |
error_message = str(e) | |
if "timeout" in error_message.lower(): | |
detailed_error = f"""**Code Execution Timeout** ⏰ | |
The edited code execution took too long and was terminated: | |
- Code may have entered an infinite loop | |
- Processing large datasets can cause timeouts | |
- Complex computations may exceed time limits | |
**What you can try:** | |
- Optimize your code for better performance | |
- Break down complex operations into smaller steps | |
- Increase the timeout limit in settings | |
- Check for infinite loops or blocking operations | |
**Technical details:** | |
``` | |
{error_message} | |
```""" | |
else: | |
detailed_error = f"""**Code Execution Failed** 💥 | |
An error occurred while executing the edited code in the sandbox: | |
**Technical details:** | |
``` | |
{error_message} | |
``` | |
**What you can try:** | |
- Check the code for syntax errors | |
- Verify all required packages are available | |
- Try simplifying the code | |
- Check the sandbox logs for more details""" | |
notebook.add_error(detailed_error) | |
yield notebook.render(mode="error"), notebook.data, messages | |
return | |
# Prepare tool response | |
raw_execution = notebook.parse_exec_result_nb(execution) | |
logger.debug(f"Tool response: {len(tool_response_content)} chars content, {len(raw_execution)} raw outputs") | |
# Add tool response to session state only | |
session_state_manager.add_message( | |
session_state, "tool", tool_response_content, | |
tool_call_id=tool_call.id, raw_execution=raw_execution, | |
metadata={"turn": turns, "execution_successful": not previous_execution_had_error, "action": "edit_cell"} | |
) | |
elif tool_call.function.name == "execute_shell_command": | |
# Update phase to executing shell command | |
session_state_manager.update_execution_state(session_state, current_phase="executing_shell") | |
logger.debug(f"Processing shell command tool call: {tool_call.id}") | |
tool_args = json.loads(tool_call.function.arguments) | |
command = tool_args["command"] | |
logger.debug(f"Shell command to execute: '{command}'") | |
# Add shell command to notebook with special styling | |
notebook.add_shell_command(command) | |
logger.debug("Yielding notebook in 'executing' mode") | |
yield notebook.render(mode="executing"), notebook.data, messages | |
try: | |
# Check for stop event before execution | |
if stop_event and stop_event.is_set(): | |
logger.info("Stop event detected before shell execution") | |
stopped_message = """**Execution Stopped** ⏸️ | |
The execution was stopped by user request before the shell command could run.""" | |
notebook.add_markdown(stopped_message, "assistant") | |
yield notebook.render(mode="stopped"), notebook.data, messages | |
return | |
# Execute shell command in sandbox using raw shell execution | |
logger.info(f"Executing raw shell command in sandbox: {command}") | |
try: | |
# Use the new raw shell execution method | |
if hasattr(sbx, 'run_shell'): | |
shell_execution = sbx.run_shell(command, timeout=60) | |
logger.info("Shell command executed using raw shell method") | |
else: | |
# Fallback: Execute shell command using Python subprocess within sandbox | |
shell_code = f""" | |
import subprocess | |
import sys | |
try: | |
result = subprocess.run( | |
{repr(command)}, | |
shell=True, | |
capture_output=True, | |
text=True, | |
timeout=60 | |
) | |
if result.stdout: | |
print("STDOUT:") | |
print(result.stdout) | |
if result.stderr: | |
print("STDERR:") | |
print(result.stderr) | |
print(f"Exit code: {{result.returncode}}") | |
except subprocess.TimeoutExpired: | |
print("Error: Command timed out after 60 seconds") | |
except Exception as e: | |
print(f"Error executing command: {{e}}") | |
""" | |
shell_execution = sbx.run_code(shell_code) | |
logger.info("Shell command executed via Python subprocess fallback") | |
# Add shell execution results to notebook | |
notebook.append_shell_execution(shell_execution) | |
# Prepare response content for LLM | |
shell_response_content = parse_exec_result_llm(shell_execution, max_code_output=max_code_output) | |
# Log tool execution in session state | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "execute_shell_command", | |
tool_args, shell_response_content, shell_execution | |
) | |
# Check for errors | |
shell_had_error = notebook.has_execution_error(shell_execution) | |
if shell_had_error: | |
logger.warning("Shell command execution resulted in error") | |
else: | |
logger.info("Shell command execution completed successfully") | |
except Exception as shell_error: | |
logger.error(f"Shell command execution failed: {str(shell_error)}") | |
# Create error message | |
detailed_error = f"""**Shell Command Failed** 🔧 | |
An error occurred while executing the shell command: | |
**Command:** `{command}` | |
**Technical details:** | |
``` | |
{str(shell_error)} | |
``` | |
**What you can try:** | |
- Check if the command exists in the sandbox environment | |
- Verify command syntax | |
- Try a simpler version of the command | |
- Check if required tools/packages are installed""" | |
notebook.add_error(detailed_error) | |
# Log failed execution | |
session_state_manager.log_tool_execution( | |
session_state, tool_call.id, "execute_shell_command", | |
tool_args, detailed_error | |
) | |
yield notebook.render(mode="error"), notebook.data, messages | |
return | |
except Exception as e: | |
# Handle general execution errors | |
logger.error(f"Shell command execution failed: {str(e)}") | |
detailed_error = f"""**Shell Execution Error** ⚠️ | |
An unexpected error occurred while executing the shell command: | |
**Command:** `{command}` | |
**Technical details:** | |
``` | |
{str(e)} | |
```""" | |
notebook.add_error(detailed_error) | |
yield notebook.render(mode="error"), notebook.data, messages | |
return | |
# Prepare tool response for LLM and session state | |
raw_execution = notebook.parse_exec_result_nb(shell_execution) | |
logger.debug(f"Shell tool response: {len(shell_response_content)} chars content") | |
# Add tool response to session state | |
session_state_manager.add_message( | |
session_state, "tool", shell_response_content, | |
tool_call_id=tool_call.id, raw_execution=raw_execution, | |
metadata={"turn": turns, "command": command, "execution_successful": not shell_had_error, "action": "shell_command"} | |
) | |
else: | |
logger.warning(f"Unknown tool call function: {tool_call.function.name}") | |
if not tool_calls: | |
logger.info(f"No tool calls on turn {turns}, conversation ending") | |
if len(full_response.strip())==0: | |
logger.error("Assistant provided no content and no tool calls") | |
notebook.add_error(f"No tool call and empty assistant response:\n{response.model_dump_json(indent=2)}") | |
# Only add the final assistant message if we didn't already add it above | |
# (in the elif full_response.strip() block) | |
if full_response.strip(): | |
# Since we're now only using session state, we can safely add the message | |
# The session state manager will handle any deduplication if needed | |
session_state_manager.add_message( | |
session_state, "assistant", full_response, | |
metadata={"turn": turns, "type": "final_response"} | |
) | |
logger.debug("Added final assistant response to session state") | |
done = True | |
# Update session state after each turn | |
session_state_manager.update_execution_state( | |
session_state, current_turn=turns, last_execution_successful=not previous_execution_had_error | |
) | |
session_state_manager.update_notebook_data(session_state, notebook.data) | |
session_state_manager.save_state(session_state) | |
if done: | |
logger.info(f"Interactive notebook completed after {turns} turns") | |
session_state_manager.update_execution_state( | |
session_state, is_running=False, sandbox_active=True | |
) | |
session_state_manager.save_state(session_state) | |
yield notebook.render(mode="done"), notebook.data, messages | |
else: | |
logger.debug(f"Turn {turns} completed, yielding in 'generating' mode") | |
yield notebook.render(mode="generating"), notebook.data, messages | |
if turns > MAX_TURNS: | |
logger.warning(f"Interactive notebook reached maximum turns ({MAX_TURNS})") | |
error_msg = f"**Maximum Turns Reached** 🔄\n\nThe conversation has reached the maximum number of turns ({MAX_TURNS}). This is a safety limit to prevent infinite loops.\n\n**What you can try:**\n- Start a new conversation\n- Clear the notebook and begin fresh\n- Contact support if you need a higher turn limit" | |
notebook.add_error(error_msg) | |
# Add error to session state | |
session_state_manager.add_message( | |
session_state, "assistant", error_msg, | |
metadata={"type": "error", "error_type": "max_turns_exceeded", "turn": turns} | |
) | |
# Update final state | |
session_state_manager.update_execution_state( | |
session_state, is_running=False, last_execution_successful=False | |
) | |
session_state_manager.update_notebook_data(session_state, notebook.data) | |
session_state_manager.save_state(session_state) | |
yield notebook.render(mode="error"), notebook.data, messages | |
elif stop_event and stop_event.is_set(): | |
logger.info("Interactive notebook stopped by user") | |
# Add a stopped message to the notebook | |
stopped_message = """**Execution Stopped** ⏸️ | |
The execution was stopped by user request. You can resume by clicking Run again.""" | |
notebook.add_markdown(stopped_message, "assistant") | |
# Add stopped message to session state | |
session_state_manager.add_message( | |
session_state, "assistant", stopped_message, | |
metadata={"type": "status", "status_type": "stopped_by_user", "turn": turns} | |
) | |
# Update state to indicate pause | |
session_state_manager.update_execution_state( | |
session_state, is_running=False, is_paused=True | |
) | |
session_state_manager.update_notebook_data(session_state, notebook.data) | |
session_state_manager.save_state(session_state) | |
yield notebook.render(mode="stopped"), notebook.data, messages | |
def run_interactive_notebook(client, model, messages, sbx, stop_event=None, tools=None): | |
"""Backward compatibility wrapper for the new session state system""" | |
logger.warning("Using legacy run_interactive_notebook - this should be replaced with session state version") | |
# Create a temporary session for backward compatibility | |
import uuid | |
temp_session_id = str(uuid.uuid4())[:8] | |
session_manager = SessionStateManager(temp_session_id) | |
# Create basic session state | |
session_state = session_manager.create_initial_state( | |
hardware_config={"gpu_type": "unknown", "cpu_cores": 2, "memory_gb": 8, "timeout_sec": 300}, | |
api_config={"model_name": model, "provider_type": "unknown"}, | |
environment={"variables": "", "files_uploaded": []}, | |
system_prompt=messages[0].get("content", "") if messages and messages[0].get("role") == "system" else "" | |
) | |
# Initialize conversation history with provided messages | |
session_state["conversation_history"] = messages | |
# Use the new session-based function | |
yield from run_interactive_notebook_with_session_state( | |
client, model, session_manager, session_state, sbx, stop_event, tools | |
) |