File size: 3,811 Bytes
7a3c0ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import Any, Callable
from .asgi import TraceMiddleware
from aworld.trace.instrumentation import Instrumentor
from aworld.trace.base import TraceProvider, get_tracer_provider
from aworld.trace.instrumentation.http_util import (
    get_excluded_urls,
    parse_excluded_urls,
)
from aworld.utils.import_package import import_packages
from aworld.logs.util import logger

import_packages(['fastapi'])  # noqa
import fastapi  # noqa


class _InstrumentedFastAPI(fastapi.FastAPI):
    """Instrumented FastAPI class."""
    _tracer_provider: TraceProvider = None
    _excluded_urls: list[str] = None
    _server_request_hook: Callable = None
    _client_request_hook: Callable = None
    _client_response_hook: Callable = None
    _instrumented_fastapi_apps = set()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        tracer = self._tracer_provider.get_tracer(
            "aworld.trace.instrumentation.fastapi")

        self.add_middleware(
            TraceMiddleware,
            tracer=tracer,
            excluded_urls=self._excluded_urls,
            server_request_hook=self._server_request_hook,
            client_request_hook=self._client_request_hook,
            client_response_hook=self._client_response_hook
        )

        self._is_instrumented_by_trace = True
        self._instrumented_fastapi_apps.add(self)

    def __del__(self):
        if self in self._instrumented_fastapi_apps:
            self._instrumented_fastapi_apps.remove(self)


class FastAPIInstrumentor(Instrumentor):
    """FastAPI Instrumentor."""
    _original_fastapi = None

    @staticmethod
    def uninstrument_app(app: fastapi.FastAPI):
        app.user_middleware = [
            x
            for x in app.user_middleware
            if x.cls is not TraceMiddleware
        ]
        app.middleware_stack = app.build_middleware_stack()
        app._is_instrumented_by_trace = False

    def instrumentation_dependencies(self) -> dict[str, Any]:
        return {"fastapi": fastapi}

    def _instrument(self, **kwargs):
        self._original_fastapi = fastapi.FastAPI
        _InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider")
        _InstrumentedFastAPI._server_request_hook = kwargs.get(
            "server_request_hook"
        )
        _InstrumentedFastAPI._client_request_hook = kwargs.get(
            "client_request_hook"
        )
        _InstrumentedFastAPI._client_response_hook = kwargs.get(
            "client_response_hook"
        )
        excluded_urls = kwargs.get("excluded_urls")
        _InstrumentedFastAPI._excluded_urls = (
            get_excluded_urls("FASTAPI")
            if excluded_urls is None
            else parse_excluded_urls(excluded_urls)
        )
        fastapi.FastAPI = _InstrumentedFastAPI

    def _uninstrument(self, **kwargs):
        for app in _InstrumentedFastAPI._instrumented_fastapi_apps:
            self.uninstrument_app(app)
        _InstrumentedFastAPI._instrumented_fastapi_apps.clear()
        fastapi.FastAPI = self._original_fastapi


def instrument_fastapi(excluded_urls: str = None,
                       server_request_hook: Callable = None,
                       client_request_hook: Callable = None,
                       client_response_hook: Callable = None,
                       tracer_provider: TraceProvider = None,
                       **kwargs: Any,
                       ):
    kwargs.update({
        "excluded_urls": excluded_urls,
        "server_request_hook": server_request_hook,
        "client_request_hook": client_request_hook,
        "client_response_hook": client_response_hook,
        "tracer_provider": tracer_provider or get_tracer_provider(),
    })
    FastAPIInstrumentor().instrument(**kwargs)
    logger.info("FastAPI instrumented.")