|
""" |
|
Test script for the consolidated Supabase memory implementation. |
|
|
|
This script tests all memory types to ensure functionality is preserved |
|
after consolidation. |
|
|
|
Usage: |
|
python -m memory.test_consolidated_implementation |
|
""" |
|
|
|
import os |
|
import sys |
|
import logging |
|
from typing import Dict, Any, Optional |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger("test_consolidated_memory") |
|
|
|
try: |
|
from memory import ( |
|
SupabaseMemory, |
|
ConversationMemory, |
|
ResultCache, |
|
WorkingMemory, |
|
verify_tables_exist, |
|
test_memory_implementation |
|
) |
|
from agent.config import get_memory_config |
|
except ImportError as e: |
|
logger.error(f"Error importing required components: {str(e)}") |
|
sys.exit(1) |
|
|
|
def run_tests(): |
|
"""Run tests for the consolidated memory implementation.""" |
|
logger.info("Testing consolidated Supabase memory implementation...") |
|
|
|
try: |
|
|
|
memory_config = get_memory_config() |
|
memory_config["enabled"] = True |
|
|
|
base_memory = SupabaseMemory(config=memory_config) |
|
|
|
if not base_memory.initialized: |
|
logger.warning("Supabase is not initialized. Running in local-only mode.") |
|
|
|
|
|
if base_memory.initialized: |
|
if not verify_tables_exist(base_memory): |
|
logger.warning("Not all required tables exist. Some tests may fail.") |
|
|
|
|
|
test_base_memory(base_memory) |
|
test_working_memory(base_memory) |
|
test_conversation_memory(base_memory) |
|
test_result_cache(base_memory) |
|
|
|
|
|
if test_memory_implementation(base_memory): |
|
logger.info("✅ All tests passed!") |
|
return True |
|
else: |
|
logger.error("❌ Some tests failed.") |
|
return False |
|
|
|
except Exception as e: |
|
logger.error(f"Error running tests: {str(e)}") |
|
return False |
|
|
|
def test_base_memory(memory: SupabaseMemory): |
|
"""Test the base memory functionality.""" |
|
logger.info("Testing base memory...") |
|
|
|
|
|
memory.store("test_base_key", "test_base_value") |
|
value = memory.retrieve("test_base_key") |
|
|
|
if value != "test_base_value": |
|
logger.error(f"Base memory store/retrieve test failed: expected 'test_base_value', got '{value}'") |
|
return False |
|
|
|
|
|
keys = memory.list_keys() |
|
if "test_base_key" not in keys: |
|
logger.error(f"Base memory list_keys test failed: 'test_base_key' not found in {keys}") |
|
return False |
|
|
|
|
|
memory.delete("test_base_key") |
|
value = memory.retrieve("test_base_key") |
|
|
|
if value is not None: |
|
logger.error(f"Base memory delete test failed: expected None, got '{value}'") |
|
return False |
|
|
|
logger.info("✅ Base memory tests passed") |
|
return True |
|
|
|
def test_working_memory(memory: SupabaseMemory): |
|
"""Test the working memory functionality.""" |
|
logger.info("Testing working memory...") |
|
|
|
session_id = "test_session" |
|
working = WorkingMemory(memory, session_id) |
|
|
|
|
|
test_data = {"key": "value", "nested": {"data": 123}} |
|
working.store_intermediate_result("test_working_key", test_data) |
|
value = working.get_intermediate_result("test_working_key") |
|
|
|
if value != test_data: |
|
logger.error(f"Working memory store/retrieve test failed: expected {test_data}, got {value}") |
|
return False |
|
|
|
|
|
results = working.list_intermediate_results() |
|
if "test_working_key" not in results: |
|
logger.error(f"Working memory list test failed: 'test_working_key' not found in {results}") |
|
return False |
|
|
|
|
|
working.clear() |
|
value = working.get_intermediate_result("test_working_key") |
|
|
|
if value is not None: |
|
logger.error(f"Working memory clear test failed: expected None, got {value}") |
|
return False |
|
|
|
logger.info("✅ Working memory tests passed") |
|
return True |
|
|
|
def test_conversation_memory(memory: SupabaseMemory): |
|
"""Test the conversation memory functionality.""" |
|
logger.info("Testing conversation memory...") |
|
|
|
conversation_id = "test_conversation" |
|
conversation = ConversationMemory(memory, conversation_id) |
|
|
|
|
|
conversation.add_message("user", "Hello, how are you?") |
|
conversation.add_message("assistant", "I'm fine, thank you!") |
|
|
|
messages = conversation.get_messages() |
|
|
|
if len(messages) != 2: |
|
logger.error(f"Conversation memory add/get test failed: expected 2 messages, got {len(messages)}") |
|
return False |
|
|
|
if messages[0]["role"] != "user" or messages[0]["content"] != "Hello, how are you?": |
|
logger.error(f"Conversation memory message test failed: unexpected first message {messages[0]}") |
|
return False |
|
|
|
if messages[1]["role"] != "assistant" or messages[1]["content"] != "I'm fine, thank you!": |
|
logger.error(f"Conversation memory message test failed: unexpected second message {messages[1]}") |
|
return False |
|
|
|
|
|
conversation.clear() |
|
messages = conversation.get_messages() |
|
|
|
if len(messages) != 0: |
|
logger.error(f"Conversation memory clear test failed: expected 0 messages, got {len(messages)}") |
|
return False |
|
|
|
logger.info("✅ Conversation memory tests passed") |
|
return True |
|
|
|
def test_result_cache(memory: SupabaseMemory): |
|
"""Test the result cache functionality.""" |
|
logger.info("Testing result cache...") |
|
|
|
cache = ResultCache(memory) |
|
|
|
|
|
test_result = {"search_results": ["result1", "result2"], "timestamp": "2025-05-08T10:00:00"} |
|
cache.cache_result("test_query", test_result) |
|
value = cache.get_result("test_query") |
|
|
|
if value != test_result: |
|
logger.error(f"Result cache/get test failed: expected {test_result}, got {value}") |
|
return False |
|
|
|
|
|
cache.invalidate("test_query") |
|
value = cache.get_result("test_query") |
|
|
|
if value is not None: |
|
logger.error(f"Result cache invalidate test failed: expected None, got {value}") |
|
return False |
|
|
|
|
|
cache.cache_result("test_query2", "test_result2") |
|
cache.clear() |
|
value = cache.get_result("test_query2") |
|
|
|
if value is not None: |
|
logger.error(f"Result cache clear test failed: expected None, got {value}") |
|
return False |
|
|
|
logger.info("✅ Result cache tests passed") |
|
return True |
|
|
|
if __name__ == "__main__": |
|
if run_tests(): |
|
sys.exit(0) |
|
else: |
|
sys.exit(1) |