Final_Assignment_GAIAAgent / src /gaia /utils /logging_integration.py
JoachimVC's picture
Upload GAIA agent implementation files for assessment
c922f8b
"""
Logging Integration for Gaia System
This module provides functions to integrate the enhanced logging framework
with the existing Gaia system components, including:
- Agent integration
- Tool registry integration
- API client integration
- Memory integration
- Graph workflow integration
It provides decorators and wrapper functions to add logging to existing components
with minimal changes to the original code.
"""
import functools
import inspect
import json
import time
import traceback
from typing import Dict, Any, Optional, Callable, List, Union, Type
from utils.logging_framework import (
log_timing,
log_error,
log_api_request,
log_api_response,
log_tool_selection,
log_tool_execution,
log_workflow_step,
log_memory_operation,
TimingContext,
get_trace_id,
set_trace_id,
generate_trace_id,
initialize_logging
)
def integrate_agent_logging(agent_class: Type):
"""
Integrate logging with the GAIAAgent class.
This function monkey-patches the GAIAAgent class to add logging to key methods.
Args:
agent_class: The GAIAAgent class to patch
"""
original_init = agent_class.__init__
original_process_question = agent_class.process_question
@functools.wraps(original_init)
def init_with_logging(self, config=None, *args, **kwargs):
# Initialize with a new trace ID
trace_id = generate_trace_id()
set_trace_id(trace_id)
# Log initialization
with TimingContext("agent initialization", "agent"):
original_init(self, config, *args, **kwargs)
# Store trace ID in agent
self._trace_id = trace_id
@functools.wraps(original_process_question)
def process_question_with_logging(self, question, *args, **kwargs):
# Set trace ID for this question
if hasattr(self, '_trace_id'):
set_trace_id(self._trace_id)
else:
self._trace_id = generate_trace_id()
set_trace_id(self._trace_id)
# Log the question
log_workflow_step("process_question", f"Processing question: {question}",
inputs={"question": question})
# Process the question with timing
with TimingContext("process_question", "agent"):
try:
result = original_process_question(self, question, *args, **kwargs)
# Log the result
log_workflow_step("generate_answer", "Generated answer",
outputs={"answer": result})
return result
except Exception as e:
# Log the error
log_error(e, context={"question": question}, critical=True)
raise
# Apply the patches
agent_class.__init__ = init_with_logging
agent_class.process_question = process_question_with_logging
def integrate_tool_registry_logging(registry_class: Type):
"""
Integrate logging with the ToolRegistry class.
This function monkey-patches the ToolRegistry class to add logging to key methods.
Args:
registry_class: The ToolRegistry class to patch
"""
original_execute_tool = registry_class.execute_tool
@functools.wraps(original_execute_tool)
def execute_tool_with_logging(self, name, **kwargs):
# Log tool selection
log_tool_selection(name, f"Selected tool: {name}", inputs=kwargs)
# Execute the tool with timing
start_time = time.time()
try:
with TimingContext(f"tool_execution_{name}", "tool"):
result = original_execute_tool(self, name, **kwargs)
# Calculate duration
end_time = time.time()
duration = end_time - start_time
# Log successful execution
log_tool_execution(name, True, result=result, duration=duration)
return result
except Exception as e:
# Calculate duration
end_time = time.time()
duration = end_time - start_time
# Log failed execution
log_tool_execution(name, False, error=str(e), duration=duration)
# Log the error
log_error(e, context={"tool_name": name, "inputs": kwargs})
raise
# Apply the patch
registry_class.execute_tool = execute_tool_with_logging
def patch_api_client(client_class: Type, api_name: str):
"""
Patch an API client class to add logging.
Args:
client_class: The API client class to patch
api_name: The name of the API (e.g., "OpenAI", "Serper")
"""
# Find methods that make API requests
for name, method in inspect.getmembers(client_class, inspect.isfunction):
if name.startswith('_'):
continue
original_method = getattr(client_class, name)
@functools.wraps(original_method)
def api_method_with_logging(self, *args, method_name=name, **kwargs):
# Extract endpoint from method name or kwargs
endpoint = kwargs.get('endpoint', method_name)
# Extract HTTP method (default to POST)
http_method = kwargs.get('method', 'POST')
# Log the request
log_api_request(api_name, endpoint, http_method, params=kwargs)
# Make the request with timing
start_time = time.time()
try:
with TimingContext(f"{api_name}_{method_name}", "api"):
response = original_method(self, *args, **kwargs)
# Calculate duration
end_time = time.time()
duration = end_time - start_time
# Extract status code (if available)
status_code = 200
if hasattr(response, 'status_code'):
status_code = response.status_code
elif isinstance(response, dict) and 'status_code' in response:
status_code = response['status_code']
# Log the response
log_api_response(api_name, endpoint, status_code, response, duration)
return response
except Exception as e:
# Calculate duration
end_time = time.time()
duration = end_time - start_time
# Log the error
log_error(e, context={"api_name": api_name, "endpoint": endpoint, "params": kwargs})
raise
# Create a closure with the method name
def create_closure(method_name):
@functools.wraps(original_method)
def wrapper(self, *args, **kwargs):
return api_method_with_logging(self, *args, method_name=method_name, **kwargs)
return wrapper
# Apply the patch
setattr(client_class, name, create_closure(name))
def integrate_memory_logging(memory_class: Type):
"""
Integrate logging with the memory class.
This function monkey-patches the memory class to add logging to key methods.
Args:
memory_class: The memory class to patch
"""
# Find methods for storing and retrieving data
for name, method in inspect.getmembers(memory_class, inspect.isfunction):
if name.startswith('_'):
continue
# Determine operation type
operation = None
if 'store' in name or 'save' in name or 'add' in name or 'insert' in name:
operation = 'store'
elif 'get' in name or 'retrieve' in name or 'load' in name or 'fetch' in name:
operation = 'get'
elif 'update' in name or 'modify' in name:
operation = 'update'
elif 'delete' in name or 'remove' in name:
operation = 'delete'
else:
continue # Skip methods that don't match known operations
original_method = getattr(memory_class, name)
@functools.wraps(original_method)
def memory_method_with_logging(self, *args, method_name=name, op_type=operation, **kwargs):
# Extract key from args or kwargs
key = None
if args:
key = str(args[0])
elif 'key' in kwargs:
key = kwargs['key']
# Determine value type
value_type = "unknown"
if op_type in ['store', 'update']:
value = args[1] if len(args) > 1 else kwargs.get('value')
if value is not None:
value_type = type(value).__name__
try:
# Execute the method
with TimingContext(f"memory_{op_type}", "memory"):
result = original_method(self, *args, **kwargs)
# Log successful operation
log_memory_operation(op_type, key, value_type, True)
return result
except Exception as e:
# Log failed operation
log_memory_operation(op_type, key, value_type, False, str(e))
# Log the error
log_error(e, context={"operation": op_type, "key": key})
raise
# Create a closure with the method name and operation type
def create_closure(method_name, op_type):
@functools.wraps(original_method)
def wrapper(self, *args, **kwargs):
return memory_method_with_logging(self, *args, method_name=method_name, op_type=op_type, **kwargs)
return wrapper
# Apply the patch
setattr(memory_class, name, create_closure(name, operation))
def integrate_graph_logging(graph_module):
"""
Integrate logging with the LangGraph workflow.
This function monkey-patches the graph module functions to add logging.
Args:
graph_module: The graph module to patch
"""
# Find all functions in the module
for name in dir(graph_module):
if name.startswith('_'):
continue
attr = getattr(graph_module, name)
if not callable(attr):
continue
original_func = attr
@functools.wraps(original_func)
def graph_func_with_logging(*args, func_name=name, **kwargs):
# Log the step
log_workflow_step(func_name, f"Executing graph step: {func_name}",
inputs=kwargs if kwargs else {"args": str(args)})
# Execute the function with timing
with TimingContext(f"graph_{func_name}", "graph"):
try:
result = original_func(*args, **kwargs)
# Log the result
log_workflow_step(func_name, f"Completed graph step: {func_name}",
outputs={"result": str(result)[:200] + "..." if isinstance(result, str) and len(str(result)) > 200 else str(result)})
return result
except Exception as e:
# Log the error
log_error(e, context={"step": func_name, "inputs": kwargs if kwargs else {"args": str(args)}})
raise
# Create a closure with the function name
def create_closure(func_name):
@functools.wraps(original_func)
def wrapper(*args, **kwargs):
return graph_func_with_logging(*args, func_name=func_name, **kwargs)
return wrapper
# Apply the patch
setattr(graph_module, name, create_closure(name))
def setup_logging_integration():
"""
Set up logging integration for all Gaia components.
This function integrates logging with all major components of the Gaia system.
"""
# Initialize logging
initialize_logging(verbose=True)
# Import components
from agent.agent import GAIAAgent
from agent.tool_registry import ToolRegistry
from agent import graph
try:
# Integrate with agent
integrate_agent_logging(GAIAAgent)
# Integrate with tool registry
integrate_tool_registry_logging(ToolRegistry)
# Integrate with graph workflow
integrate_graph_logging(graph)
# Integrate with API clients
try:
from langchain.chat_models import ChatOpenAI
patch_api_client(ChatOpenAI, "OpenAI")
except ImportError:
pass
try:
from tools.web_tools import SerperSearchTool
patch_api_client(SerperSearchTool, "Serper")
except ImportError:
pass
try:
from tools.perplexity_tool import PerplexityTool
patch_api_client(PerplexityTool, "Perplexity")
except ImportError:
pass
# Integrate with memory
try:
from memory.supabase_memory import SupabaseMemory
integrate_memory_logging(SupabaseMemory)
except ImportError:
pass
return True
except Exception as e:
print(f"Error setting up logging integration: {str(e)}")
traceback.print_exc()
return False
if __name__ == "__main__":
# Test the logging integration
setup_logging_integration()
print("Logging integration set up successfully")