| import hashlib |
| from collections.abc import Awaitable |
| from datetime import timedelta |
| from functools import wraps |
| from typing import Any, Callable, Optional, TypeVar, cast |
|
|
| from cashews import Cache |
| from pydantic import BaseModel |
| from pydantic.decorator import ValidatedFunction |
|
|
| from .config import Config |
| from .log import logger |
|
|
| CACHE_CONFIG_KEY = "_cache_config" |
|
|
| AsyncFunc = Callable[..., Awaitable[Any]] |
| T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc) |
|
|
|
|
| CACHE_ENABLED = Config["cache"]["enabled"].as_bool() |
| CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number()) |
| CACHE_URI = Config["cache"]["uri"].as_str() |
| CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool() |
|
|
| cache = Cache(name="hibiapi") |
| try: |
| cache.setup(CACHE_URI) |
| except Exception as e: |
| logger.warning( |
| f"Cache URI <y>{CACHE_URI!r}</y> setup <r><b>failed</b></r>: " |
| f"<r>{e!r}</r>, use memory backend instead." |
| ) |
|
|
|
|
| class CacheConfig(BaseModel): |
| endpoint: AsyncFunc |
| namespace: str |
| enabled: bool = True |
| ttl: timedelta = CACHE_DELTA |
|
|
| @staticmethod |
| def new( |
| function: AsyncFunc, |
| *, |
| enabled: bool = True, |
| ttl: timedelta = CACHE_DELTA, |
| namespace: Optional[str] = None, |
| ): |
| return CacheConfig( |
| endpoint=function, |
| enabled=enabled, |
| ttl=ttl, |
| namespace=namespace or function.__qualname__, |
| ) |
|
|
|
|
| def cache_config( |
| enabled: bool = True, |
| ttl: timedelta = CACHE_DELTA, |
| namespace: Optional[str] = None, |
| ): |
| def decorator(function: T_AsyncFunc) -> T_AsyncFunc: |
| setattr( |
| function, |
| CACHE_CONFIG_KEY, |
| CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace), |
| ) |
| return function |
|
|
| return decorator |
|
|
|
|
| disable_cache = cache_config(enabled=False) |
|
|
|
|
| class CachedValidatedFunction(ValidatedFunction): |
| def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel: |
| values = self.build_values(args=args, kwargs=kwargs) |
| return self.model(**values) |
|
|
|
|
| def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc: |
| from .routing import request_headers, response_headers |
|
|
| vf = CachedValidatedFunction(function, config={}) |
| config = cast( |
| CacheConfig, |
| getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function), |
| ) |
|
|
| config.enabled = CACHE_ENABLED and config.enabled |
|
|
| @wraps(function) |
| async def wrapper(*args, **kwargs): |
| cache_policy = "public" |
|
|
| if CACHE_CONTROLLABLE: |
| cache_policy = request_headers.get().get("cache-control", cache_policy) |
|
|
| if not config.enabled or cache_policy.casefold() == "no-store": |
| return await vf.call(*args, **kwargs) |
|
|
| key = ( |
| f"{config.namespace}:" |
| + hashlib.md5( |
| (model := vf.serialize(args=args, kwargs=kwargs)) |
| .json(exclude={"self"}, sort_keys=True, ensure_ascii=False) |
| .encode() |
| ).hexdigest() |
| ) |
|
|
| response_header = response_headers.get() |
| result: Optional[Any] = None |
|
|
| if cache_policy.casefold() == "no-cache": |
| await cache.delete(key) |
| elif result := await cache.get(key): |
| logger.debug(f"Request hit cache <b><e>{key}</e></b>") |
| response_header.setdefault("X-Cache-Hit", key) |
|
|
| if result is None: |
| result = await vf.execute(model) |
| await cache.set(key, result, expire=config.ttl) |
|
|
| if (cache_remain := await cache.get_expire(key)) > 0: |
| response_header.setdefault("Cache-Control", f"max-age={cache_remain}") |
|
|
| return result |
|
|
| return wrapper |
|
|