File size: 6,704 Bytes
0d3a3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import time
import asyncio
from typing import Callable
from functools import wraps
from aworld.metrics.metric import get_metric_provider, MetricType, BaseMetric
from aworld.metrics.template import MetricTemplate, MetricTemplates

_GLOBAL_METIRCS = {}


class MetricContext:

    _initialized = False

    @classmethod
    def configure(cls,
                  provider: str,
                  backend: str,
                  base_url: str = None,
                  write_token: str = None,
                  **kwargs):
        """
        Configure the metric provider.
        Args:
            provider: The provider of the metric provider.
            backend: The backend of the metric provider.
            base_url: The base url of the metric provider.
            write_token: The write token of the metric provider.
            export_console: Whether to export the metrics to console.
            **kwargs: The other parameters of the metric provider.
        """
        if cls._initialized:
            cls.shutdown()
        if provider == "prometheus":
            from aworld.metrics.prometheus.prometheus_adapter import configure_prometheus_provider
            configure_prometheus_provider(
                backend, base_url, write_token, **kwargs)
        elif provider == "otlp":
            from aworld.metrics.opentelemetry.opentelemetry_adapter import configure_otlp_provider
            configure_otlp_provider(backend, base_url, write_token, **kwargs)
        cls._initialized = True

    @classmethod
    def metric_initialized(cls):
        return cls._initialized

    @staticmethod
    def get_or_create_metric(template: MetricTemplate):
        if template.name in _GLOBAL_METIRCS:
            return _GLOBAL_METIRCS[template.name]

        metric = None
        if template.type == MetricType.COUNTER:
            metric = get_metric_provider().create_counter(template.name, template.description, template.unit,
                                                          template.labels)
        elif template.type == MetricType.UPDOWNCOUNTER:
            metric = get_metric_provider().create_un_down_counter(template.name, template.description, template.unit,
                                                                  template.labels)
        elif template.type == MetricType.GAUGE:
            metric = get_metric_provider().create_gauge(template.name, template.description, template.unit,
                                                        template.labels)
        elif template.type == MetricType.HISTOGRAM:
            metric = get_metric_provider().create_histogram(template.name, template.description, template.unit,
                                                            template.buckets, template.labels)

        _GLOBAL_METIRCS[template.name] = metric
        return metric

    @classmethod
    def _validate_type(cls, metric: BaseMetric, type: str):
        if type != metric._type:
            raise ValueError(f"metric type {metric._type} is not {type}")

    @classmethod
    def count(cls, template: MetricTemplate, value: int, labels: dict = None):
        """
        Increment a counter metric.
        """
        metric = cls.get_or_create_metric(template)
        cls._validate_type(metric, MetricType.COUNTER)
        metric.add(value, labels)

    @classmethod
    def inc(cls, template: MetricTemplate, value: int, labels: dict = None):
        """
        Increment a updowncounter metric.
        """
        metric = cls.get_or_create_metric(template)
        cls._validate_type(metric, MetricType.UPDOWNCOUNTER)
        metric.inc(value, labels)

    @classmethod
    def dec(cls, template: MetricTemplate, value: int, labels: dict = None):
        """
        Decrement a updowncounter metric.
        """
        metric = cls.get_or_create_metric(template)
        cls._validate_type(metric, MetricType.UPDOWNCOUNTER)
        metric.dec(value, labels)

    @classmethod
    def gauge_set(cls, template: MetricTemplate, value: int, labels: dict = None):
        """
        Set a value to a gauge metric.
        """
        metric = cls.get_or_create_metric(template)
        cls._validate_type(metric, MetricType.GAUGE)
        metric.set(value, labels)

    @classmethod
    def histogram_record(cls, template: MetricTemplate, value: int, labels: dict = None):
        """
        Set a value to a histogram metric.
        """
        metric = cls.get_or_create_metric(template)
        cls._validate_type(metric, MetricType.HISTOGRAM)
        metric.record(value, labels)

    @classmethod
    def shutdown(cls):
        """
        Shutdown the metric provider.
        """
        provider = get_metric_provider()
        if provider:
            provider.shutdown()
        cls._initialized = False


class ApiMetricTracker:
    """
    Decorator to track API metrics.
    """

    def __init__(self, api_name: str = None, func: Callable = None):
        self.start_time = None
        self.status = "success"
        self.func = func
        self.api_name = api_name
        if self.api_name is None and self.func is not None:
            self.api_name = self.func.__name__

    def _new_tracker(self, func: Callable):
        return self.__class__(func=func)

    def __enter__(self):
        self.start_time = time.time() * 1000

    def __exit__(self, exc_type, value, traceback):
        if exc_type is None:
            self.status = "success"
        else:
            self.status = "failure"
        self._record_metrics(self.api_name, self.start_time, self.status)

    def __call__(self, func: Callable = None) -> Callable:
        if func is None:
            return self
        return self.decorator(func)

    def _record_metrics(self, api_name: str, start_time: float, status: str) -> None:
        """
        Record metrics for the API.
        """
        elapsed_time = time.time() * 1000 - start_time
        MetricContext.count(MetricTemplates.REQUEST_COUNT, 1,
                            labels={"method": api_name, "status": status})
        MetricContext.histogram_record(MetricTemplates.REQUEST_LATENCY, elapsed_time,
                                       labels={"method": api_name, "status": status})

    def decorator(self, func):
        """
        Decorator to track API metrics.
        """

        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            with self._new_tracker(func):
                return await func(*args, **kwargs)

        @wraps(func)
        def wrapper(*args, **kwargs):
            with self._new_tracker(func):
                return func(*args, **kwargs)

        return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper