Spaces:
Sleeping
Sleeping
from typing import Any, Dict, List, Optional | |
from . import Checkpoint, BaseCheckpointRepository, VersionUtils | |
class InMemoryCheckpointRepository(BaseCheckpointRepository): | |
""" | |
In-memory implementation of BaseCheckpointRepository. | |
Stores checkpoints in a simple in-memory dictionary. | |
Thread safety is not guaranteed. | |
""" | |
def __init__(self) -> None: | |
""" | |
Initialize the in-memory checkpoint repository. | |
""" | |
self._checkpoints: Dict[str, Checkpoint] = {} | |
self._session_index: Dict[str, List[str]] = {} | |
def get(self, checkpoint_id: str) -> Optional[Checkpoint]: | |
""" | |
Retrieve a checkpoint by its unique identifier. | |
Args: | |
checkpoint_id (str): The unique identifier of the checkpoint. | |
Returns: | |
Optional[Checkpoint]: The checkpoint if found, otherwise None. | |
""" | |
return self._checkpoints.get(checkpoint_id) | |
def list(self, params: Dict[str, Any]) -> List[Checkpoint]: | |
""" | |
List checkpoints matching the given parameters. | |
Args: | |
params (dict): Parameters to filter checkpoints. | |
Returns: | |
List[Checkpoint]: List of matching checkpoints. | |
""" | |
result = [] | |
for cp in self._checkpoints.values(): | |
match = True | |
for k, v in params.items(): | |
if k == 'session_id': | |
if cp.metadata.session_id != v: | |
match = False | |
break | |
elif k == 'task_id': | |
if cp.metadata.task_id != v: | |
match = False | |
break | |
elif cp.get(k) != v: | |
match = False | |
break | |
if match: | |
result.append(cp) | |
return result | |
def put(self, checkpoint: Checkpoint) -> None: | |
""" | |
Store a checkpoint. | |
Args: | |
checkpoint (Checkpoint): The checkpoint to store. | |
""" | |
# Find last version checkpoint by session_id | |
last_checkpoint = self.get_by_session(checkpoint.metadata.session_id) | |
if last_checkpoint: | |
# Compare versions to ensure optimistic locking | |
if VersionUtils.is_version_less(checkpoint, last_checkpoint.version): | |
raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}") | |
# Store the new checkpoint | |
self._checkpoints[checkpoint.id] = checkpoint | |
# Update session index | |
session_id = checkpoint.metadata.session_id | |
if session_id: | |
if session_id not in self._session_index: | |
self._session_index[session_id] = [] | |
self._session_index[session_id].append(checkpoint.id) | |
def get_by_session(self, session_id: str) -> Optional[Checkpoint]: | |
""" | |
Get the latest checkpoint for a session. | |
Args: | |
session_id (str): The session identifier. | |
Returns: | |
Optional[Checkpoint]: The latest checkpoint if found, otherwise None. | |
""" | |
ids = self._session_index.get(session_id, []) | |
if not ids: | |
return None | |
# Assume the last one is the latest | |
last_id = ids[-1] | |
return self._checkpoints.get(last_id) | |
def delete_by_session(self, session_id: str) -> None: | |
""" | |
Delete all checkpoints related to a session. | |
Args: | |
session_id (str): The session identifier. | |
""" | |
ids = self._session_index.pop(session_id, []) | |
for cid in ids: | |
self._checkpoints.pop(cid, None) | |