Duibonduil's picture
Upload 2 files
8af6ba8 verified
from typing import Any, Dict, List, Optional
from . import Checkpoint, BaseCheckpointRepository, VersionUtils
class InMemoryCheckpointRepository(BaseCheckpointRepository):
"""
In-memory implementation of BaseCheckpointRepository.
Stores checkpoints in a simple in-memory dictionary.
Thread safety is not guaranteed.
"""
def __init__(self) -> None:
"""
Initialize the in-memory checkpoint repository.
"""
self._checkpoints: Dict[str, Checkpoint] = {}
self._session_index: Dict[str, List[str]] = {}
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
"""
Retrieve a checkpoint by its unique identifier.
Args:
checkpoint_id (str): The unique identifier of the checkpoint.
Returns:
Optional[Checkpoint]: The checkpoint if found, otherwise None.
"""
return self._checkpoints.get(checkpoint_id)
def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
"""
List checkpoints matching the given parameters.
Args:
params (dict): Parameters to filter checkpoints.
Returns:
List[Checkpoint]: List of matching checkpoints.
"""
result = []
for cp in self._checkpoints.values():
match = True
for k, v in params.items():
if k == 'session_id':
if cp.metadata.session_id != v:
match = False
break
elif k == 'task_id':
if cp.metadata.task_id != v:
match = False
break
elif cp.get(k) != v:
match = False
break
if match:
result.append(cp)
return result
def put(self, checkpoint: Checkpoint) -> None:
"""
Store a checkpoint.
Args:
checkpoint (Checkpoint): The checkpoint to store.
"""
# Find last version checkpoint by session_id
last_checkpoint = self.get_by_session(checkpoint.metadata.session_id)
if last_checkpoint:
# Compare versions to ensure optimistic locking
if VersionUtils.is_version_less(checkpoint, last_checkpoint.version):
raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}")
# Store the new checkpoint
self._checkpoints[checkpoint.id] = checkpoint
# Update session index
session_id = checkpoint.metadata.session_id
if session_id:
if session_id not in self._session_index:
self._session_index[session_id] = []
self._session_index[session_id].append(checkpoint.id)
def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
"""
Get the latest checkpoint for a session.
Args:
session_id (str): The session identifier.
Returns:
Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
"""
ids = self._session_index.get(session_id, [])
if not ids:
return None
# Assume the last one is the latest
last_id = ids[-1]
return self._checkpoints.get(last_id)
def delete_by_session(self, session_id: str) -> None:
"""
Delete all checkpoints related to a session.
Args:
session_id (str): The session identifier.
"""
ids = self._session_index.pop(session_id, [])
for cid in ids:
self._checkpoints.pop(cid, None)