Spaces:
Sleeping
Sleeping
from typing import Any, Dict, Optional, List | |
import copy | |
import uuid | |
from datetime import datetime, timezone | |
from abc import ABC, abstractmethod | |
import asyncio | |
from pydantic import BaseModel, Field | |
class CheckpointMetadata(BaseModel): | |
""" | |
Metadata for a checkpoint, including session and task identifiers. | |
Attributes: | |
session_id (str): The session identifier (required). | |
task_id (Optional[str]): The task identifier (optional). | |
""" | |
session_id: str = Field(..., description="The session identifier.") | |
task_id: Optional[str] = Field(None, description="The task identifier.") | |
class Checkpoint(BaseModel): | |
""" | |
Core structure for a state checkpoint. | |
Attributes: | |
id (str): Unique identifier for the checkpoint. | |
ts (str): Timestamp of the checkpoint. | |
metadata (CheckpointMetadata): Metadata associated with the checkpoint. | |
values (dict[str, Any]): State values stored in the checkpoint. | |
version (str): Version of the checkpoint format. | |
parent_id (Optional[str]): Parent checkpoint identifier, if any. | |
namespace (str): Namespace for the checkpoint, default is 'aworld'. | |
""" | |
id: str = Field(..., description="Unique identifier for the checkpoint.") | |
ts: str = Field(..., description="Timestamp of the checkpoint.") | |
metadata: CheckpointMetadata = Field(..., description="Metadata associated with the checkpoint.") | |
values: Dict[str, Any] = Field(..., description="State values stored in the checkpoint.") | |
version: int = Field(..., description="Version of the checkpoint format.") | |
parent_id: Optional[str] = Field(default=None, description="Parent checkpoint identifier, if any.") | |
namespace: str = Field(default="aworld", description="Namespace for the checkpoint, default is 'aworld'.") | |
def empty_checkpoint() -> Checkpoint: | |
""" | |
Create an empty checkpoint with default values. | |
Returns: | |
Checkpoint: An empty checkpoint structure. | |
""" | |
return Checkpoint( | |
id=str(uuid.uuid4()), | |
ts=datetime.now(timezone.utc).isoformat(), | |
metadata=CheckpointMetadata(session_id="", task_id=None), | |
values={}, | |
version=1, | |
parent_id=None, | |
namespace="aworld", | |
) | |
def copy_checkpoint(checkpoint: Checkpoint) -> Checkpoint: | |
""" | |
Create a deep copy of a checkpoint. | |
Args: | |
checkpoint (Checkpoint): The checkpoint to copy. | |
Returns: | |
Checkpoint: A deep copy of the provided checkpoint. | |
""" | |
return copy.deepcopy(checkpoint) | |
def create_checkpoint( | |
values: Dict[str, Any], | |
metadata: CheckpointMetadata, | |
parent_id: Optional[str] = None, | |
version: int = 1, | |
namespace: str = 'aworld', | |
) -> Checkpoint: | |
""" | |
Create a new checkpoint from provided state values and metadata. | |
Args: | |
values (dict[str, Any]): State values to store in the checkpoint. | |
metadata (CheckpointMetadata): Metadata for the checkpoint. | |
parent_id (Optional[str]): Parent checkpoint identifier, if any. | |
version (str): Version of the checkpoint format. | |
namespace (str): Namespace for the checkpoint. | |
Returns: | |
Checkpoint: The newly created checkpoint. | |
""" | |
return Checkpoint( | |
id=str(uuid.uuid4()), | |
ts=datetime.now(timezone.utc).isoformat(), | |
metadata=metadata, | |
values=values, | |
version=VersionUtils.get_next_version(version), | |
parent_id=parent_id, | |
namespace=namespace, | |
) | |
class BaseCheckpointRepository(ABC): | |
""" | |
Abstract base class for a checkpoint repository. | |
Provides synchronous and asynchronous methods for checkpoint management. | |
""" | |
def get(self, checkpoint_id: str) -> Optional[Checkpoint]: | |
""" | |
Retrieve a checkpoint by its unique identifier. | |
Args: | |
checkpoint_id (str): The unique identifier of the checkpoint. | |
Returns: | |
Optional[Checkpoint]: The checkpoint if found, otherwise None. | |
""" | |
pass | |
def list(self, params: Dict[str, Any]) -> List[Checkpoint]: | |
""" | |
List checkpoints matching the given parameters. | |
Args: | |
params (dict): Parameters to filter checkpoints. | |
Returns: | |
List[Checkpoint]: List of matching checkpoints. | |
""" | |
pass | |
def put(self, checkpoint: Checkpoint) -> None: | |
""" | |
Store a checkpoint. | |
Args: | |
checkpoint (Checkpoint): The checkpoint to store. | |
""" | |
pass | |
def get_by_session(self, session_id: str) -> Optional[Checkpoint]: | |
""" | |
Get the latest checkpoint for a session. | |
Args: | |
session_id (str): The session identifier. | |
Returns: | |
Optional[Checkpoint]: The latest checkpoint if found, otherwise None. | |
""" | |
pass | |
def delete_by_session(self, session_id: str) -> None: | |
""" | |
Delete all checkpoints related to a session. | |
Args: | |
session_id (str): The session identifier. | |
""" | |
pass | |
# Async methods | |
async def aget(self, checkpoint_id: str) -> Optional[Checkpoint]: | |
""" | |
Asynchronously retrieve a checkpoint by its unique identifier. | |
Args: | |
checkpoint_id (str): The unique identifier of the checkpoint. | |
Returns: | |
Optional[Checkpoint]: The checkpoint if found, otherwise None. | |
""" | |
return await asyncio.to_thread(self.get, checkpoint_id) | |
async def alist(self, params: Dict[str, Any]) -> List[Checkpoint]: | |
""" | |
Asynchronously list checkpoints matching the given parameters. | |
Args: | |
params (dict): Parameters to filter checkpoints. | |
Returns: | |
List[Checkpoint]: List of matching checkpoints. | |
""" | |
return await asyncio.to_thread(self.list, params) | |
async def aput(self, checkpoint: Checkpoint) -> None: | |
""" | |
Asynchronously store a checkpoint. | |
Args: | |
checkpoint (Checkpoint): The checkpoint to store. | |
""" | |
await asyncio.to_thread(self.put, checkpoint) | |
async def aget_by_session(self, session_id: str) -> Optional[Checkpoint]: | |
""" | |
Asynchronously get the latest checkpoint for a session. | |
Args: | |
session_id (str): The session identifier. | |
Returns: | |
Optional[Checkpoint]: The latest checkpoint if found, otherwise None. | |
""" | |
return await asyncio.to_thread(self.get_by_session, session_id) | |
async def adelete_by_session(self, session_id: str) -> None: | |
""" | |
Asynchronously delete all checkpoints related to a session. | |
Args: | |
session_id (str): The session identifier. | |
""" | |
await asyncio.to_thread(self.delete_by_session, session_id) | |
class VersionUtils: | |
def get_next_version(version: int) -> int: | |
""" | |
Get the next version of the checkpoint. | |
""" | |
return version + 1 | |
def get_previous_version(version: int) -> int: | |
""" | |
Get the previous version of the checkpoint. | |
""" | |
return version - 1 | |
def is_version_greater(checkpoint: Checkpoint, version: int) -> bool: | |
""" | |
Check if the checkpoint version is greater than the given version. | |
""" | |
return checkpoint.version > version | |
def is_version_less(checkpoint: Checkpoint, version: int) -> bool: | |
""" | |
Check if the checkpoint version is less than the given version. | |
""" | |
return checkpoint.version < version |