File size: 13,876 Bytes
c922f8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
Logging Integration for Gaia System

This module provides functions to integrate the enhanced logging framework
with the existing Gaia system components, including:

- Agent integration
- Tool registry integration
- API client integration
- Memory integration
- Graph workflow integration

It provides decorators and wrapper functions to add logging to existing components
with minimal changes to the original code.
"""

import functools
import inspect
import json
import time
import traceback
from typing import Dict, Any, Optional, Callable, List, Union, Type

from utils.logging_framework import (
    log_timing, 
    log_error, 
    log_api_request, 
    log_api_response,
    log_tool_selection, 
    log_tool_execution,
    log_workflow_step,
    log_memory_operation,
    TimingContext,
    get_trace_id,
    set_trace_id,
    generate_trace_id,
    initialize_logging
)

def integrate_agent_logging(agent_class: Type):
    """
    Integrate logging with the GAIAAgent class.
    
    This function monkey-patches the GAIAAgent class to add logging to key methods.
    
    Args:
        agent_class: The GAIAAgent class to patch
    """
    original_init = agent_class.__init__
    original_process_question = agent_class.process_question
    
    @functools.wraps(original_init)
    def init_with_logging(self, config=None, *args, **kwargs):
        # Initialize with a new trace ID
        trace_id = generate_trace_id()
        set_trace_id(trace_id)
        
        # Log initialization
        with TimingContext("agent initialization", "agent"):
            original_init(self, config, *args, **kwargs)
        
        # Store trace ID in agent
        self._trace_id = trace_id
    
    @functools.wraps(original_process_question)
    def process_question_with_logging(self, question, *args, **kwargs):
        # Set trace ID for this question
        if hasattr(self, '_trace_id'):
            set_trace_id(self._trace_id)
        else:
            self._trace_id = generate_trace_id()
            set_trace_id(self._trace_id)
        
        # Log the question
        log_workflow_step("process_question", f"Processing question: {question}", 
                         inputs={"question": question})
        
        # Process the question with timing
        with TimingContext("process_question", "agent"):
            try:
                result = original_process_question(self, question, *args, **kwargs)
                
                # Log the result
                log_workflow_step("generate_answer", "Generated answer", 
                                 outputs={"answer": result})
                
                return result
            except Exception as e:
                # Log the error
                log_error(e, context={"question": question}, critical=True)
                raise
    
    # Apply the patches
    agent_class.__init__ = init_with_logging
    agent_class.process_question = process_question_with_logging

def integrate_tool_registry_logging(registry_class: Type):
    """
    Integrate logging with the ToolRegistry class.
    
    This function monkey-patches the ToolRegistry class to add logging to key methods.
    
    Args:
        registry_class: The ToolRegistry class to patch
    """
    original_execute_tool = registry_class.execute_tool
    
    @functools.wraps(original_execute_tool)
    def execute_tool_with_logging(self, name, **kwargs):
        # Log tool selection
        log_tool_selection(name, f"Selected tool: {name}", inputs=kwargs)
        
        # Execute the tool with timing
        start_time = time.time()
        try:
            with TimingContext(f"tool_execution_{name}", "tool"):
                result = original_execute_tool(self, name, **kwargs)
            
            # Calculate duration
            end_time = time.time()
            duration = end_time - start_time
            
            # Log successful execution
            log_tool_execution(name, True, result=result, duration=duration)
            
            return result
        except Exception as e:
            # Calculate duration
            end_time = time.time()
            duration = end_time - start_time
            
            # Log failed execution
            log_tool_execution(name, False, error=str(e), duration=duration)
            
            # Log the error
            log_error(e, context={"tool_name": name, "inputs": kwargs})
            
            raise
    
    # Apply the patch
    registry_class.execute_tool = execute_tool_with_logging

def patch_api_client(client_class: Type, api_name: str):
    """
    Patch an API client class to add logging.
    
    Args:
        client_class: The API client class to patch
        api_name: The name of the API (e.g., "OpenAI", "Serper")
    """
    # Find methods that make API requests
    for name, method in inspect.getmembers(client_class, inspect.isfunction):
        if name.startswith('_'):
            continue
        
        original_method = getattr(client_class, name)
        
        @functools.wraps(original_method)
        def api_method_with_logging(self, *args, method_name=name, **kwargs):
            # Extract endpoint from method name or kwargs
            endpoint = kwargs.get('endpoint', method_name)
            
            # Extract HTTP method (default to POST)
            http_method = kwargs.get('method', 'POST')
            
            # Log the request
            log_api_request(api_name, endpoint, http_method, params=kwargs)
            
            # Make the request with timing
            start_time = time.time()
            try:
                with TimingContext(f"{api_name}_{method_name}", "api"):
                    response = original_method(self, *args, **kwargs)
                
                # Calculate duration
                end_time = time.time()
                duration = end_time - start_time
                
                # Extract status code (if available)
                status_code = 200
                if hasattr(response, 'status_code'):
                    status_code = response.status_code
                elif isinstance(response, dict) and 'status_code' in response:
                    status_code = response['status_code']
                
                # Log the response
                log_api_response(api_name, endpoint, status_code, response, duration)
                
                return response
            except Exception as e:
                # Calculate duration
                end_time = time.time()
                duration = end_time - start_time
                
                # Log the error
                log_error(e, context={"api_name": api_name, "endpoint": endpoint, "params": kwargs})
                
                raise
        
        # Create a closure with the method name
        def create_closure(method_name):
            @functools.wraps(original_method)
            def wrapper(self, *args, **kwargs):
                return api_method_with_logging(self, *args, method_name=method_name, **kwargs)
            return wrapper
        
        # Apply the patch
        setattr(client_class, name, create_closure(name))

def integrate_memory_logging(memory_class: Type):
    """
    Integrate logging with the memory class.
    
    This function monkey-patches the memory class to add logging to key methods.
    
    Args:
        memory_class: The memory class to patch
    """
    # Find methods for storing and retrieving data
    for name, method in inspect.getmembers(memory_class, inspect.isfunction):
        if name.startswith('_'):
            continue
        
        # Determine operation type
        operation = None
        if 'store' in name or 'save' in name or 'add' in name or 'insert' in name:
            operation = 'store'
        elif 'get' in name or 'retrieve' in name or 'load' in name or 'fetch' in name:
            operation = 'get'
        elif 'update' in name or 'modify' in name:
            operation = 'update'
        elif 'delete' in name or 'remove' in name:
            operation = 'delete'
        else:
            continue  # Skip methods that don't match known operations
        
        original_method = getattr(memory_class, name)
        
        @functools.wraps(original_method)
        def memory_method_with_logging(self, *args, method_name=name, op_type=operation, **kwargs):
            # Extract key from args or kwargs
            key = None
            if args:
                key = str(args[0])
            elif 'key' in kwargs:
                key = kwargs['key']
            
            # Determine value type
            value_type = "unknown"
            if op_type in ['store', 'update']:
                value = args[1] if len(args) > 1 else kwargs.get('value')
                if value is not None:
                    value_type = type(value).__name__
            
            try:
                # Execute the method
                with TimingContext(f"memory_{op_type}", "memory"):
                    result = original_method(self, *args, **kwargs)
                
                # Log successful operation
                log_memory_operation(op_type, key, value_type, True)
                
                return result
            except Exception as e:
                # Log failed operation
                log_memory_operation(op_type, key, value_type, False, str(e))
                
                # Log the error
                log_error(e, context={"operation": op_type, "key": key})
                
                raise
        
        # Create a closure with the method name and operation type
        def create_closure(method_name, op_type):
            @functools.wraps(original_method)
            def wrapper(self, *args, **kwargs):
                return memory_method_with_logging(self, *args, method_name=method_name, op_type=op_type, **kwargs)
            return wrapper
        
        # Apply the patch
        setattr(memory_class, name, create_closure(name, operation))

def integrate_graph_logging(graph_module):
    """
    Integrate logging with the LangGraph workflow.
    
    This function monkey-patches the graph module functions to add logging.
    
    Args:
        graph_module: The graph module to patch
    """
    # Find all functions in the module
    for name in dir(graph_module):
        if name.startswith('_'):
            continue
        
        attr = getattr(graph_module, name)
        if not callable(attr):
            continue
        
        original_func = attr
        
        @functools.wraps(original_func)
        def graph_func_with_logging(*args, func_name=name, **kwargs):
            # Log the step
            log_workflow_step(func_name, f"Executing graph step: {func_name}", 
                             inputs=kwargs if kwargs else {"args": str(args)})
            
            # Execute the function with timing
            with TimingContext(f"graph_{func_name}", "graph"):
                try:
                    result = original_func(*args, **kwargs)
                    
                    # Log the result
                    log_workflow_step(func_name, f"Completed graph step: {func_name}", 
                                     outputs={"result": str(result)[:200] + "..." if isinstance(result, str) and len(str(result)) > 200 else str(result)})
                    
                    return result
                except Exception as e:
                    # Log the error
                    log_error(e, context={"step": func_name, "inputs": kwargs if kwargs else {"args": str(args)}})
                    
                    raise
        
        # Create a closure with the function name
        def create_closure(func_name):
            @functools.wraps(original_func)
            def wrapper(*args, **kwargs):
                return graph_func_with_logging(*args, func_name=func_name, **kwargs)
            return wrapper
        
        # Apply the patch
        setattr(graph_module, name, create_closure(name))

def setup_logging_integration():
    """
    Set up logging integration for all Gaia components.
    
    This function integrates logging with all major components of the Gaia system.
    """
    # Initialize logging
    initialize_logging(verbose=True)
    
    # Import components
    from agent.agent import GAIAAgent
    from agent.tool_registry import ToolRegistry
    from agent import graph
    
    try:
        # Integrate with agent
        integrate_agent_logging(GAIAAgent)
        
        # Integrate with tool registry
        integrate_tool_registry_logging(ToolRegistry)
        
        # Integrate with graph workflow
        integrate_graph_logging(graph)
        
        # Integrate with API clients
        try:
            from langchain.chat_models import ChatOpenAI
            patch_api_client(ChatOpenAI, "OpenAI")
        except ImportError:
            pass
        
        try:
            from tools.web_tools import SerperSearchTool
            patch_api_client(SerperSearchTool, "Serper")
        except ImportError:
            pass
        
        try:
            from tools.perplexity_tool import PerplexityTool
            patch_api_client(PerplexityTool, "Perplexity")
        except ImportError:
            pass
        
        # Integrate with memory
        try:
            from memory.supabase_memory import SupabaseMemory
            integrate_memory_logging(SupabaseMemory)
        except ImportError:
            pass
        
        return True
    except Exception as e:
        print(f"Error setting up logging integration: {str(e)}")
        traceback.print_exc()
        return False

if __name__ == "__main__":
    # Test the logging integration
    setup_logging_integration()
    print("Logging integration set up successfully")