File size: 6,279 Bytes
31e8fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import time
import sys
import os
from logging import StreamHandler
from abc import ABC
from logging import Logger, NOTSET, LogRecord, Filter, Formatter, Handler
from typing import Optional, Union

from aworld.trace.base import get_tracer_provider_silent, Tracer, AttributeValueType

TRACE_LOG_FORMAT = '%(asctime)s - [%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s'
SPECIAL_TRACE_LOG_FORMAT = '%(asctime)s - [trace_%(trace_id)s] - [%(span_id)s] - %(name)s - %(levelname)s - %(message)s'


class LoggerProvider(ABC):
    """A logger provider is a factory for loggers."""


_GLOBAL_LOG_PROVIDER: Optional[LoggerProvider] = None


def set_log_provider(provider: str = "otlp",
                     backend: str = "logfire",
                     base_url: str = None,
                     write_token: str = None,
                     **kwargs):
    """Set the global log provider."""

    global _GLOBAL_LOG_PROVIDER

    if provider == "otlp":
        from .opentelemetry.otlp_log import OTLPLoggerProvider
        _GLOBAL_LOG_PROVIDER = OTLPLoggerProvider(backend=backend,
                                                  base_url=base_url,
                                                  write_token=write_token,
                                                  **kwargs)


def get_log_provider() -> LoggerProvider:
    """
    Get the global log provider.
    """
    global _GLOBAL_LOG_PROVIDER
    if _GLOBAL_LOG_PROVIDER is None:
        raise ValueError("No log provider has been set.")
    return _GLOBAL_LOG_PROVIDER


def instrument_logging(logger: Logger, level: Union[int, str] = NOTSET) -> None:
    """Instrument the logger."""
    for handler in logger.root.handlers:
        if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters):
            handler.setFormatter(Formatter(TRACE_LOG_FORMAT))
            handler.addFilter(TraceLoggingFilter())

    if not logger.handlers:
        print("No handlers found, adding a StreamHandler. logger=", logger.name)
        handler = StreamHandler()
        handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT))
        handler.addFilter(TraceLoggingFilter())
        logger.addHandler(handler)
    else:
        for handler in logger.handlers:
            if not any(isinstance(filter, TraceLoggingFilter) for filter in handler.filters):
                handler.setFormatter(Formatter(SPECIAL_TRACE_LOG_FORMAT))
                handler.addFilter(TraceLoggingFilter())
    logger.propagate = False
    logger.addHandler(TraceLogginHandler(level))


class TraceLoggingFilter(Filter):
    """
    A filter that adds trace information to log records.
    """

    def filter(self, record: LogRecord) -> bool:
        """
        Add trace information to the log record.
        """
        trace = get_tracer_provider_silent()
        if trace:
            span = trace.get_current_span()
            record.trace_id = span.get_trace_id() if span else None
            record.span_id = span.get_span_id() if span else None
        return True


class TraceLogginHandler(Handler):
    """
    A handler class which writes logging records, appropriately formatted,
    to a stream. Note that this class does not close the stream, as
    sys.stdout or sys.stderr may be used.
    """
    @staticmethod
    def strip_color(text: str) -> str:
        """Remove ANSI color codes from text"""
        import re
        return re.sub(r'\033\[[0-9;]*m', '', text)

    def __init__(self,
                 level: Union[int, str] = NOTSET,
                 tracer_name: str = "aworld.log") -> None:
        """Initialize the handler."""
        super().__init__(level=level)
        self._tracer_name = tracer_name
        self._tracer: Tracer = None

    def emit(self, record: LogRecord) -> None:
        """Emit a record."""
        trace = get_tracer_provider_silent()
        if not trace or not trace.get_current_span() or not trace.get_current_span().is_recording():
            return

        if not self._tracer:
            self._tracer = trace.get_tracer(name=self._tracer_name)

        try:
            f = sys._getframe()
            while f:
                if 'logging/__init__.py' in f.f_code.co_filename or \
                        f.f_code.co_filename.startswith(os.path.dirname(__file__)):
                    f = f.f_back
                else:
                    break

            origin_msg = record.msg
            raw_msg = None
            if f:
                try:
                    import linecache
                    line = linecache.getline(f.f_code.co_filename, f.f_lineno)
                    if 'logger.' in line:
                        raw_msg = line.split('logger.', 1)[1].split(
                            '(', 1)[1].split(')', 1)[0].strip()
                except:
                    pass
            record.msg = self.strip_color(record.msg)
            msg_template = raw_msg if raw_msg else record.msg

            if len(msg_template) > 255:
                msg_template = msg_template[:255] + '...'

            attributes = {
                'code.filepath': f.f_code.co_filename if f else record.pathname,
                'code.lineno': f.f_lineno if f else record.lineno,
                'code.function': f.f_code.co_name if f else record.funcName,
                'log.level': record.levelname,
                'log.logger': record.name,
                'log.message': self.format(record),
            }
            record.msg = origin_msg
            self._create_span(
                span_name=msg_template,
                attributes=attributes,
                exc_info=record.exc_info,
            )
        except RecursionError:  # See issue 36272
            raise
        except Exception:
            self.handleError(record)

    def _create_span(self,
                     span_name: str,
                     attributes: dict[str, AttributeValueType] = None,
                     exc_info: BaseException = None):
        start_time = time.time_ns()
        span = self._tracer.start_span(
            name=span_name,
            attributes=attributes,
            start_time=start_time,
        )
        if exc_info:
            span.record_exception(exception=exc_info, timestamp=start_time)
        span.end()