Spaces:
Sleeping
Sleeping
import time | |
import sys | |
import os | |
from logging import StreamHandler | |
from abc import ABC | |
from logging import Logger, NOTSET, LogRecord, Filter, Formatter, Handler | |
from typing import Optional, Union | |
from aworld.trace.base import get_tracer_provider_silent, Tracer, AttributeValueType | |
TRACE_LOG_FORMAT = '%(asctime)s - [%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' | |
SPECIAL_TRACE_LOG_FORMAT = '%(asctime)s - [trace_%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' | |
class LoggerProvider(ABC): | |
"""A logger provider is a factory for loggers.""" | |
_GLOBAL_LOG_PROVIDER: Optional[LoggerProvider] = None | |
def set_log_provider(provider: str = "otlp", | |
backend: str = "logfire", | |
base_url: str = None, | |
write_token: str = None, | |
**kwargs): | |
"""Set the global log provider.""" | |
global _GLOBAL_LOG_PROVIDER | |
if provider == "otlp": | |
from .opentelemetry.otlp_log import OTLPLoggerProvider | |
_GLOBAL_LOG_PROVIDER = OTLPLoggerProvider(backend=backend, | |
base_url=base_url, | |
write_token=write_token, | |
**kwargs) | |
def get_log_provider() -> LoggerProvider: | |
""" | |
Get the global log provider. | |
""" | |
global _GLOBAL_LOG_PROVIDER | |
if _GLOBAL_LOG_PROVIDER is None: | |
raise ValueError("No log provider has been set.") | |
return _GLOBAL_LOG_PROVIDER | |
def instrument_logging(logger: Logger, level: Union[int, str] = NOTSET) -> None: | |
"""Instrument the logger.""" | |
for handler in logger.root.handlers: | |
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): | |
handler.setFormatter(Formatter(TRACE_LOG_FORMAT)) | |
handler.addFilter(TraceLoggingFilter()) | |
if not logger.handlers: | |
print("No handlers found, adding a StreamHandler. logger=", logger.name) | |
handler = StreamHandler() | |
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) | |
handler.addFilter(TraceLoggingFilter()) | |
logger.addHandler(handler) | |
else: | |
for handler in logger.handlers: | |
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): | |
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) | |
handler.addFilter(TraceLoggingFilter()) | |
logger.propagate = False | |
logger.addHandler(TraceLogginHandler(level)) | |
class TraceLoggingFilter(Filter): | |
""" | |
A filter that adds trace information to log records. | |
""" | |
def filter(self, record: LogRecord) -> bool: | |
""" | |
Add trace information to the log record. | |
""" | |
trace = get_tracer_provider_silent() | |
if trace: | |
span = trace.get_current_span() | |
record.trace_id = span.get_trace_id() if span else None | |
record.span_id = span.get_span_id() if span else None | |
return True | |
class TraceLogginHandler(Handler): | |
""" | |
A handler class which writes logging records, appropriately formatted, | |
to a stream. Note that this class does not close the stream, as | |
sys.stdout or sys.stderr may be used. | |
""" | |
def strip_color(text: str) -> str: | |
"""Remove ANSI color codes from text""" | |
import re | |
return re.sub(r'\033\[[0-9;]*m', '', text) | |
def __init__(self, | |
level: Union[int, str] = NOTSET, | |
tracer_name: str = "aworld.log") -> None: | |
"""Initialize the handler.""" | |
super().__init__(level=level) | |
self._tracer_name = tracer_name | |
self._tracer: Tracer = None | |
def emit(self, record: LogRecord) -> None: | |
"""Emit a record.""" | |
trace = get_tracer_provider_silent() | |
if not trace or not trace.get_current_span() or not trace.get_current_span().is_recording(): | |
return | |
if not self._tracer: | |
self._tracer = trace.get_tracer(name=self._tracer_name) | |
try: | |
f = sys._getframe() | |
while f: | |
if 'logging/__init__.py' in f.f_code.co_filename or \ | |
f.f_code.co_filename.startswith(os.path.dirname(__file__)): | |
f = f.f_back | |
else: | |
break | |
origin_msg = record.msg | |
raw_msg = None | |
if f: | |
try: | |
import linecache | |
line = linecache.getline(f.f_code.co_filename, f.f_lineno) | |
if 'logger.' in line: | |
raw_msg = line.split('logger.', 1)[1].split( | |
'(', 1)[1].split(')', 1)[0].strip() | |
except: | |
pass | |
record.msg = self.strip_color(record.msg) | |
msg_template = raw_msg if raw_msg else record.msg | |
if len(msg_template) > 255: | |
msg_template = msg_template[:255] + '...' | |
attributes = { | |
'code.filepath': f.f_code.co_filename if f else record.pathname, | |
'code.lineno': f.f_lineno if f else record.lineno, | |
'code.function': f.f_code.co_name if f else record.funcName, | |
'log.level': record.levelname, | |
'log.logger': record.name, | |
'log.message': self.format(record), | |
} | |
record.msg = origin_msg | |
self._create_span( | |
span_name=msg_template, | |
attributes=attributes, | |
exc_info=record.exc_info, | |
) | |
except RecursionError: # See issue 36272 | |
raise | |
except Exception: | |
self.handleError(record) | |
def _create_span(self, | |
span_name: str, | |
attributes: dict[str, AttributeValueType] = None, | |
exc_info: BaseException = None): | |
start_time = time.time_ns() | |
span = self._tracer.start_span( | |
name=span_name, | |
attributes=attributes, | |
start_time=start_time, | |
) | |
if exc_info: | |
span.record_exception(exception=exc_info, timestamp=start_time) | |
span.end() | |