File size: 6,914 Bytes
3e56848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import ast
import re
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder
from importlib.machinery import ModuleSpec
from importlib.util import spec_from_loader
from types import ModuleType
from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast

from aworld.trace.base import log_trace_error
from .rewrite_ast import compile_source

if TYPE_CHECKING:
    from .context_manager import TraceManager


class AutoTraceModule:
    """A class that represents a module being imported that should maybe be traced automatically."""

    def __init__(self, module_name: str) -> None:
        self._module_name = module_name
        """Fully qualified absolute name of the module being imported."""

    def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool:
        """
        Check if the module name starts with the given prefix.
        """
        if isinstance(prefix, str):
            prefix = (prefix,)
        pattern = '|'.join([get_module_pattern(p) for p in prefix])
        return bool(re.match(pattern, self._module_name))


class TraceImportFinder(MetaPathFinder):
    """A class that implements the `find_spec` method of the `MetaPathFinder` protocol."""

    def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool],
                 min_duration_ns: int) -> None:
        self._trace_manager = trace_manager
        self._modules_filter = module_funcs
        self._min_duration_ns = min_duration_ns

    def _find_plain_specs(
            self, fullname: str, path: Sequence[str] = None, target: ModuleType = None
    ) -> Iterator[ModuleSpec]:
        """Yield module specs returned by other finders on `sys.meta_path`."""
        for finder in sys.meta_path:
            # Skip this finder or any like it to avoid infinite recursion.
            if isinstance(finder, TraceImportFinder):
                continue

            try:
                plain_spec = finder.find_spec(fullname, path, target)
            except Exception:  # pragma: no cover
                continue

            if plain_spec:
                yield plain_spec

    def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None:
        """Find the spec for the given module name."""

        for plain_spec in self._find_plain_specs(fullname, path, target):
            # Get module specs returned by other finders on `sys.meta_path`
            get_source = getattr(plain_spec.loader, 'get_source', None)
            if not callable(get_source):
                continue
            try:
                source = cast(str, get_source(fullname))
            except Exception:
                continue

            if not source:
                continue

            filename = plain_spec.origin
            if not filename:
                try:
                    filename = cast('str | None', plain_spec.loader.get_filename(fullname))
                except Exception:
                    pass
            filename = filename or f'<{fullname}>'

            if not self._modules_filter(AutoTraceModule(fullname)):
                return None

            try:
                tree = ast.parse(source)
            except Exception:
                # Invalid source code. Try another one.
                continue

            try:
                execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns)
            except Exception:  # pragma: no cover
                log_trace_error()
                return None

            loader = AutoTraceLoader(plain_spec, execute)
            return spec_from_loader(fullname, loader)


class AutoTraceLoader(Loader):
    """
    A class that implements the `exec_module` method of the `Loader` protocol.
    """

    def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None:
        self._plain_spec = plain_spec
        self._execute = execute

    def exec_module(self, module: ModuleType):
        """Execute a modified AST of the module's source code in the module's namespace.
        """
        self._execute(module.__dict__)

    def create_module(self, spec: ModuleSpec):
        return None

    def get_code(self, _name: str):
        """`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol.
        So return some code which can be executed with the module namespace.
        Here `__loader__` will be this object, i.e. `self`.
        source = '__loader__.execute(globals())'
        return compile(source, '<string>', 'exec', dont_inherit=True)
        """

    def __getattr__(self, item: str):
        """Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist."""
        if item in {'get_filename', 'is_package'}:
            return getattr(self.plain_spec.loader, item)
        raise AttributeError(item)


def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]:
    """Convert a sequence of module names to a function that checks if a module name starts with any of the given module names.
    """
    return lambda module: module.need_auto_trace(modules)


def get_module_pattern(module: str):
    """
    Get the regex pattern for the given module name.
    """

    if not re.match(r'[\w.]+$', module, re.UNICODE):
        return module
    module = re.escape(module)
    return rf'{module}($|\.)'


def install_auto_tracing(trace_manager: "TraceManager",
                         modules: Union[Sequence[str],
                         Callable[[AutoTraceModule], bool]],
                         min_duration_seconds: float
                         ) -> None:
    """
    Automatically trace the execution of a function.
    """
    if isinstance(modules, Sequence):
        module_funcs = convert_to_modules_func(modules)
    else:
        module_funcs = modules

    if not callable(module_funcs):
        raise TypeError('modules must be a list of strings or a callable')

    for module in list(sys.modules.values()):
        try:
            auto_trace_module = AutoTraceModule(module.__name__)
        except Exception:
            continue

        if module_funcs(auto_trace_module):
            warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. '
                          f'Call `auto_tracing` earlier',
                          stacklevel=2,
                          )

    min_duration_ns = int(min_duration_seconds * 1_000_000_000)
    trace_manager = trace_manager.new_manager('auto_tracing')
    finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns)
    sys.meta_path.insert(0, finder)


T = TypeVar('T')


def not_auto_trace(x: T) -> T:
    """Decorator to prevent a function/class from being traced by `auto_tracing`"""
    return x