Spaces:
Sleeping
Sleeping
import flask | |
import weakref | |
from typing import Any, Callable, Collection | |
from time import time_ns | |
from timeit import default_timer | |
from importlib_metadata import version | |
from packaging import version as package_version | |
from aworld.trace.instrumentation import Instrumentor | |
from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider | |
from aworld.metrics.metric import MetricType | |
from aworld.metrics.template import MetricTemplate | |
from aworld.logs.util import logger | |
from aworld.trace.instrumentation.http_util import ( | |
collect_request_attributes, | |
url_disabled, | |
get_excluded_urls, | |
parse_excluded_urls, | |
HTTP_ROUTE | |
) | |
from aworld.trace.propagator import get_global_trace_propagator | |
from aworld.metrics.context_manager import MetricContext | |
from aworld.trace.propagator.carrier import ListTupleCarrier, DictCarrier | |
_ENVIRON_STARTTIME_KEY = "aworld-flask.starttime_key" | |
_ENVIRON_SPAN_KEY = "aworld-flask.span_key" | |
_ENVIRON_REQCTX_REF_KEY = "aworld-flask.reqctx_ref_key" | |
flask_version = version("flask") | |
if package_version.parse(flask_version) >= package_version.parse("2.2.0"): | |
def _request_ctx_ref() -> weakref.ReferenceType: | |
return weakref.ref(flask.globals.request_ctx._get_current_object()) | |
else: | |
def _request_ctx_ref() -> weakref.ReferenceType: | |
return weakref.ref(flask._request_ctx_stack.top) | |
def _rewrapped_app( | |
wsgi_app, | |
active_requests_counter, | |
duration_histogram, | |
response_hook=None, | |
excluded_urls=None, | |
): | |
def _wrapped_app(wrapped_app_environ, start_response): | |
# We want to measure the time for route matching, etc. | |
# In theory, we could start the span here and use | |
# update_name later but that API is "highly discouraged" so | |
# we better avoid it. | |
wrapped_app_environ[_ENVIRON_STARTTIME_KEY] = time_ns() | |
start = default_timer() | |
attributes = collect_request_attributes(wrapped_app_environ) | |
if MetricContext.metric_initialized(): | |
MetricContext.inc(active_requests_counter, 1, attributes) | |
request_route = None | |
def _start_response(status, response_headers, *args, **kwargs): | |
if flask.request and ( | |
excluded_urls is None | |
or not url_disabled(flask.request.url, excluded_urls) | |
): | |
nonlocal request_route | |
request_route = flask.request.url_rule | |
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY) | |
propagator = get_global_trace_propagator() | |
if propagator and span: | |
trace_context = TraceContext( | |
trace_id=span.get_trace_id(), | |
span_id=span.get_span_id() | |
) | |
propagator.inject( | |
trace_context, ListTupleCarrier(response_headers)) | |
if span and span.is_recording(): | |
status_code_str, _ = status.split(" ", 1) | |
try: | |
status_code = int(status_code_str) | |
except ValueError: | |
status_code = -1 | |
span.set_attribute( | |
"http.response.status_code", status_code) | |
span.set_attributes(attributes) | |
if response_hook is not None: | |
response_hook(span, status, response_headers) | |
return start_response(status, response_headers, *args, **kwargs) | |
result = wsgi_app(wrapped_app_environ, _start_response) | |
duration_s = default_timer() - start | |
if MetricContext.metric_initialized(): | |
MetricContext.histogram_record( | |
duration_histogram, | |
duration_s, | |
attributes | |
) | |
MetricContext.dec(active_requests_counter, 1, attributes) | |
return result | |
return _wrapped_app | |
def _wrapped_before_request( | |
request_hook=None, | |
tracer: Tracer = None, | |
excluded_urls=None | |
): | |
def _before_request(): | |
if excluded_urls and url_disabled(flask.request.url, excluded_urls): | |
return | |
flask_request_environ = flask.request.environ | |
logger.info( | |
f"_wrapped_before_request flask_request_environ={flask_request_environ}") | |
attributes = collect_request_attributes(flask_request_environ) | |
if flask.request.url_rule: | |
# For 404 that result from no route found, etc, we | |
# don't have a url_rule. | |
attributes[HTTP_ROUTE] = flask.request.url_rule.rule | |
span_name = f"HTTP {flask.request.url_rule.rule}" | |
else: | |
span_name = f"HTTP {flask.request.url}" | |
propagator = get_global_trace_propagator() | |
trace_context = None | |
if propagator: | |
trace_context = propagator.extract( | |
DictCarrier(flask_request_environ)) | |
logger.info(f"_wrapped_before_request trace_context={trace_context}") | |
span = tracer.start_span( | |
span_name, | |
SpanType.SERVER, | |
attributes=attributes, | |
start_time=flask_request_environ.get(_ENVIRON_STARTTIME_KEY), | |
trace_context=trace_context | |
) | |
if request_hook: | |
request_hook(span, flask_request_environ) | |
flask_request_environ[_ENVIRON_SPAN_KEY] = span | |
flask_request_environ[_ENVIRON_REQCTX_REF_KEY] = _request_ctx_ref() | |
return _before_request | |
def _wrapped_teardown_request( | |
excluded_urls=None, | |
): | |
def _teardown_request(exc): | |
if excluded_urls and url_disabled(flask.request.url, excluded_urls): | |
return | |
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY) | |
original_reqctx_ref = flask.request.environ.get( | |
_ENVIRON_REQCTX_REF_KEY | |
) | |
current_reqctx_ref = _request_ctx_ref() | |
if not span or original_reqctx_ref != current_reqctx_ref: | |
return | |
if exc is None: | |
span.end() | |
else: | |
span.record_exception(exc) | |
span.end() | |
return _teardown_request | |
class _InstrumentedFlask(flask.Flask): | |
_excluded_urls = None | |
_tracer_provider: TraceProvider = None | |
_request_hook = None | |
_response_hook = None | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
tracer = self._tracer_provider.get_tracer( | |
"aworld.trace.instrumentation.flask") | |
duration_histogram = MetricTemplate( | |
type=MetricType.HISTOGRAM, | |
name="flask_request_duration_histogram", | |
description="Duration of flask HTTP server requests." | |
) | |
active_requests_counter = MetricTemplate( | |
type=MetricType.UPDOWNCOUNTER, | |
name="flask_active_request_counter", | |
unit="1", | |
description="Number of active HTTP server requests.", | |
) | |
self.wsgi_app = _rewrapped_app( | |
self.wsgi_app, | |
active_requests_counter, | |
duration_histogram, | |
_InstrumentedFlask._response_hook, | |
excluded_urls=_InstrumentedFlask._excluded_urls | |
) | |
_before_request = _wrapped_before_request( | |
_InstrumentedFlask._request_hook, | |
tracer, | |
excluded_urls=_InstrumentedFlask._excluded_urls | |
) | |
self._before_request = _before_request | |
self.before_request(_before_request) | |
_teardown_request = _wrapped_teardown_request( | |
excluded_urls=_InstrumentedFlask._excluded_urls, | |
) | |
self.teardown_request(_teardown_request) | |
class FlaskInstrumentor(Instrumentor): | |
def instrumentation_dependencies(self) -> Collection[str]: | |
return ("flask >= 1.0",) | |
def _instrument(self, **kwargs: Any): | |
logger.info("Flask _instrument entered.") | |
self._original_flask = flask.Flask | |
request_hook = kwargs.get("request_hook") | |
response_hook = kwargs.get("response_hook") | |
if callable(request_hook): | |
_InstrumentedFlask._request_hook = request_hook | |
if callable(response_hook): | |
_InstrumentedFlask._response_hook = response_hook | |
tracer_provider = kwargs.get("tracer_provider") | |
_InstrumentedFlask._tracer_provider = tracer_provider | |
excluded_urls = kwargs.get("excluded_urls") | |
_InstrumentedFlask._excluded_urls = ( | |
get_excluded_urls("FLASK") | |
if excluded_urls is None | |
else parse_excluded_urls(excluded_urls) | |
) | |
flask.Flask = _InstrumentedFlask | |
logger.info("Flask _instrument exited.") | |
def _uninstrument(self, **kwargs): | |
flask.Flask = self._original_flask | |
def instrument_flask(excluded_urls: str = None, | |
request_hook: Callable = None, | |
response_hook: Callable = None, | |
tracer_provider: TraceProvider = None, | |
**kwargs: Any, | |
): | |
""" | |
Instrument the Flask application. | |
Args: | |
excluded_urls (str): A comma separated list of URLs to be excluded from instrumentation. | |
request_hook (Callable): A function to be called before a request is processed. | |
response_hook (Callable): A function to be called after a request is processed. | |
tracer_provider (TraceProvider): The trace provider to use. | |
""" | |
all_kwargs = { | |
"excluded_urls": excluded_urls, | |
"request_hook": request_hook, | |
"response_hook": response_hook, | |
"tracer_provider": tracer_provider or get_tracer_provider(), | |
**kwargs | |
} | |
FlaskInstrumentor().instrument(**all_kwargs) | |
logger.info("Flask instrumented.") | |