Spaces:
Sleeping
Sleeping
File size: 6,279 Bytes
31e8fad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import time
import sys
import os
from logging import StreamHandler
from abc import ABC
from logging import Logger, NOTSET, LogRecord, Filter, Formatter, Handler
from typing import Optional, Union
from aworld.trace.base import get_tracer_provider_silent, Tracer, AttributeValueType
TRACE_LOG_FORMAT = '%(asctime)s - [%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s'
SPECIAL_TRACE_LOG_FORMAT = '%(asctime)s - [trace_%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s'
class LoggerProvider(ABC):
"""A logger provider is a factory for loggers."""
_GLOBAL_LOG_PROVIDER: Optional[LoggerProvider] = None
def set_log_provider(provider: str = "otlp",
backend: str = "logfire",
base_url: str = None,
write_token: str = None,
**kwargs):
"""Set the global log provider."""
global _GLOBAL_LOG_PROVIDER
if provider == "otlp":
from .opentelemetry.otlp_log import OTLPLoggerProvider
_GLOBAL_LOG_PROVIDER = OTLPLoggerProvider(backend=backend,
base_url=base_url,
write_token=write_token,
**kwargs)
def get_log_provider() -> LoggerProvider:
"""
Get the global log provider.
"""
global _GLOBAL_LOG_PROVIDER
if _GLOBAL_LOG_PROVIDER is None:
raise ValueError("No log provider has been set.")
return _GLOBAL_LOG_PROVIDER
def instrument_logging(logger: Logger, level: Union[int, str] = NOTSET) -> None:
"""Instrument the logger."""
for handler in logger.root.handlers:
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters):
handler.setFormatter(Formatter(TRACE_LOG_FORMAT))
handler.addFilter(TraceLoggingFilter())
if not logger.handlers:
print("No handlers found, adding a StreamHandler. logger=", logger.name)
handler = StreamHandler()
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT))
handler.addFilter(TraceLoggingFilter())
logger.addHandler(handler)
else:
for handler in logger.handlers:
if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters):
handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT))
handler.addFilter(TraceLoggingFilter())
logger.propagate = False
logger.addHandler(TraceLogginHandler(level))
class TraceLoggingFilter(Filter):
"""
A filter that adds trace information to log records.
"""
def filter(self, record: LogRecord) -> bool:
"""
Add trace information to the log record.
"""
trace = get_tracer_provider_silent()
if trace:
span = trace.get_current_span()
record.trace_id = span.get_trace_id() if span else None
record.span_id = span.get_span_id() if span else None
return True
class TraceLogginHandler(Handler):
"""
A handler class which writes logging records, appropriately formatted,
to a stream. Note that this class does not close the stream, as
sys.stdout or sys.stderr may be used.
"""
@staticmethod
def strip_color(text: str) -> str:
"""Remove ANSI color codes from text"""
import re
return re.sub(r'\033\[[0-9;]*m', '', text)
def __init__(self,
level: Union[int, str] = NOTSET,
tracer_name: str = "aworld.log") -> None:
"""Initialize the handler."""
super().__init__(level=level)
self._tracer_name = tracer_name
self._tracer: Tracer = None
def emit(self, record: LogRecord) -> None:
"""Emit a record."""
trace = get_tracer_provider_silent()
if not trace or not trace.get_current_span() or not trace.get_current_span().is_recording():
return
if not self._tracer:
self._tracer = trace.get_tracer(name=self._tracer_name)
try:
f = sys._getframe()
while f:
if 'logging/__init__.py' in f.f_code.co_filename or \
f.f_code.co_filename.startswith(os.path.dirname(__file__)):
f = f.f_back
else:
break
origin_msg = record.msg
raw_msg = None
if f:
try:
import linecache
line = linecache.getline(f.f_code.co_filename, f.f_lineno)
if 'logger.' in line:
raw_msg = line.split('logger.', 1)[1].split(
'(', 1)[1].split(')', 1)[0].strip()
except:
pass
record.msg = self.strip_color(record.msg)
msg_template = raw_msg if raw_msg else record.msg
if len(msg_template) > 255:
msg_template = msg_template[:255] + '...'
attributes = {
'code.filepath': f.f_code.co_filename if f else record.pathname,
'code.lineno': f.f_lineno if f else record.lineno,
'code.function': f.f_code.co_name if f else record.funcName,
'log.level': record.levelname,
'log.logger': record.name,
'log.message': self.format(record),
}
record.msg = origin_msg
self._create_span(
span_name=msg_template,
attributes=attributes,
exc_info=record.exc_info,
)
except RecursionError: # See issue 36272
raise
except Exception:
self.handleError(record)
def _create_span(self,
span_name: str,
attributes: dict[str, AttributeValueType] = None,
exc_info: BaseException = None):
start_time = time.time_ns()
span = self._tracer.start_span(
name=span_name,
attributes=attributes,
start_time=start_time,
)
if exc_info:
span.record_exception(exception=exc_info, timestamp=start_time)
span.end()
|