Duibonduil's picture
Upload 11 files
7a3c0ee verified
from aworld.logs.util import logger
from aworld.trace.instrumentation.http_util import (
collect_attributes_from_request,
url_disabled,
get_excluded_urls,
parse_excluded_urls,
HTTP_RESPONSE_STATUS_CODE,
HTTP_FLAVOR
)
from aworld.metrics.context_manager import MetricContext
from aworld.metrics.template import MetricTemplate
from aworld.metrics.metric import MetricType
from aworld.trace.instrumentation import Instrumentor
from aworld.trace.propagator.carrier import DictCarrier
import functools
from timeit import default_timer
from requests import sessions
from requests.models import PreparedRequest, Response
from requests.structures import CaseInsensitiveDict
from typing import Collection, Any, Callable
from aworld.trace.base import TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider
from aworld.trace.propagator import get_global_trace_propagator
def _wrapped_send(
tracer: Tracer = None,
excluded_urls=None,
request_hook: Callable = None,
response_hook: Callable = None,
duration_histogram: MetricTemplate = None
):
oringinal_send = sessions.Session.send
@functools.wraps(oringinal_send)
def instrumented_send(
self: sessions.Session, request: PreparedRequest, **kwargs: Any
):
if excluded_urls and url_disabled(request.url, excluded_urls):
return oringinal_send(self, request, **kwargs)
def get_or_create_headers():
request.headers = (
request.headers
if request.headers is not None
else CaseInsensitiveDict()
)
return request.headers
method = request.method
if method is None:
method = "HTTP"
span_name = method.upper()
span_attributes = collect_attributes_from_request(request)
with tracer.start_as_current_span(
span_name, span_type=SpanType.CLIENT, attributes=span_attributes
) as span:
exception = None
if callable(request_hook):
request_hook(span, request)
headers = get_or_create_headers()
trace_context = TraceContext(
trace_id=span.get_trace_id(),
span_id=span.get_span_id(),
)
propagator = get_global_trace_propagator()
if propagator:
propagator.inject(trace_context, DictCarrier(headers))
start_time = default_timer()
try:
logger.info("Sending headers: %s", request.headers)
result = oringinal_send(
self, request, **kwargs
) # *** PROCEED
except Exception as exc: # pylint: disable=W0703
exception = exc
result = getattr(exc, "response", None)
finally:
elapsed_time = max(default_timer() - start_time, 0)
if isinstance(result, Response):
span_attributes = {}
span_attributes[HTTP_RESPONSE_STATUS_CODE] = result.status_code
if result.raw is not None:
version = getattr(result.raw, "version", None)
if version:
# Only HTTP/1 is supported by requests
version_text = "1.1" if version == 11 else "1.0"
span_attributes[HTTP_FLAVOR] = version_text
span.set_attributes(span_attributes)
if callable(response_hook):
response_hook(span, request, result)
if exception is not None:
span.record_exception(exception)
if duration_histogram is not None and MetricContext.metric_initialized():
MetricContext.histogram_record(
duration_histogram,
elapsed_time,
span_attributes
)
if exception is not None:
raise exception.with_traceback(exception.__traceback__)
return result
return instrumented_send
class _InstrumentedSession(sessions.Session):
"""
An instrumented requests.Session class.
"""
_excluded_urls = None
_tracer_provider: TraceProvider = None
_request_hook = None
_response_hook = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
tracer = self._tracer_provider.get_tracer(
"aworld.trace.instrumentation.requests")
excluded_urls = kwargs.get("excluded_urls")
duration_histogram = MetricTemplate(
type=MetricType.HISTOGRAM,
name="client_request_duration_histogram",
unit="s",
description="Duration of HTTP client requests."
)
self.send = functools.partial(_wrapped_send(
tracer=tracer,
excluded_urls=excluded_urls,
request_hook=self._request_hook,
response_hook=self._response_hook,
duration_histogram=duration_histogram
), self)
class RequestsInstrumentor(Instrumentor):
"""
An instrumentor for the requests module.
"""
def instrumentation_dependencies(self) -> Collection[str]:
return ["requests"]
def _instrument(self, **kwargs):
"""
Instruments the requests module.
"""
logger.info("requests _instrument entered.")
self._original_session = sessions.Session
request_hook = kwargs.get("request_hook")
response_hook = kwargs.get("response_hook")
if callable(request_hook):
_InstrumentedSession._request_hook = request_hook
if callable(response_hook):
_InstrumentedSession._response_hook = response_hook
tracer_provider = kwargs.get("tracer_provider")
_InstrumentedSession._tracer_provider = tracer_provider
excluded_urls = kwargs.get("excluded_urls")
_InstrumentedSession._excluded_urls = (
get_excluded_urls("FLASK")
if excluded_urls is None
else parse_excluded_urls(excluded_urls)
)
sessions.Session = _InstrumentedSession
logger.info("requests _instrument exited.")
def _uninstrument(self, **kwargs):
"""
Uninstruments the requests module.
"""
sessions.Session = self._original_session
def instrument_requests(excluded_urls: str = None,
request_hook: Callable = None,
response_hook: Callable = None,
tracer_provider: TraceProvider = None,
**kwargs: Any,
):
"""
Instruments the requests module.
Args:
excluded_urls: A comma separated list of URLs to exclude from tracing.
request_hook: A function that will be called before a request is sent.
The function will be called with the span and the request.
response_hook: A function that will be called after a response is received.
The function will be called with the span and the response.
tracer_provider: The tracer provider to use. If not provided, the global
tracer provider will be used.
kwargs: Additional keyword arguments.
"""
all_kwargs = {
"excluded_urls": excluded_urls,
"request_hook": request_hook,
"response_hook": response_hook,
"tracer_provider": tracer_provider or get_tracer_provider(),
**kwargs
}
RequestsInstrumentor().instrument(**all_kwargs)
logger.info("Requests instrumented.")