Spaces:
Sleeping
Sleeping
File size: 6,914 Bytes
3e56848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import ast
import re
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder
from importlib.machinery import ModuleSpec
from importlib.util import spec_from_loader
from types import ModuleType
from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast
from aworld.trace.base import log_trace_error
from .rewrite_ast import compile_source
if TYPE_CHECKING:
from .context_manager import TraceManager
class AutoTraceModule:
"""A class that represents a module being imported that should maybe be traced automatically."""
def __init__(self, module_name: str) -> None:
self._module_name = module_name
"""Fully qualified absolute name of the module being imported."""
def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool:
"""
Check if the module name starts with the given prefix.
"""
if isinstance(prefix, str):
prefix = (prefix,)
pattern = '|'.join([get_module_pattern(p) for p in prefix])
return bool(re.match(pattern, self._module_name))
class TraceImportFinder(MetaPathFinder):
"""A class that implements the `find_spec` method of the `MetaPathFinder` protocol."""
def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool],
min_duration_ns: int) -> None:
self._trace_manager = trace_manager
self._modules_filter = module_funcs
self._min_duration_ns = min_duration_ns
def _find_plain_specs(
self, fullname: str, path: Sequence[str] = None, target: ModuleType = None
) -> Iterator[ModuleSpec]:
"""Yield module specs returned by other finders on `sys.meta_path`."""
for finder in sys.meta_path:
# Skip this finder or any like it to avoid infinite recursion.
if isinstance(finder, TraceImportFinder):
continue
try:
plain_spec = finder.find_spec(fullname, path, target)
except Exception: # pragma: no cover
continue
if plain_spec:
yield plain_spec
def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None:
"""Find the spec for the given module name."""
for plain_spec in self._find_plain_specs(fullname, path, target):
# Get module specs returned by other finders on `sys.meta_path`
get_source = getattr(plain_spec.loader, 'get_source', None)
if not callable(get_source):
continue
try:
source = cast(str, get_source(fullname))
except Exception:
continue
if not source:
continue
filename = plain_spec.origin
if not filename:
try:
filename = cast('str | None', plain_spec.loader.get_filename(fullname))
except Exception:
pass
filename = filename or f'<{fullname}>'
if not self._modules_filter(AutoTraceModule(fullname)):
return None
try:
tree = ast.parse(source)
except Exception:
# Invalid source code. Try another one.
continue
try:
execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns)
except Exception: # pragma: no cover
log_trace_error()
return None
loader = AutoTraceLoader(plain_spec, execute)
return spec_from_loader(fullname, loader)
class AutoTraceLoader(Loader):
"""
A class that implements the `exec_module` method of the `Loader` protocol.
"""
def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None:
self._plain_spec = plain_spec
self._execute = execute
def exec_module(self, module: ModuleType):
"""Execute a modified AST of the module's source code in the module's namespace.
"""
self._execute(module.__dict__)
def create_module(self, spec: ModuleSpec):
return None
def get_code(self, _name: str):
"""`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol.
So return some code which can be executed with the module namespace.
Here `__loader__` will be this object, i.e. `self`.
source = '__loader__.execute(globals())'
return compile(source, '<string>', 'exec', dont_inherit=True)
"""
def __getattr__(self, item: str):
"""Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist."""
if item in {'get_filename', 'is_package'}:
return getattr(self.plain_spec.loader, item)
raise AttributeError(item)
def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]:
"""Convert a sequence of module names to a function that checks if a module name starts with any of the given module names.
"""
return lambda module: module.need_auto_trace(modules)
def get_module_pattern(module: str):
"""
Get the regex pattern for the given module name.
"""
if not re.match(r'[\w.]+$', module, re.UNICODE):
return module
module = re.escape(module)
return rf'{module}($|\.)'
def install_auto_tracing(trace_manager: "TraceManager",
modules: Union[Sequence[str],
Callable[[AutoTraceModule], bool]],
min_duration_seconds: float
) -> None:
"""
Automatically trace the execution of a function.
"""
if isinstance(modules, Sequence):
module_funcs = convert_to_modules_func(modules)
else:
module_funcs = modules
if not callable(module_funcs):
raise TypeError('modules must be a list of strings or a callable')
for module in list(sys.modules.values()):
try:
auto_trace_module = AutoTraceModule(module.__name__)
except Exception:
continue
if module_funcs(auto_trace_module):
warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. '
f'Call `auto_tracing` earlier',
stacklevel=2,
)
min_duration_ns = int(min_duration_seconds * 1_000_000_000)
trace_manager = trace_manager.new_manager('auto_tracing')
finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns)
sys.meta_path.insert(0, finder)
T = TypeVar('T')
def not_auto_trace(x: T) -> T:
"""Decorator to prevent a function/class from being traced by `auto_tracing`"""
return x
|