from __future__ import annotations import ast import uuid import time from pathlib import Path from collections import deque from functools import partial from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast from aworld.trace.base import AttributeValueType from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY if TYPE_CHECKING: from .context_manager import TraceManager from .auto_trace import not_auto_trace def compile_source( tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int ) -> Callable[[dict[str, Any]], None]: """Compile a modified AST of the module's source code in the module's namespace. Returns a function which accepts module globals and executes the compiled code. The modified AST wraps the body of every function definition in `with context_factories[index]():`. `context_factories` is added to the module's namespace as `aworld_`. `index` is a different constant number for each function definition. """ context_factories_var_name = f'aworld_{uuid.uuid4().hex}' # The variable name for storing context_factors in the module's namespace. context_factories: list[Callable[[], ContextManager[Any]]] = [] tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories, min_duration_ns) assert isinstance(tree, ast.Module) # for type checking # dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module. code = compile(tree, filename, 'exec', dont_inherit=True) def execute(globs: dict[str, Any]): globs[context_factories_var_name] = context_factories exec(code, globs, globs) return execute def rewrite_ast( tree: ast.AST, filename: str, context_factories_var_name: str, module_name: str, trace_manager: TraceManager, context_factories: list[Callable[[], ContextManager[Any]]], min_duration_ns: int, ) -> ast.AST: transformer = AutoTraceTransformer( context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns ) return transformer.visit(tree) class AutoTraceTransformer(ast.NodeTransformer): """Trace all encountered functions except those explicitly marked with `@no_auto_trace`.""" def __init__( self, context_factories_var_name: str, filename: str, module_name: str, trace_manager: TraceManager, context_factories: list[Callable[[], ContextManager[Any]]], min_duration_ns: int, ): self._context_factories_var_name = context_factories_var_name self._filename = filename self._module_name = module_name self._trace_manager = trace_manager self._context_factories = context_factories self._min_duration_ns = min_duration_ns self._qualname_stack: list[str] = [] def visit_ClassDef(self, node: ast.ClassDef): """Visit a class definition and rewrite its methods.""" if self.check_not_auto_trace(node): return node self._qualname_stack.append(node.name) node = cast(ast.ClassDef, self.generic_visit(node)) self._qualname_stack.pop() return node def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: """Visit a function definition and rewrite it.""" if self.check_not_auto_trace(node): return node self._qualname_stack.append(node.name) qualname = '.'.join(self._qualname_stack) self._qualname_stack.append('') self.generic_visit(node) self._qualname_stack.pop() # self._qualname_stack.pop() # node.name return self.rewrite_function(node, qualname) def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool: """Return true if the node has a `@not_auto_trace` decorator.""" return any( ( isinstance(node, ast.Name) and node.id == not_auto_trace.__name__ # or ( # isinstance(node, ast.Attribute) # and node.attr == not_auto_trace.__name__ # and isinstance(node.value, ast.Name) # and node.value.id == xxx.__name__ # ) ) for node in node.decorator_list ) def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST: """Rewrite a function definition to trace its execution.""" if has_yield(node): return node body = node.body.copy() new_body: list[ast.stmt] = [] if ( body and isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant) and isinstance(body[0].value.value, str) ): new_body.append(body.pop(0)) if not body or ( len(body) == 1 and ( isinstance(body[0], ast.Pass) or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant)) ) ): return node span = ast.With( items=[ ast.withitem( context_expr=self.trace_context_method_call_node(node, qualname), ) ], body=body, type_comment=node.type_comment, ) new_body.append(span) return ast.fix_missing_locations( ast.copy_location( type(node)( # type: ignore name=node.name, args=node.args, body=new_body, decorator_list=node.decorator_list, returns=node.returns, type_comment=node.type_comment, ), node, ) ) def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call: """Return a method call to `context_factories[index]()`.""" index = len(self._context_factories) span_factory = partial( self._trace_manager._create_auto_span, # type: ignore *self.build_create_auto_span_args(qualname, node.lineno), ) if self._min_duration_ns > 0: timer = time.time_ns min_duration = self._min_duration_ns # This needs to be as fast as possible since it's the cost of auto-tracing a function # that never actually gets instrumented because its calls are all faster than `min_duration`. class MeasureTime: __slots__ = 'start' def __enter__(_self): _self.start = timer() def __exit__(_self, *_): # the first call exceeding min_ruration will not be tracked, and subsequent calls will only be tracked if timer() - _self.start >= min_duration: self._context_factories[index] = span_factory self._context_factories.append(MeasureTime) else: self._context_factories.append(span_factory) # This node means: # context_factories[index]() # where `context_factories` is a global variable with the name `self._context_factories_var_name` # pointing to the `self.context_factories` list. return ast.Call( func=ast.Subscript( value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()), slice=ast.Index(value=ast.Constant(value=index)), # type: ignore ctx=ast.Load(), ), args=[], keywords=[], ) def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]: """Build the arguments for `create_auto_span`.""" stack_info = { 'code.filepath': get_filepath(self._filename), 'code.lineno': lineno, 'code.function': qualname, } attributes: dict[str, AttributeValueType] = {**stack_info} # type: ignore msg_template = f'Calling {self._module_name}.{qualname}' attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template span_name = msg_template return span_name, attributes def has_yield(node: ast.AST): """Return true if the node has a yield statement.""" queue = deque([node]) while queue: node = queue.popleft() for child in ast.iter_child_nodes(node): if isinstance(child, (ast.Yield, ast.YieldFrom)): return True if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)): queue.append(child) def get_filepath(file: str): """Return a dict with the filepath attribute.""" path = Path(file) if path.is_absolute(): try: path = path.relative_to(Path('.').resolve()) except ValueError: # pragma: no cover # happens if filename path is not within CWD pass return str(path)