from typing import Any, Dict, Optional, List import copy import uuid from datetime import datetime, timezone from abc import ABC, abstractmethod import asyncio from pydantic import BaseModel, Field class CheckpointMetadata(BaseModel): """ Metadata for a checkpoint, including session and task identifiers. Attributes: session_id (str): The session identifier (required). task_id (Optional[str]): The task identifier (optional). """ session_id: str = Field(..., description="The session identifier.") task_id: Optional[str] = Field(None, description="The task identifier.") class Checkpoint(BaseModel): """ Core structure for a state checkpoint. Attributes: id (str): Unique identifier for the checkpoint. ts (str): Timestamp of the checkpoint. metadata (CheckpointMetadata): Metadata associated with the checkpoint. values (dict[str, Any]): State values stored in the checkpoint. version (str): Version of the checkpoint format. parent_id (Optional[str]): Parent checkpoint identifier, if any. namespace (str): Namespace for the checkpoint, default is 'aworld'. """ id: str = Field(..., description="Unique identifier for the checkpoint.") ts: str = Field(..., description="Timestamp of the checkpoint.") metadata: CheckpointMetadata = Field(..., description="Metadata associated with the checkpoint.") values: Dict[str, Any] = Field(..., description="State values stored in the checkpoint.") version: int = Field(..., description="Version of the checkpoint format.") parent_id: Optional[str] = Field(default=None, description="Parent checkpoint identifier, if any.") namespace: str = Field(default="aworld", description="Namespace for the checkpoint, default is 'aworld'.") def empty_checkpoint() -> Checkpoint: """ Create an empty checkpoint with default values. Returns: Checkpoint: An empty checkpoint structure. """ return Checkpoint( id=str(uuid.uuid4()), ts=datetime.now(timezone.utc).isoformat(), metadata=CheckpointMetadata(session_id="", task_id=None), values={}, version=1, parent_id=None, namespace="aworld", ) def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint: """ Create a deep copy of a checkpoint. Args: checkpoint (Checkpoint): The checkpoint to copy. Returns: Checkpoint: A deep copy of the provided checkpoint. """ return copy.deepcopy(checkpoint) def create_checkpoint( values: Dict[str, Any], metadata: CheckpointMetadata, parent_id: Optional[str] = None, version: int = 1, namespace: str = 'aworld', ) -> Checkpoint: """ Create a new checkpoint from provided state values and metadata. Args: values (dict[str, Any]): State values to store in the checkpoint. metadata (CheckpointMetadata): Metadata for the checkpoint. parent_id (Optional[str]): Parent checkpoint identifier, if any. version (str): Version of the checkpoint format. namespace (str): Namespace for the checkpoint. Returns: Checkpoint: The newly created checkpoint. """ return Checkpoint( id=str(uuid.uuid4()), ts=datetime.now(timezone.utc).isoformat(), metadata=metadata, values=values, version=VersionUtils.get_next_version(version), parent_id=parent_id, namespace=namespace, ) class BaseCheckpointRepository(ABC): """ Abstract base class for a checkpoint repository. Provides synchronous and asynchronous methods for checkpoint management. """ @abstractmethod def get(self, checkpoint_id: str) -> Optional[Checkpoint]: """ Retrieve a checkpoint by its unique identifier. Args: checkpoint_id (str): The unique identifier of the checkpoint. Returns: Optional[Checkpoint]: The checkpoint if found, otherwise None. """ pass @abstractmethod def list(self, params: Dict[str, Any]) -> List[Checkpoint]: """ List checkpoints matching the given parameters. Args: params (dict): Parameters to filter checkpoints. Returns: List[Checkpoint]: List of matching checkpoints. """ pass @abstractmethod def put(self, checkpoint: Checkpoint) -> None: """ Store a checkpoint. Args: checkpoint (Checkpoint): The checkpoint to store. """ pass @abstractmethod def get_by_session(self, session_id: str) -> Optional[Checkpoint]: """ Get the latest checkpoint for a session. Args: session_id (str): The session identifier. Returns: Optional[Checkpoint]: The latest checkpoint if found, otherwise None. """ pass @abstractmethod def delete_by_session(self, session_id: str) -> None: """ Delete all checkpoints related to a session. Args: session_id (str): The session identifier. """ pass # Async methods async def aget(self, checkpoint_id: str) -> Optional[Checkpoint]: """ Asynchronously retrieve a checkpoint by its unique identifier. Args: checkpoint_id (str): The unique identifier of the checkpoint. Returns: Optional[Checkpoint]: The checkpoint if found, otherwise None. """ return await asyncio.to_thread(self.get, checkpoint_id) async def alist(self, params: Dict[str, Any]) -> List[Checkpoint]: """ Asynchronously list checkpoints matching the given parameters. Args: params (dict): Parameters to filter checkpoints. Returns: List[Checkpoint]: List of matching checkpoints. """ return await asyncio.to_thread(self.list, params) async def aput(self, checkpoint: Checkpoint) -> None: """ Asynchronously store a checkpoint. Args: checkpoint (Checkpoint): The checkpoint to store. """ await asyncio.to_thread(self.put, checkpoint) async def aget_by_session(self, session_id: str) -> Optional[Checkpoint]: """ Asynchronously get the latest checkpoint for a session. Args: session_id (str): The session identifier. Returns: Optional[Checkpoint]: The latest checkpoint if found, otherwise None. """ return await asyncio.to_thread(self.get_by_session, session_id) async def adelete_by_session(self, session_id: str) -> None: """ Asynchronously delete all checkpoints related to a session. Args: session_id (str): The session identifier. """ await asyncio.to_thread(self.delete_by_session, session_id) class VersionUtils: @staticmethod def get_next_version(version: int) -> int: """ Get the next version of the checkpoint. """ return version + 1 @staticmethod def get_previous_version(version: int) -> int: """ Get the previous version of the checkpoint. """ return version - 1 @staticmethod def is_version_greater(checkpoint: Checkpoint, version: int) -> bool: """ Check if the checkpoint version is greater than the given version. """ return checkpoint.version > version @staticmethod def is_version_less(checkpoint: Checkpoint, version: int) -> bool: """ Check if the checkpoint version is less than the given version. """ return checkpoint.version < version