import ast from abc import ABC, abstractmethod from app.config import config from app.models import const # Base class for state management class BaseState(ABC): @abstractmethod def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): pass @abstractmethod def get_task(self, task_id: str): pass # Memory state management class MemoryState(BaseState): def __init__(self): self._tasks = {} def update_task( self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs, ): progress = int(progress) if progress > 100: progress = 100 self._tasks[task_id] = { "state": state, "progress": progress, **kwargs, } def get_task(self, task_id: str): return self._tasks.get(task_id, None) def delete_task(self, task_id: str): if task_id in self._tasks: del self._tasks[task_id] # Redis state management class RedisState(BaseState): def __init__(self, host="localhost", port=6379, db=0, password=None): import redis self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password) def update_task( self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs, ): progress = int(progress) if progress > 100: progress = 100 fields = { "state": state, "progress": progress, **kwargs, } for field, value in fields.items(): self._redis.hset(task_id, field, str(value)) def get_task(self, task_id: str): task_data = self._redis.hgetall(task_id) if not task_data: return None task = { key.decode("utf-8"): self._convert_to_original_type(value) for key, value in task_data.items() } return task def delete_task(self, task_id: str): self._redis.delete(task_id) @staticmethod def _convert_to_original_type(value): """ Convert the value from byte string to its original data type. You can extend this method to handle other data types as needed. """ value_str = value.decode("utf-8") try: # try to convert byte string array to list return ast.literal_eval(value_str) except (ValueError, SyntaxError): pass if value_str.isdigit(): return int(value_str) # Add more conversions here if needed return value_str # Global state _enable_redis = config.app.get("enable_redis", False) _redis_host = config.app.get("redis_host", "localhost") _redis_port = config.app.get("redis_port", 6379) _redis_db = config.app.get("redis_db", 0) _redis_password = config.app.get("redis_password", None) state = ( RedisState( host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password ) if _enable_redis else MemoryState() )