Spaces:
Runtime error
Runtime error
"""Redis client for caching and task queue management.""" | |
import json | |
from typing import Any, Dict, Optional, TypeVar | |
from datetime import datetime | |
import logging | |
from time import time as _time | |
import redis.asyncio as redis_async | |
import redis as redis_sync # Import synchronous Redis client | |
from pydantic import BaseModel | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
from app.core.config import settings | |
# Type variable for cache | |
T = TypeVar('T') | |
# Configure logging | |
log = logging.getLogger(__name__) | |
# Redis connection pools for reusing connections | |
_redis_pool_async = None | |
_redis_pool_sync = None # Synchronous pool | |
# Default cache expiration (12 hours) | |
DEFAULT_CACHE_EXPIRY = 60 * 60 * 12 | |
async def get_redis_pool() -> redis_async.Redis: | |
"""Get or create async Redis connection pool with retry logic.""" | |
global _redis_pool_async | |
if _redis_pool_async is None: | |
# Get Redis configuration from settings | |
redis_url = settings.REDIS_URL or "redis://localhost:6379/0" | |
try: | |
# Create connection pool with reasonable defaults | |
_redis_pool_async = redis_async.ConnectionPool.from_url( | |
redis_url, | |
max_connections=10, | |
decode_responses=True, | |
health_check_interval=5, | |
socket_connect_timeout=5, | |
socket_keepalive=True, | |
retry_on_timeout=True | |
) | |
log.info(f"Created async Redis connection pool with URL: {redis_url}") | |
except Exception as e: | |
log.error(f"Error creating async Redis connection pool: {e}") | |
raise | |
return redis_async.Redis(connection_pool=_redis_pool_async) | |
def get_redis_pool_sync() -> redis_sync.Redis: | |
"""Get or create synchronous Redis connection pool.""" | |
global _redis_pool_sync | |
if _redis_pool_sync is None: | |
# Get Redis configuration from settings | |
redis_url = settings.REDIS_URL or "redis://localhost:6379/0" | |
try: | |
# Create connection pool with reasonable defaults | |
_redis_pool_sync = redis_sync.ConnectionPool.from_url( | |
redis_url, | |
max_connections=10, | |
decode_responses=True, | |
socket_connect_timeout=5, | |
socket_keepalive=True, | |
retry_on_timeout=True | |
) | |
log.info(f"Created sync Redis connection pool with URL: {redis_url}") | |
except Exception as e: | |
log.error(f"Error creating sync Redis connection pool: {e}") | |
raise | |
return redis_sync.Redis(connection_pool=_redis_pool_sync) | |
async def get_redis() -> redis_async.Redis: | |
"""Get Redis client from pool with retry logic.""" | |
try: | |
redis_client = await get_redis_pool() | |
return redis_client | |
except Exception as e: | |
log.error(f"Error getting Redis client: {e}") | |
raise | |
def get_redis_sync() -> redis_sync.Redis: | |
"""Get synchronous Redis client from pool with retry logic.""" | |
try: | |
return get_redis_pool_sync() | |
except Exception as e: | |
log.error(f"Error getting synchronous Redis client: {e}") | |
raise | |
# Cache key generation | |
def generate_cache_key(prefix: str, *args: Any) -> str: | |
"""Generate cache key with prefix and args.""" | |
key_parts = [prefix] + [str(arg) for arg in args if arg] | |
return ":".join(key_parts) | |
# JSON serialization helpers | |
def _json_serialize(obj: Any) -> str: | |
"""Serialize object to JSON with datetime support.""" | |
def _serialize_datetime(o: Any) -> str: | |
if isinstance(o, datetime): | |
return o.isoformat() | |
if isinstance(o, BaseModel): | |
return o.dict() | |
return str(o) | |
return json.dumps(obj, default=_serialize_datetime) | |
def _json_deserialize(data: str, model_class: Optional[type] = None) -> Any: | |
"""Deserialize JSON string to object with datetime support.""" | |
result = json.loads(data) | |
if model_class and issubclass(model_class, BaseModel): | |
return model_class.parse_obj(result) | |
return result | |
# Async cache operations | |
async def cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool: | |
"""Set cache value with expiration (async version).""" | |
redis_client = await get_redis() | |
serialized = _json_serialize(value) | |
try: | |
await redis_client.set(key, serialized, ex=expire) | |
log.debug(f"Cached data at key: {key}, expires in {expire}s") | |
return True | |
except Exception as e: | |
log.error(f"Error caching data at key {key}: {e}") | |
return False | |
async def cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]: | |
"""Get cache value with optional model deserialization (async version).""" | |
redis_client = await get_redis() | |
try: | |
data = await redis_client.get(key) | |
if not data: | |
return None | |
log.debug(f"Cache hit for key: {key}") | |
return _json_deserialize(data, model_class) | |
except Exception as e: | |
log.error(f"Error retrieving cache for key {key}: {e}") | |
return None | |
# Synchronous cache operations for Celery tasks | |
def sync_cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool: | |
"""Set cache value with expiration (synchronous version for Celery tasks). Logs slow operations.""" | |
redis_client = get_redis_sync() | |
serialized = _json_serialize(value) | |
start = _time() | |
try: | |
redis_client.set(key, serialized, ex=expire) | |
elapsed = _time() - start | |
if elapsed > 2: | |
log.warning(f"Slow sync_cache_set for key {key}: {elapsed:.2f}s") | |
log.debug(f"Cached data at key: {key}, expires in {expire}s (sync)") | |
return True | |
except Exception as e: | |
log.error(f"Error caching data at key {key}: {e}") | |
return False | |
def sync_cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]: | |
"""Get cache value with optional model deserialization (synchronous version for Celery tasks). Logs slow operations.""" | |
redis_client = get_redis_sync() | |
start = _time() | |
try: | |
data = redis_client.get(key) | |
elapsed = _time() - start | |
if elapsed > 2: | |
log.warning(f"Slow sync_cache_get for key {key}: {elapsed:.2f}s") | |
if not data: | |
return None | |
log.debug(f"Cache hit for key: {key} (sync)") | |
return _json_deserialize(data, model_class) | |
except Exception as e: | |
log.error(f"Error retrieving cache for key {key}: {e}") | |
return None | |
async def cache_invalidate(key: str) -> bool: | |
"""Invalidate cache for key.""" | |
redis_client = await get_redis() | |
try: | |
await redis_client.delete(key) | |
log.debug(f"Invalidated cache for key: {key}") | |
return True | |
except Exception as e: | |
log.error(f"Error invalidating cache for key {key}: {e}") | |
return False | |
async def cache_invalidate_pattern(pattern: str) -> int: | |
"""Invalidate all cache keys matching pattern.""" | |
redis_client = await get_redis() | |
try: | |
keys = await redis_client.keys(pattern) | |
if not keys: | |
return 0 | |
count = await redis_client.delete(*keys) | |
log.debug(f"Invalidated {count} keys matching pattern: {pattern}") | |
return count | |
except Exception as e: | |
log.error(f"Error invalidating keys with pattern {pattern}: {e}") | |
return 0 | |
# Task queue operations | |
async def enqueue_task(queue_name: str, task_id: str, payload: Dict[str, Any]) -> bool: | |
"""Add task to queue.""" | |
redis_client = await get_redis() | |
try: | |
serialized = _json_serialize(payload) | |
await redis_client.lpush(f"queue:{queue_name}", serialized) | |
await redis_client.hset(f"tasks:{queue_name}", task_id, "pending") | |
log.info(f"Enqueued task {task_id} to queue {queue_name}") | |
return True | |
except Exception as e: | |
log.error(f"Error enqueueing task {task_id} to {queue_name}: {e}") | |
return False | |
async def mark_task_complete(queue_name: str, task_id: str, result: Optional[Dict[str, Any]] = None) -> bool: | |
"""Mark task as complete with optional result.""" | |
redis_client = await get_redis() | |
try: | |
# Store result if provided | |
if result: | |
await redis_client.hset( | |
f"results:{queue_name}", | |
task_id, | |
_json_serialize(result) | |
) | |
# Mark task as complete | |
await redis_client.hset(f"tasks:{queue_name}", task_id, "complete") | |
await redis_client.expire(f"tasks:{queue_name}", 86400) # Expire after 24 hours | |
log.info(f"Marked task {task_id} as complete in queue {queue_name}") | |
return True | |
except Exception as e: | |
log.error(f"Error marking task {task_id} as complete: {e}") | |
return False | |
async def get_task_status(queue_name: str, task_id: str) -> Optional[str]: | |
"""Get status of a task.""" | |
redis_client = await get_redis() | |
try: | |
status = await redis_client.hget(f"tasks:{queue_name}", task_id) | |
return status | |
except Exception as e: | |
log.error(f"Error getting status for task {task_id}: {e}") | |
return None | |
async def get_task_result(queue_name: str, task_id: str) -> Optional[Dict[str, Any]]: | |
"""Get result of a completed task.""" | |
redis_client = await get_redis() | |
try: | |
data = await redis_client.hget(f"results:{queue_name}", task_id) | |
if not data: | |
return None | |
return _json_deserialize(data) | |
except Exception as e: | |
log.error(f"Error getting result for task {task_id}: {e}") | |
return None | |
# Stream processing for real-time updates | |
async def add_to_stream(stream: str, data: Dict[str, Any], max_len: int = 1000) -> str: | |
"""Add event to Redis stream.""" | |
redis_client = await get_redis() | |
try: | |
# Convert dict values to strings (Redis streams requirement) | |
entry = {k: _json_serialize(v) for k, v in data.items()} | |
# Add to stream with automatic ID generation | |
event_id = await redis_client.xadd( | |
stream, | |
entry, | |
maxlen=max_len, | |
approximate=True | |
) | |
log.debug(f"Added event {event_id} to stream {stream}") | |
return event_id | |
except Exception as e: | |
log.error(f"Error adding to stream {stream}: {e}") | |
raise |