Duibonduil's picture
Upload 2 files
bf4ab00 verified
import re
from typing import Tuple, List
from aworld.logs.util import logger
from aworld.trace.base import Propagator, Carrier, TraceContext
class W3CTraceContextPropagator(Propagator):
"""
OtelPropagator is a Propagator that extracts and injects using w3c TraceContext's headers.
carrier = {
"traceparent": "00-0af7651916cd43dd8448eb211c80319c-00f067aa0ba902b7-01",
"tracestate": "congo=t61rcWkgMzE",
"baggage": "key1=value1,key2=value2"
}
"""
_STATE_KEY_FORMAT = (
r"[a-z][_0-9a-z\-\*\/]{0,255}|"
r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}"
)
_STATE_VALUE_FORMAT = (
r"[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]"
)
_state_delimiter_pattern = re.compile(r"[ \t]*,[ \t]*")
_state_member_pattern = re.compile(
f"({_STATE_KEY_FORMAT})(=)({_STATE_VALUE_FORMAT})[ \t]*")
_TRACEPARENT_HEADER_NAME = "traceparent"
_TRACESTATE_HEADER_NAME = "tracestate"
_TRACEPARENT_HEADER_FORMAT = (
"^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})"
+ "(-.*)?[ \t]*$"
)
_TRACEPARENT_HEADER_FORMAT_RE = re.compile(_TRACEPARENT_HEADER_FORMAT)
def extract(self, carrier: Carrier) -> TraceContext:
"""
Extract trace context from carrier.
Args:
carrier: The carrier to extract trace context from.
Returns:
A dict of trace context.
"""
header = carrier.get(self._TRACEPARENT_HEADER_NAME) or carrier.get(
'HTTP_' + self._TRACEPARENT_HEADER_NAME.upper())
if header is None:
return None
match = re.search(self._TRACEPARENT_HEADER_FORMAT_RE, header)
if not match:
return None
version: str = match.group(1)
trace_id: str = match.group(2)
span_id: str = match.group(3)
trace_flags: str = match.group(4)
logger.info(
f"extract trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")
if trace_id == "0" * 32 or span_id == "0" * 16:
return None
if version == "00":
if match.group(5): # type: ignore
return None
if version == "ff":
return None
state_header = carrier.get(self._TRACESTATE_HEADER_NAME) or carrier.get(
'HTTP_' + self._TRACESTATE_HEADER_NAME.upper())
return TraceContext(
trace_id=trace_id,
span_id=span_id,
trace_flags=trace_flags,
version=version,
attributes=(self._extract_state_from_header(state_header))
)
def inject(self, trace_context: TraceContext, carrier: Carrier) -> None:
"""
Inject trace context into carrier.
Args:
context: The trace context to inject.
carrier: The carrier to inject trace context into.
"""
attribute_copy = trace_context.attributes.copy()
version: str = trace_context.version
trace_flags: str = trace_context.trace_flags
trace_id = trace_context.trace_id
span_id = trace_context.span_id
logger.info(
f"inject trace_id: {trace_id}, span_id: {span_id}, trace_flags: {trace_flags}, version: {version}")
if (not trace_id or trace_id == "0" * 32
or not span_id or span_id == "0" * 16):
return
if isinstance(trace_id, int):
trace_id = format(trace_id, "032x")
if isinstance(span_id, int):
span_id = format(span_id, "016x")
traceparent_string = f"{version}-{trace_id}-{span_id}-{trace_flags}"
carrier.set(self._TRACEPARENT_HEADER_NAME, traceparent_string)
tracestate_string = ",".join(
f"{key}={value}" for key, value in attribute_copy.items())
if tracestate_string:
carrier.set(self._TRACESTATE_HEADER_NAME, tracestate_string)
def _extract_state_from_header(self, header: str) -> dict:
"""
Extract state from header.
Args:
header: The header to extract state from.
Returns:
A dict of state.
"""
if header is None:
return {}
state = {}
members: List[str] = re.split(self._state_delimiter_pattern, header)
for member in members:
# empty members are valid, but no need to process further.
if not member:
continue
match = self._state_member_pattern.fullmatch(member)
if not match:
logger.warning(
"Member doesn't match the w3c identifiers format {member}")
return state
groups: Tuple[str, ...] = match.groups()
key, _eq, value = groups
# duplicate keys are not legal in header
if key in state:
return state
state[key] = value
return state