Spaces:
Sleeping
Sleeping
from typing import Any, Callable | |
from .asgi import TraceMiddleware | |
from aworld.trace.instrumentation import Instrumentor | |
from aworld.trace.base import TraceProvider, get_tracer_provider | |
from aworld.trace.instrumentation.http_util import ( | |
get_excluded_urls, | |
parse_excluded_urls, | |
) | |
from aworld.utils.import_package import import_packages | |
from aworld.logs.util import logger | |
import_packages(['fastapi']) # noqa | |
import fastapi # noqa | |
class _InstrumentedFastAPI(fastapi.FastAPI): | |
"""Instrumented FastAPI class.""" | |
_tracer_provider: TraceProvider = None | |
_excluded_urls: list[str] = None | |
_server_request_hook: Callable = None | |
_client_request_hook: Callable = None | |
_client_response_hook: Callable = None | |
_instrumented_fastapi_apps = set() | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
tracer = self._tracer_provider.get_tracer( | |
"aworld.trace.instrumentation.fastapi") | |
self.add_middleware( | |
TraceMiddleware, | |
tracer=tracer, | |
excluded_urls=self._excluded_urls, | |
server_request_hook=self._server_request_hook, | |
client_request_hook=self._client_request_hook, | |
client_response_hook=self._client_response_hook | |
) | |
self._is_instrumented_by_trace = True | |
self._instrumented_fastapi_apps.add(self) | |
def __del__(self): | |
if self in self._instrumented_fastapi_apps: | |
self._instrumented_fastapi_apps.remove(self) | |
class FastAPIInstrumentor(Instrumentor): | |
"""FastAPI Instrumentor.""" | |
_original_fastapi = None | |
def uninstrument_app(app: fastapi.FastAPI): | |
app.user_middleware = [ | |
x | |
for x in app.user_middleware | |
if x.cls is not TraceMiddleware | |
] | |
app.middleware_stack = app.build_middleware_stack() | |
app._is_instrumented_by_trace = False | |
def instrumentation_dependencies(self) -> dict[str, Any]: | |
return {"fastapi": fastapi} | |
def _instrument(self, **kwargs): | |
self._original_fastapi = fastapi.FastAPI | |
_InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider") | |
_InstrumentedFastAPI._server_request_hook = kwargs.get( | |
"server_request_hook" | |
) | |
_InstrumentedFastAPI._client_request_hook = kwargs.get( | |
"client_request_hook" | |
) | |
_InstrumentedFastAPI._client_response_hook = kwargs.get( | |
"client_response_hook" | |
) | |
excluded_urls = kwargs.get("excluded_urls") | |
_InstrumentedFastAPI._excluded_urls = ( | |
get_excluded_urls("FASTAPI") | |
if excluded_urls is None | |
else parse_excluded_urls(excluded_urls) | |
) | |
fastapi.FastAPI = _InstrumentedFastAPI | |
def _uninstrument(self, **kwargs): | |
for app in _InstrumentedFastAPI._instrumented_fastapi_apps: | |
self.uninstrument_app(app) | |
_InstrumentedFastAPI._instrumented_fastapi_apps.clear() | |
fastapi.FastAPI = self._original_fastapi | |
def instrument_fastapi(excluded_urls: str = None, | |
server_request_hook: Callable = None, | |
client_request_hook: Callable = None, | |
client_response_hook: Callable = None, | |
tracer_provider: TraceProvider = None, | |
**kwargs: Any, | |
): | |
kwargs.update({ | |
"excluded_urls": excluded_urls, | |
"server_request_hook": server_request_hook, | |
"client_request_hook": client_request_hook, | |
"client_response_hook": client_response_hook, | |
"tracer_provider": tracer_provider or get_tracer_provider(), | |
}) | |
FastAPIInstrumentor().instrument(**kwargs) | |
logger.info("FastAPI instrumented.") | |