Duibonduil's picture
Upload 11 files
7a3c0ee verified
import flask
import weakref
from typing import Any, Callable, Collection
from time import time_ns
from timeit import default_timer
from importlib_metadata import version
from packaging import version as package_version
from aworld.trace.instrumentation import Instrumentor
from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
from aworld.metrics.metric import MetricType
from aworld.metrics.template import MetricTemplate
from aworld.logs.util import logger
from aworld.trace.instrumentation.http_util import (
collect_request_attributes,
url_disabled,
get_excluded_urls,
parse_excluded_urls,
HTTP_ROUTE
)
from aworld.trace.propagator import get_global_trace_propagator
from aworld.metrics.context_manager import MetricContext
from aworld.trace.propagator.carrier import ListTupleCarrier, DictCarrier
_ENVIRON_STARTTIME_KEY = "aworld-flask.starttime_key"
_ENVIRON_SPAN_KEY = "aworld-flask.span_key"
_ENVIRON_REQCTX_REF_KEY = "aworld-flask.reqctx_ref_key"
flask_version = version("flask")
if package_version.parse(flask_version) >= package_version.parse("2.2.0"):
def _request_ctx_ref() -> weakref.ReferenceType:
return weakref.ref(flask.globals.request_ctx._get_current_object())
else:
def _request_ctx_ref() -> weakref.ReferenceType:
return weakref.ref(flask._request_ctx_stack.top)
def _rewrapped_app(
wsgi_app,
active_requests_counter,
duration_histogram,
response_hook=None,
excluded_urls=None,
):
def _wrapped_app(wrapped_app_environ, start_response):
# We want to measure the time for route matching, etc.
# In theory, we could start the span here and use
# update_name later but that API is "highly discouraged" so
# we better avoid it.
wrapped_app_environ[_ENVIRON_STARTTIME_KEY] = time_ns()
start = default_timer()
attributes = collect_request_attributes(wrapped_app_environ)
if MetricContext.metric_initialized():
MetricContext.inc(active_requests_counter, 1, attributes)
request_route = None
def _start_response(status, response_headers, *args, **kwargs):
if flask.request and (
excluded_urls is None
or not url_disabled(flask.request.url, excluded_urls)
):
nonlocal request_route
request_route = flask.request.url_rule
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
propagator = get_global_trace_propagator()
if propagator and span:
trace_context = TraceContext(
trace_id=span.get_trace_id(),
span_id=span.get_span_id()
)
propagator.inject(
trace_context, ListTupleCarrier(response_headers))
if span and span.is_recording():
status_code_str, _ = status.split(" ", 1)
try:
status_code = int(status_code_str)
except ValueError:
status_code = -1
span.set_attribute(
"http.response.status_code", status_code)
span.set_attributes(attributes)
if response_hook is not None:
response_hook(span, status, response_headers)
return start_response(status, response_headers, *args, **kwargs)
result = wsgi_app(wrapped_app_environ, _start_response)
duration_s = default_timer() - start
if MetricContext.metric_initialized():
MetricContext.histogram_record(
duration_histogram,
duration_s,
attributes
)
MetricContext.dec(active_requests_counter, 1, attributes)
return result
return _wrapped_app
def _wrapped_before_request(
request_hook=None,
tracer: Tracer = None,
excluded_urls=None
):
def _before_request():
if excluded_urls and url_disabled(flask.request.url, excluded_urls):
return
flask_request_environ = flask.request.environ
logger.info(
f"_wrapped_before_request flask_request_environ={flask_request_environ}")
attributes = collect_request_attributes(flask_request_environ)
if flask.request.url_rule:
# For 404 that result from no route found, etc, we
# don't have a url_rule.
attributes[HTTP_ROUTE] = flask.request.url_rule.rule
span_name = f"HTTP {flask.request.url_rule.rule}"
else:
span_name = f"HTTP {flask.request.url}"
propagator = get_global_trace_propagator()
trace_context = None
if propagator:
trace_context = propagator.extract(
DictCarrier(flask_request_environ))
logger.info(f"_wrapped_before_request trace_context={trace_context}")
span = tracer.start_span(
span_name,
SpanType.SERVER,
attributes=attributes,
start_time=flask_request_environ.get(_ENVIRON_STARTTIME_KEY),
trace_context=trace_context
)
if request_hook:
request_hook(span, flask_request_environ)
flask_request_environ[_ENVIRON_SPAN_KEY] = span
flask_request_environ[_ENVIRON_REQCTX_REF_KEY] = _request_ctx_ref()
return _before_request
def _wrapped_teardown_request(
excluded_urls=None,
):
def _teardown_request(exc):
if excluded_urls and url_disabled(flask.request.url, excluded_urls):
return
span: Span = flask.request.environ.get(_ENVIRON_SPAN_KEY)
original_reqctx_ref = flask.request.environ.get(
_ENVIRON_REQCTX_REF_KEY
)
current_reqctx_ref = _request_ctx_ref()
if not span or original_reqctx_ref != current_reqctx_ref:
return
if exc is None:
span.end()
else:
span.record_exception(exc)
span.end()
return _teardown_request
class _InstrumentedFlask(flask.Flask):
_excluded_urls = None
_tracer_provider: TraceProvider = None
_request_hook = None
_response_hook = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = self._tracer_provider.get_tracer(
"aworld.trace.instrumentation.flask")
duration_histogram = MetricTemplate(
type=MetricType.HISTOGRAM,
name="flask_request_duration_histogram",
description="Duration of flask HTTP server requests."
)
active_requests_counter = MetricTemplate(
type=MetricType.UPDOWNCOUNTER,
name="flask_active_request_counter",
unit="1",
description="Number of active HTTP server requests.",
)
self.wsgi_app = _rewrapped_app(
self.wsgi_app,
active_requests_counter,
duration_histogram,
_InstrumentedFlask._response_hook,
excluded_urls=_InstrumentedFlask._excluded_urls
)
_before_request = _wrapped_before_request(
_InstrumentedFlask._request_hook,
tracer,
excluded_urls=_InstrumentedFlask._excluded_urls
)
self._before_request = _before_request
self.before_request(_before_request)
_teardown_request = _wrapped_teardown_request(
excluded_urls=_InstrumentedFlask._excluded_urls,
)
self.teardown_request(_teardown_request)
class FlaskInstrumentor(Instrumentor):
def instrumentation_dependencies(self) -> Collection[str]:
return ("flask >= 1.0",)
def _instrument(self, **kwargs: Any):
logger.info("Flask _instrument entered.")
self._original_flask = flask.Flask
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
if callable(request_hook):
_InstrumentedFlask._request_hook = request_hook
if callable(response_hook):
_InstrumentedFlask._response_hook = response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedFlask._tracer_provider = tracer_provider
excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFlask._excluded_urls = (
get_excluded_urls("FLASK")
if excluded_urls is None
else parse_excluded_urls(excluded_urls)
)
flask.Flask = _InstrumentedFlask
logger.info("Flask _instrument exited.")
def _uninstrument(self, **kwargs):
flask.Flask = self._original_flask
def instrument_flask(excluded_urls: str = None,
request_hook: Callable = None,
response_hook: Callable = None,
tracer_provider: TraceProvider = None,
**kwargs: Any,
):
"""
Instrument the Flask application.
Args:
excluded_urls (str): A comma separated list of URLs to be excluded from instrumentation.
request_hook (Callable): A function to be called before a request is processed.
response_hook (Callable): A function to be called after a request is processed.
tracer_provider (TraceProvider): The trace provider to use.
"""
all_kwargs = {
"excluded_urls": excluded_urls,
"request_hook": request_hook,
"response_hook": response_hook,
"tracer_provider": tracer_provider or get_tracer_provider(),
**kwargs
}
FlaskInstrumentor().instrument(**all_kwargs)
logger.info("Flask instrumented.")