Duibonduil's picture
Upload 11 files
7a3c0ee verified
from timeit import default_timer
from typing import Any, Awaitable, Callable
from functools import wraps
from aworld.metrics.context_manager import MetricContext
from aworld.trace.instrumentation.http_util import (
collect_request_attributes_asgi,
url_disabled,
parser_host_port_url_from_asgi
)
from aworld.trace.base import Span, TraceProvider, TraceContext, Tracer, SpanType
from aworld.trace.propagator import get_global_trace_propagator
from aworld.trace.propagator.carrier import DictCarrier, ListTupleCarrier
from aworld.metrics.metric import MetricType
from aworld.metrics.template import MetricTemplate
from aworld.logs.util import logger
def _wrapped_receive(
server_span: Span,
server_span_name: str,
scope: dict[str, Any],
receive: Callable[[], Awaitable[dict[str, Any]]],
attributes: dict[str],
client_request_hook: Callable = None
):
@wraps(receive)
async def otel_receive():
message = await receive()
if client_request_hook and callable(client_request_hook):
client_request_hook(scope, message)
server_span.set_attribute("asgi.event.type", message.get("type", ""))
return message
return otel_receive
def _wrapped_send(
server_span: Span,
server_span_name: str,
scope: dict[str, Any],
send: Callable[[dict[str, Any]], Awaitable[None]],
attributes: dict[str],
client_response_hook: Callable = None
):
expecting_trailers = False
@wraps(send)
async def otel_send(message: dict[str, Any]):
nonlocal expecting_trailers
status_code = None
if message["type"] == "http.response.start":
status_code = message["status"]
elif message["type"] == "websocket.send":
status_code = 200
# raw_headers = message.get("headers")
# if raw_headers:
if status_code:
server_span.set_attribute(
"http.response.status_code", status_code)
if callable(client_response_hook):
client_response_hook(scope, message)
if message["type"] == "http.response.start":
expecting_trailers = message.get("trailers", False)
propagator = get_global_trace_propagator()
if propagator:
trace_context = TraceContext(
trace_id=server_span.get_trace_id(),
span_id=server_span.get_span_id()
)
propagator.inject(
trace_context, DictCarrier(message))
await send(message)
if (
not expecting_trailers
and message["type"] == "http.response.body"
and not message.get("more_body", False)
) or (
expecting_trailers
and message["type"] == "http.response.trailers"
and not message.get("more_trailers", False)
):
server_span.end()
return otel_send
class TraceMiddleware:
"""
A ASGI Middleware for tracing requests and responses.
"""
def __init__(
self,
app,
excluded_urls=None,
tracer_provider: TraceProvider = None,
tracer: Tracer = None,
server_request_hook: Callable = None,
client_request_hook: Callable = None,
client_response_hook: Callable = None,):
self.app = app
self.excluded_urls = excluded_urls
self.tracer_provider = tracer_provider
self.server_request_hook = server_request_hook
self.client_request_hook = client_request_hook
self.client_response_hook = client_response_hook
self.tracer: Tracer = (self.tracer_provider.get_tracer(
"aworld.trace.instrumentation.asgi"
) if tracer is None else tracer)
self.duration_histogram = MetricTemplate(
type=MetricType.HISTOGRAM,
name="asgi_request_duration_histogram",
description="Duration of flask HTTP server requests."
)
self.active_requests_counter = MetricTemplate(
type=MetricType.UPDOWNCOUNTER,
name="asgi_active_request_counter",
unit="1",
description="Number of active HTTP server requests.",
)
async def __call__(
self,
scope: dict[str, Any],
receive: Callable[[], Awaitable[dict[str, Any]]],
send: Callable[[dict[str, Any]], Awaitable[None]],
):
start = default_timer()
if scope["type"] not in ("http", "websocket"):
return await self.app(scope, receive, send)
_, _, url = parser_host_port_url_from_asgi(scope)
if self.excluded_urls and url_disabled(url, self.excluded_urls):
return await self.app(scope, receive, send)
span_name = scope.get("method", "HTTP").strip(
).upper() + "_" + scope.get("path", "").strip()
attributes = collect_request_attributes_asgi(scope)
if scope["type"] == "http" and MetricContext.metric_initialized():
MetricContext.inc(self.active_requests_counter, 1, attributes)
trace_context = None
propagator = get_global_trace_propagator()
if propagator:
trace_context = propagator.extract(
ListTupleCarrier(scope.get("headers", [])))
logger.info(
f"asgi extract trace_context: {trace_context}, scope: {scope}")
try:
with self.tracer.start_as_current_span(
span_name, span_type=SpanType.SERVER, trace_context=trace_context, attributes=attributes
) as span:
if callable(self.server_request_hook):
self.server_request_hook(scope)
wrappered_receive = _wrapped_receive(
span,
span_name,
scope,
receive,
attributes,
self.client_request_hook
)
wrappered_send = _wrapped_send(
span,
span_name,
scope,
send,
attributes,
self.client_response_hook
)
await self.app(scope, wrappered_receive, wrappered_send)
finally:
if scope["type"] == "http":
duration_s = default_timer() - start
if MetricContext.metric_initialized():
MetricContext.histogram_record(
self.duration_histogram,
duration_s,
attributes
)
MetricContext.inc(
self.active_requests_counter, -1, attributes)
if span.is_recording():
span.end()