import time import sys import os from logging import StreamHandler from abc import ABC from logging import Logger, NOTSET, LogRecord, Filter, Formatter, Handler from typing import Optional, Union from aworld.trace.base import get_tracer_provider_silent, Tracer, AttributeValueType TRACE_LOG_FORMAT = '%(asctime)s - [%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' SPECIAL_TRACE_LOG_FORMAT = '%(asctime)s - [trace_%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s' class LoggerProvider(ABC): """A logger provider is a factory for loggers.""" _GLOBAL_LOG_PROVIDER: Optional[LoggerProvider] = None def set_log_provider(provider: str = "otlp", backend: str = "logfire", base_url: str = None, write_token: str = None, **kwargs): """Set the global log provider.""" global _GLOBAL_LOG_PROVIDER if provider == "otlp": from .opentelemetry.otlp_log import OTLPLoggerProvider _GLOBAL_LOG_PROVIDER = OTLPLoggerProvider(backend=backend, base_url=base_url, write_token=write_token, **kwargs) def get_log_provider() -> LoggerProvider: """ Get the global log provider. """ global _GLOBAL_LOG_PROVIDER if _GLOBAL_LOG_PROVIDER is None: raise ValueError("No log provider has been set.") return _GLOBAL_LOG_PROVIDER def instrument_logging(logger: Logger, level: Union[int, str] = NOTSET) -> None: """Instrument the logger.""" for handler in logger.root.handlers: if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): handler.setFormatter(Formatter(TRACE_LOG_FORMAT)) handler.addFilter(TraceLoggingFilter()) if not logger.handlers: print("No handlers found, adding a StreamHandler. logger=", logger.name) handler = StreamHandler() handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) handler.addFilter(TraceLoggingFilter()) logger.addHandler(handler) else: for handler in logger.handlers: if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters): handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT)) handler.addFilter(TraceLoggingFilter()) logger.propagate = False logger.addHandler(TraceLogginHandler(level)) class TraceLoggingFilter(Filter): """ A filter that adds trace information to log records. """ def filter(self, record: LogRecord) -> bool: """ Add trace information to the log record. """ trace = get_tracer_provider_silent() if trace: span = trace.get_current_span() record.trace_id = span.get_trace_id() if span else None record.span_id = span.get_span_id() if span else None return True class TraceLogginHandler(Handler): """ A handler class which writes logging records, appropriately formatted, to a stream. Note that this class does not close the stream, as sys.stdout or sys.stderr may be used. """ @staticmethod def strip_color(text: str) -> str: """Remove ANSI color codes from text""" import re return re.sub(r'\033\[[0-9;]*m', '', text) def __init__(self, level: Union[int, str] = NOTSET, tracer_name: str = "aworld.log") -> None: """Initialize the handler.""" super().__init__(level=level) self._tracer_name = tracer_name self._tracer: Tracer = None def emit(self, record: LogRecord) -> None: """Emit a record.""" trace = get_tracer_provider_silent() if not trace or not trace.get_current_span() or not trace.get_current_span().is_recording(): return if not self._tracer: self._tracer = trace.get_tracer(name=self._tracer_name) try: f = sys._getframe() while f: if 'logging/__init__.py' in f.f_code.co_filename or \ f.f_code.co_filename.startswith(os.path.dirname(__file__)): f = f.f_back else: break origin_msg = record.msg raw_msg = None if f: try: import linecache line = linecache.getline(f.f_code.co_filename, f.f_lineno) if 'logger.' in line: raw_msg = line.split('logger.', 1)[1].split( '(', 1)[1].split(')', 1)[0].strip() except: pass record.msg = self.strip_color(record.msg) msg_template = raw_msg if raw_msg else record.msg if len(msg_template) > 255: msg_template = msg_template[:255] + '...' attributes = { 'code.filepath': f.f_code.co_filename if f else record.pathname, 'code.lineno': f.f_lineno if f else record.lineno, 'code.function': f.f_code.co_name if f else record.funcName, 'log.level': record.levelname, 'log.logger': record.name, 'log.message': self.format(record), } record.msg = origin_msg self._create_span( span_name=msg_template, attributes=attributes, exc_info=record.exc_info, ) except RecursionError: # See issue 36272 raise except Exception: self.handleError(record) def _create_span(self, span_name: str, attributes: dict[str, AttributeValueType] = None, exc_info: BaseException = None): start_time = time.time_ns() span = self._tracer.start_span( name=span_name, attributes=attributes, start_time=start_time, ) if exc_info: span.record_exception(exception=exc_info, timestamp=start_time) span.end()