Spaces:
Sleeping
Sleeping
from aworld.logs.util import logger | |
from aworld.trace.instrumentation.http_util import ( | |
collect_attributes_from_request, | |
url_disabled, | |
get_excluded_urls, | |
parse_excluded_urls, | |
HTTP_RESPONSE_STATUS_CODE, | |
HTTP_FLAVOR | |
) | |
from aworld.metrics.context_manager import MetricContext | |
from aworld.metrics.template import MetricTemplate | |
from aworld.metrics.metric import MetricType | |
from aworld.trace.instrumentation import Instrumentor | |
from aworld.trace.propagator.carrier import DictCarrier | |
import functools | |
from timeit import default_timer | |
from requests import sessions | |
from requests.models import PreparedRequest, Response | |
from requests.structures import CaseInsensitiveDict | |
from typing import Collection, Any, Callable | |
from aworld.trace.base import TraceProvider, TraceContext, Tracer, SpanType, get_tracer_provider | |
from aworld.trace.propagator import get_global_trace_propagator | |
def _wrapped_send( | |
tracer: Tracer = None, | |
excluded_urls=None, | |
request_hook: Callable = None, | |
response_hook: Callable = None, | |
duration_histogram: MetricTemplate = None | |
): | |
oringinal_send = sessions.Session.send | |
def instrumented_send( | |
self: sessions.Session, request: PreparedRequest, **kwargs: Any | |
): | |
if excluded_urls and url_disabled(request.url, excluded_urls): | |
return oringinal_send(self, request, **kwargs) | |
def get_or_create_headers(): | |
request.headers = ( | |
request.headers | |
if request.headers is not None | |
else CaseInsensitiveDict() | |
) | |
return request.headers | |
method = request.method | |
if method is None: | |
method = "HTTP" | |
span_name = method.upper() | |
span_attributes = collect_attributes_from_request(request) | |
with tracer.start_as_current_span( | |
span_name, span_type=SpanType.CLIENT, attributes=span_attributes | |
) as span: | |
exception = None | |
if callable(request_hook): | |
request_hook(span, request) | |
headers = get_or_create_headers() | |
trace_context = TraceContext( | |
trace_id=span.get_trace_id(), | |
span_id=span.get_span_id(), | |
) | |
propagator = get_global_trace_propagator() | |
if propagator: | |
propagator.inject(trace_context, DictCarrier(headers)) | |
start_time = default_timer() | |
try: | |
logger.info("Sending headers: %s", request.headers) | |
result = oringinal_send( | |
self, request, **kwargs | |
) # *** PROCEED | |
except Exception as exc: # pylint: disable=W0703 | |
exception = exc | |
result = getattr(exc, "response", None) | |
finally: | |
elapsed_time = max(default_timer() - start_time, 0) | |
if isinstance(result, Response): | |
span_attributes = {} | |
span_attributes[HTTP_RESPONSE_STATUS_CODE] = result.status_code | |
if result.raw is not None: | |
version = getattr(result.raw, "version", None) | |
if version: | |
# Only HTTP/1 is supported by requests | |
version_text = "1.1" if version == 11 else "1.0" | |
span_attributes[HTTP_FLAVOR] = version_text | |
span.set_attributes(span_attributes) | |
if callable(response_hook): | |
response_hook(span, request, result) | |
if exception is not None: | |
span.record_exception(exception) | |
if duration_histogram is not None and MetricContext.metric_initialized(): | |
MetricContext.histogram_record( | |
duration_histogram, | |
elapsed_time, | |
span_attributes | |
) | |
if exception is not None: | |
raise exception.with_traceback(exception.__traceback__) | |
return result | |
return instrumented_send | |
class _InstrumentedSession(sessions.Session): | |
""" | |
An instrumented requests.Session class. | |
""" | |
_excluded_urls = None | |
_tracer_provider: TraceProvider = None | |
_request_hook = None | |
_response_hook = None | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
tracer = self._tracer_provider.get_tracer( | |
"aworld.trace.instrumentation.requests") | |
excluded_urls = kwargs.get("excluded_urls") | |
duration_histogram = MetricTemplate( | |
type=MetricType.HISTOGRAM, | |
name="client_request_duration_histogram", | |
unit="s", | |
description="Duration of HTTP client requests." | |
) | |
self.send = functools.partial(_wrapped_send( | |
tracer=tracer, | |
excluded_urls=excluded_urls, | |
request_hook=self._request_hook, | |
response_hook=self._response_hook, | |
duration_histogram=duration_histogram | |
), self) | |
class RequestsInstrumentor(Instrumentor): | |
""" | |
An instrumentor for the requests module. | |
""" | |
def instrumentation_dependencies(self) -> Collection[str]: | |
return ["requests"] | |
def _instrument(self, **kwargs): | |
""" | |
Instruments the requests module. | |
""" | |
logger.info("requests _instrument entered.") | |
self._original_session = sessions.Session | |
request_hook = kwargs.get("request_hook") | |
response_hook = kwargs.get("response_hook") | |
if callable(request_hook): | |
_InstrumentedSession._request_hook = request_hook | |
if callable(response_hook): | |
_InstrumentedSession._response_hook = response_hook | |
tracer_provider = kwargs.get("tracer_provider") | |
_InstrumentedSession._tracer_provider = tracer_provider | |
excluded_urls = kwargs.get("excluded_urls") | |
_InstrumentedSession._excluded_urls = ( | |
get_excluded_urls("FLASK") | |
if excluded_urls is None | |
else parse_excluded_urls(excluded_urls) | |
) | |
sessions.Session = _InstrumentedSession | |
logger.info("requests _instrument exited.") | |
def _uninstrument(self, **kwargs): | |
""" | |
Uninstruments the requests module. | |
""" | |
sessions.Session = self._original_session | |
def instrument_requests(excluded_urls: str = None, | |
request_hook: Callable = None, | |
response_hook: Callable = None, | |
tracer_provider: TraceProvider = None, | |
**kwargs: Any, | |
): | |
""" | |
Instruments the requests module. | |
Args: | |
excluded_urls: A comma separated list of URLs to exclude from tracing. | |
request_hook: A function that will be called before a request is sent. | |
The function will be called with the span and the request. | |
response_hook: A function that will be called after a response is received. | |
The function will be called with the span and the response. | |
tracer_provider: The tracer provider to use. If not provided, the global | |
tracer provider will be used. | |
kwargs: Additional keyword arguments. | |
""" | |
all_kwargs = { | |
"excluded_urls": excluded_urls, | |
"request_hook": request_hook, | |
"response_hook": response_hook, | |
"tracer_provider": tracer_provider or get_tracer_provider(), | |
**kwargs | |
} | |
RequestsInstrumentor().instrument(**all_kwargs) | |
logger.info("Requests instrumented.") | |