Spaces:
Sleeping
Sleeping
import ast | |
import re | |
import sys | |
import warnings | |
from importlib.abc import Loader, MetaPathFinder | |
from importlib.machinery import ModuleSpec | |
from importlib.util import spec_from_loader | |
from types import ModuleType | |
from typing import TYPE_CHECKING, Sequence, Union, Callable, Iterator, TypeVar, Any, cast | |
from aworld.trace.base import log_trace_error | |
from .rewrite_ast import compile_source | |
if TYPE_CHECKING: | |
from .context_manager import TraceManager | |
class AutoTraceModule: | |
"""A class that represents a module being imported that should maybe be traced automatically.""" | |
def __init__(self, module_name: str) -> None: | |
self._module_name = module_name | |
"""Fully qualified absolute name of the module being imported.""" | |
def need_auto_trace(self, prefix: Union[str, Sequence[str]]) -> bool: | |
""" | |
Check if the module name starts with the given prefix. | |
""" | |
if isinstance(prefix, str): | |
prefix = (prefix,) | |
pattern = '|'.join([get_module_pattern(p) for p in prefix]) | |
return bool(re.match(pattern, self._module_name)) | |
class TraceImportFinder(MetaPathFinder): | |
"""A class that implements the `find_spec` method of the `MetaPathFinder` protocol.""" | |
def __init__(self, trace_manager: "TraceManager", module_funcs: Callable[[AutoTraceModule], bool], | |
min_duration_ns: int) -> None: | |
self._trace_manager = trace_manager | |
self._modules_filter = module_funcs | |
self._min_duration_ns = min_duration_ns | |
def _find_plain_specs( | |
self, fullname: str, path: Sequence[str] = None, target: ModuleType = None | |
) -> Iterator[ModuleSpec]: | |
"""Yield module specs returned by other finders on `sys.meta_path`.""" | |
for finder in sys.meta_path: | |
# Skip this finder or any like it to avoid infinite recursion. | |
if isinstance(finder, TraceImportFinder): | |
continue | |
try: | |
plain_spec = finder.find_spec(fullname, path, target) | |
except Exception: # pragma: no cover | |
continue | |
if plain_spec: | |
yield plain_spec | |
def find_spec(self, fullname: str, path: Sequence[str], target=None) -> None: | |
"""Find the spec for the given module name.""" | |
for plain_spec in self._find_plain_specs(fullname, path, target): | |
# Get module specs returned by other finders on `sys.meta_path` | |
get_source = getattr(plain_spec.loader, 'get_source', None) | |
if not callable(get_source): | |
continue | |
try: | |
source = cast(str, get_source(fullname)) | |
except Exception: | |
continue | |
if not source: | |
continue | |
filename = plain_spec.origin | |
if not filename: | |
try: | |
filename = cast('str | None', plain_spec.loader.get_filename(fullname)) | |
except Exception: | |
pass | |
filename = filename or f'<{fullname}>' | |
if not self._modules_filter(AutoTraceModule(fullname)): | |
return None | |
try: | |
tree = ast.parse(source) | |
except Exception: | |
# Invalid source code. Try another one. | |
continue | |
try: | |
execute = compile_source(tree, filename, fullname, self._trace_manager, self._min_duration_ns) | |
except Exception: # pragma: no cover | |
log_trace_error() | |
return None | |
loader = AutoTraceLoader(plain_spec, execute) | |
return spec_from_loader(fullname, loader) | |
class AutoTraceLoader(Loader): | |
""" | |
A class that implements the `exec_module` method of the `Loader` protocol. | |
""" | |
def __init__(self, plain_spec: ModuleSpec, execute: Callable[[dict[str, Any]], None]) -> None: | |
self._plain_spec = plain_spec | |
self._execute = execute | |
def exec_module(self, module: ModuleType): | |
"""Execute a modified AST of the module's source code in the module's namespace. | |
""" | |
self._execute(module.__dict__) | |
def create_module(self, spec: ModuleSpec): | |
return None | |
def get_code(self, _name: str): | |
"""`python -m` uses the `runpy` module which calls this method instead of going through the normal protocol. | |
So return some code which can be executed with the module namespace. | |
Here `__loader__` will be this object, i.e. `self`. | |
source = '__loader__.execute(globals())' | |
return compile(source, '<string>', 'exec', dont_inherit=True) | |
""" | |
def __getattr__(self, item: str): | |
"""Forward some methods to the plain spec's loader (likely a `SourceFileLoader`) if they exist.""" | |
if item in {'get_filename', 'is_package'}: | |
return getattr(self.plain_spec.loader, item) | |
raise AttributeError(item) | |
def convert_to_modules_func(modules: Sequence[str]) -> Callable[[AutoTraceModule], bool]: | |
"""Convert a sequence of module names to a function that checks if a module name starts with any of the given module names. | |
""" | |
return lambda module: module.need_auto_trace(modules) | |
def get_module_pattern(module: str): | |
""" | |
Get the regex pattern for the given module name. | |
""" | |
if not re.match(r'[\w.]+$', module, re.UNICODE): | |
return module | |
module = re.escape(module) | |
return rf'{module}($|\.)' | |
def install_auto_tracing(trace_manager: "TraceManager", | |
modules: Union[Sequence[str], | |
Callable[[AutoTraceModule], bool]], | |
min_duration_seconds: float | |
) -> None: | |
""" | |
Automatically trace the execution of a function. | |
""" | |
if isinstance(modules, Sequence): | |
module_funcs = convert_to_modules_func(modules) | |
else: | |
module_funcs = modules | |
if not callable(module_funcs): | |
raise TypeError('modules must be a list of strings or a callable') | |
for module in list(sys.modules.values()): | |
try: | |
auto_trace_module = AutoTraceModule(module.__name__) | |
except Exception: | |
continue | |
if module_funcs(auto_trace_module): | |
warnings.warn(f'The module {module.__name__!r} matches modules to trace, but it has already been imported. ' | |
f'Call `auto_tracing` earlier', | |
stacklevel=2, | |
) | |
min_duration_ns = int(min_duration_seconds * 1_000_000_000) | |
trace_manager = trace_manager.new_manager('auto_tracing') | |
finder = TraceImportFinder(trace_manager, module_funcs, min_duration_ns) | |
sys.meta_path.insert(0, finder) | |
T = TypeVar('T') | |
def not_auto_trace(x: T) -> T: | |
"""Decorator to prevent a function/class from being traced by `auto_tracing`""" | |
return x | |