Spaces:
Sleeping
Sleeping
import time | |
import asyncio | |
from typing import Callable | |
from functools import wraps | |
from aworld.metrics.metric import get_metric_provider, MetricType, BaseMetric | |
from aworld.metrics.template import MetricTemplate, MetricTemplates | |
_GLOBAL_METIRCS = {} | |
class MetricContext: | |
_initialized = False | |
def configure(cls, | |
provider: str, | |
backend: str, | |
base_url: str = None, | |
write_token: str = None, | |
**kwargs): | |
""" | |
Configure the metric provider. | |
Args: | |
provider: The provider of the metric provider. | |
backend: The backend of the metric provider. | |
base_url: The base url of the metric provider. | |
write_token: The write token of the metric provider. | |
export_console: Whether to export the metrics to console. | |
**kwargs: The other parameters of the metric provider. | |
""" | |
if cls._initialized: | |
cls.shutdown() | |
if provider == "prometheus": | |
from aworld.metrics.prometheus.prometheus_adapter import configure_prometheus_provider | |
configure_prometheus_provider( | |
backend, base_url, write_token, **kwargs) | |
elif provider == "otlp": | |
from aworld.metrics.opentelemetry.opentelemetry_adapter import configure_otlp_provider | |
configure_otlp_provider(backend, base_url, write_token, **kwargs) | |
cls._initialized = True | |
def metric_initialized(cls): | |
return cls._initialized | |
def get_or_create_metric(template: MetricTemplate): | |
if template.name in _GLOBAL_METIRCS: | |
return _GLOBAL_METIRCS[template.name] | |
metric = None | |
if template.type == MetricType.COUNTER: | |
metric = get_metric_provider().create_counter(template.name, template.description, template.unit, | |
template.labels) | |
elif template.type == MetricType.UPDOWNCOUNTER: | |
metric = get_metric_provider().create_un_down_counter(template.name, template.description, template.unit, | |
template.labels) | |
elif template.type == MetricType.GAUGE: | |
metric = get_metric_provider().create_gauge(template.name, template.description, template.unit, | |
template.labels) | |
elif template.type == MetricType.HISTOGRAM: | |
metric = get_metric_provider().create_histogram(template.name, template.description, template.unit, | |
template.buckets, template.labels) | |
_GLOBAL_METIRCS[template.name] = metric | |
return metric | |
def _validate_type(cls, metric: BaseMetric, type: str): | |
if type != metric._type: | |
raise ValueError(f"metric type {metric._type} is not {type}") | |
def count(cls, template: MetricTemplate, value: int, labels: dict = None): | |
""" | |
Increment a counter metric. | |
""" | |
metric = cls.get_or_create_metric(template) | |
cls._validate_type(metric, MetricType.COUNTER) | |
metric.add(value, labels) | |
def inc(cls, template: MetricTemplate, value: int, labels: dict = None): | |
""" | |
Increment a updowncounter metric. | |
""" | |
metric = cls.get_or_create_metric(template) | |
cls._validate_type(metric, MetricType.UPDOWNCOUNTER) | |
metric.inc(value, labels) | |
def dec(cls, template: MetricTemplate, value: int, labels: dict = None): | |
""" | |
Decrement a updowncounter metric. | |
""" | |
metric = cls.get_or_create_metric(template) | |
cls._validate_type(metric, MetricType.UPDOWNCOUNTER) | |
metric.dec(value, labels) | |
def gauge_set(cls, template: MetricTemplate, value: int, labels: dict = None): | |
""" | |
Set a value to a gauge metric. | |
""" | |
metric = cls.get_or_create_metric(template) | |
cls._validate_type(metric, MetricType.GAUGE) | |
metric.set(value, labels) | |
def histogram_record(cls, template: MetricTemplate, value: int, labels: dict = None): | |
""" | |
Set a value to a histogram metric. | |
""" | |
metric = cls.get_or_create_metric(template) | |
cls._validate_type(metric, MetricType.HISTOGRAM) | |
metric.record(value, labels) | |
def shutdown(cls): | |
""" | |
Shutdown the metric provider. | |
""" | |
provider = get_metric_provider() | |
if provider: | |
provider.shutdown() | |
cls._initialized = False | |
class ApiMetricTracker: | |
""" | |
Decorator to track API metrics. | |
""" | |
def __init__(self, api_name: str = None, func: Callable = None): | |
self.start_time = None | |
self.status = "success" | |
self.func = func | |
self.api_name = api_name | |
if self.api_name is None and self.func is not None: | |
self.api_name = self.func.__name__ | |
def _new_tracker(self, func: Callable): | |
return self.__class__(func=func) | |
def __enter__(self): | |
self.start_time = time.time() * 1000 | |
def __exit__(self, exc_type, value, traceback): | |
if exc_type is None: | |
self.status = "success" | |
else: | |
self.status = "failure" | |
self._record_metrics(self.api_name, self.start_time, self.status) | |
def __call__(self, func: Callable = None) -> Callable: | |
if func is None: | |
return self | |
return self.decorator(func) | |
def _record_metrics(self, api_name: str, start_time: float, status: str) -> None: | |
""" | |
Record metrics for the API. | |
""" | |
elapsed_time = time.time() * 1000 - start_time | |
MetricContext.count(MetricTemplates.REQUEST_COUNT, 1, | |
labels={"method": api_name, "status": status}) | |
MetricContext.histogram_record(MetricTemplates.REQUEST_LATENCY, elapsed_time, | |
labels={"method": api_name, "status": status}) | |
def decorator(self, func): | |
""" | |
Decorator to track API metrics. | |
""" | |
async def async_wrapper(*args, **kwargs): | |
with self._new_tracker(func): | |
return await func(*args, **kwargs) | |
def wrapper(*args, **kwargs): | |
with self._new_tracker(func): | |
return func(*args, **kwargs) | |
return async_wrapper if asyncio.iscoroutinefunction(func) else wrapper | |