File size: 9,443 Bytes
3e56848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
from __future__ import annotations

import ast
import uuid
import time
from pathlib import Path
from collections import deque
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ContextManager, cast

from aworld.trace.base import AttributeValueType
from aworld.trace.constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY

if TYPE_CHECKING:
    from .context_manager import TraceManager
    from .auto_trace import not_auto_trace


def compile_source(
        tree: ast.AST, filename: str, module_name: str, trace_manager: TraceManager, min_duration_ns: int
) -> Callable[[dict[str, Any]], None]:
    """Compile a modified AST of the module's source code in the module's namespace.

    Returns a function which accepts module globals and executes the compiled code.

    The modified AST wraps the body of every function definition in `with context_factories[index]():`.
    `context_factories` is added to the module's namespace as `aworld_<uuid>`.
    `index` is a different constant number for each function definition.
    """

    context_factories_var_name = f'aworld_{uuid.uuid4().hex}'
    # The variable name for storing context_factors in the module's namespace.

    context_factories: list[Callable[[], ContextManager[Any]]] = []
    tree = rewrite_ast(tree, filename, context_factories_var_name, module_name, trace_manager, context_factories,
                       min_duration_ns)
    assert isinstance(tree, ast.Module)  # for type checking
    # dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module.
    code = compile(tree, filename, 'exec', dont_inherit=True)

    def execute(globs: dict[str, Any]):
        globs[context_factories_var_name] = context_factories
        exec(code, globs, globs)

    return execute


def rewrite_ast(
        tree: ast.AST,
        filename: str,
        context_factories_var_name: str,
        module_name: str,
        trace_manager: TraceManager,
        context_factories: list[Callable[[], ContextManager[Any]]],
        min_duration_ns: int,
) -> ast.AST:
    transformer = AutoTraceTransformer(
        context_factories_var_name, filename, module_name, trace_manager, context_factories, min_duration_ns
    )
    return transformer.visit(tree)


class AutoTraceTransformer(ast.NodeTransformer):
    """Trace all encountered functions except those explicitly marked with `@no_auto_trace`."""

    def __init__(
            self,
            context_factories_var_name: str,
            filename: str,
            module_name: str,
            trace_manager: TraceManager,
            context_factories: list[Callable[[], ContextManager[Any]]],
            min_duration_ns: int,
    ):
        self._context_factories_var_name = context_factories_var_name
        self._filename = filename
        self._module_name = module_name
        self._trace_manager = trace_manager
        self._context_factories = context_factories
        self._min_duration_ns = min_duration_ns
        self._qualname_stack: list[str] = []

    def visit_ClassDef(self, node: ast.ClassDef):
        """Visit a class definition and rewrite its methods."""

        if self.check_not_auto_trace(node):
            return node

        self._qualname_stack.append(node.name)
        node = cast(ast.ClassDef, self.generic_visit(node))
        self._qualname_stack.pop()
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
        """Visit a function definition and rewrite it."""

        if self.check_not_auto_trace(node):
            return node

        self._qualname_stack.append(node.name)
        qualname = '.'.join(self._qualname_stack)
        self._qualname_stack.append('<locals>')
        self.generic_visit(node)
        self._qualname_stack.pop()  # <locals>
        self._qualname_stack.pop()  # node.name
        return self.rewrite_function(node, qualname)

    def check_not_auto_trace(self, node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef) -> bool:
        """Return true if the node has a `@not_auto_trace` decorator."""
        return any(
            (
                    isinstance(node, ast.Name)
                    and node.id == not_auto_trace.__name__
                # or (
                #     isinstance(node, ast.Attribute)
                #     and node.attr == not_auto_trace.__name__
                #     and isinstance(node.value, ast.Name)
                #     and node.value.id == xxx.__name__
                # )
            )
            for node in node.decorator_list
        )

    def rewrite_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.AST:
        """Rewrite a function definition to trace its execution."""

        if has_yield(node):
            return node

        body = node.body.copy()
        new_body: list[ast.stmt] = []
        if (
                body
                and isinstance(body[0], ast.Expr)
                and isinstance(body[0].value, ast.Constant)
                and isinstance(body[0].value.value, str)
        ):
            new_body.append(body.pop(0))

        if not body or (
                len(body) == 1
                and (
                        isinstance(body[0], ast.Pass)
                        or (isinstance(body[0], ast.Expr) and isinstance(body[0].value, ast.Constant))
                )
        ):
            return node

        span = ast.With(
            items=[
                ast.withitem(
                    context_expr=self.trace_context_method_call_node(node, qualname),
                )
            ],
            body=body,
            type_comment=node.type_comment,
        )
        new_body.append(span)

        return ast.fix_missing_locations(
            ast.copy_location(
                type(node)(  # type: ignore
                    name=node.name,
                    args=node.args,
                    body=new_body,
                    decorator_list=node.decorator_list,
                    returns=node.returns,
                    type_comment=node.type_comment,
                ),
                node,
            )
        )

    def trace_context_method_call_node(self, node: ast.FunctionDef | ast.AsyncFunctionDef, qualname: str) -> ast.Call:
        """Return a method call to `context_factories[index]()`."""

        index = len(self._context_factories)
        span_factory = partial(
            self._trace_manager._create_auto_span,  # type: ignore
            *self.build_create_auto_span_args(qualname, node.lineno),
        )
        if self._min_duration_ns > 0:

            timer = time.time_ns
            min_duration = self._min_duration_ns

            # This needs to be as fast as possible since it's the cost of auto-tracing a function
            # that never actually gets instrumented because its calls are all faster than `min_duration`.
            class MeasureTime:
                __slots__ = 'start'

                def __enter__(_self):
                    _self.start = timer()

                def __exit__(_self, *_):
                    # the first call exceeding min_ruration will not be tracked, and subsequent calls will only be tracked
                    if timer() - _self.start >= min_duration:
                        self._context_factories[index] = span_factory

            self._context_factories.append(MeasureTime)
        else:
            self._context_factories.append(span_factory)

        # This node means:
        #   context_factories[index]()
        # where `context_factories` is a global variable with the name `self._context_factories_var_name`
        # pointing to the `self.context_factories` list.
        return ast.Call(
            func=ast.Subscript(
                value=ast.Name(id=self._context_factories_var_name, ctx=ast.Load()),
                slice=ast.Index(value=ast.Constant(value=index)),  # type: ignore
                ctx=ast.Load(),
            ),
            args=[],
            keywords=[],
        )

    def build_create_auto_span_args(self, qualname: str, lineno: int) -> tuple[str, dict[str, AttributeValueType]]:
        """Build the arguments for `create_auto_span`."""

        stack_info = {
            'code.filepath': get_filepath(self._filename),
            'code.lineno': lineno,
            'code.function': qualname,
        }
        attributes: dict[str, AttributeValueType] = {**stack_info}  # type: ignore

        msg_template = f'Calling {self._module_name}.{qualname}'
        attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] = msg_template

        span_name = msg_template

        return span_name, attributes


def has_yield(node: ast.AST):
    """Return true if the node has a yield statement."""

    queue = deque([node])
    while queue:
        node = queue.popleft()
        for child in ast.iter_child_nodes(node):
            if isinstance(child, (ast.Yield, ast.YieldFrom)):
                return True
            if not isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
                queue.append(child)


def get_filepath(file: str):
    """Return a dict with the filepath attribute."""

    path = Path(file)
    if path.is_absolute():
        try:
            path = path.relative_to(Path('.').resolve())
        except ValueError:  # pragma: no cover
            # happens if filename path is not within CWD
            pass
    return str(path)