Final_Assignment_GAIAAgent / src /gaia /memory /test_consolidated_implementation.py
JoachimVC's picture
Upload GAIA agent implementation files for assessment
c922f8b
"""
Test script for the consolidated Supabase memory implementation.
This script tests all memory types to ensure functionality is preserved
after consolidation.
Usage:
python -m memory.test_consolidated_implementation
"""
import os
import sys
import logging
from typing import Dict, Any, Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("test_consolidated_memory")
try:
from memory import (
SupabaseMemory,
ConversationMemory,
ResultCache,
WorkingMemory,
verify_tables_exist,
test_memory_implementation
)
from agent.config import get_memory_config
except ImportError as e:
logger.error(f"Error importing required components: {str(e)}")
sys.exit(1)
def run_tests():
"""Run tests for the consolidated memory implementation."""
logger.info("Testing consolidated Supabase memory implementation...")
try:
# Create memory instance
memory_config = get_memory_config()
memory_config["enabled"] = True
base_memory = SupabaseMemory(config=memory_config)
if not base_memory.initialized:
logger.warning("Supabase is not initialized. Running in local-only mode.")
# Verify tables exist
if base_memory.initialized:
if not verify_tables_exist(base_memory):
logger.warning("Not all required tables exist. Some tests may fail.")
# Test each memory type
test_base_memory(base_memory)
test_working_memory(base_memory)
test_conversation_memory(base_memory)
test_result_cache(base_memory)
# Run comprehensive test
if test_memory_implementation(base_memory):
logger.info("✅ All tests passed!")
return True
else:
logger.error("❌ Some tests failed.")
return False
except Exception as e:
logger.error(f"Error running tests: {str(e)}")
return False
def test_base_memory(memory: SupabaseMemory):
"""Test the base memory functionality."""
logger.info("Testing base memory...")
# Test store and retrieve
memory.store("test_base_key", "test_base_value")
value = memory.retrieve("test_base_key")
if value != "test_base_value":
logger.error(f"Base memory store/retrieve test failed: expected 'test_base_value', got '{value}'")
return False
# Test list_keys
keys = memory.list_keys()
if "test_base_key" not in keys:
logger.error(f"Base memory list_keys test failed: 'test_base_key' not found in {keys}")
return False
# Test delete
memory.delete("test_base_key")
value = memory.retrieve("test_base_key")
if value is not None:
logger.error(f"Base memory delete test failed: expected None, got '{value}'")
return False
logger.info("✅ Base memory tests passed")
return True
def test_working_memory(memory: SupabaseMemory):
"""Test the working memory functionality."""
logger.info("Testing working memory...")
session_id = "test_session"
working = WorkingMemory(memory, session_id)
# Test store and retrieve
test_data = {"key": "value", "nested": {"data": 123}}
working.store_intermediate_result("test_working_key", test_data)
value = working.get_intermediate_result("test_working_key")
if value != test_data:
logger.error(f"Working memory store/retrieve test failed: expected {test_data}, got {value}")
return False
# Test list results
results = working.list_intermediate_results()
if "test_working_key" not in results:
logger.error(f"Working memory list test failed: 'test_working_key' not found in {results}")
return False
# Test clear
working.clear()
value = working.get_intermediate_result("test_working_key")
if value is not None:
logger.error(f"Working memory clear test failed: expected None, got {value}")
return False
logger.info("✅ Working memory tests passed")
return True
def test_conversation_memory(memory: SupabaseMemory):
"""Test the conversation memory functionality."""
logger.info("Testing conversation memory...")
conversation_id = "test_conversation"
conversation = ConversationMemory(memory, conversation_id)
# Test add and get messages
conversation.add_message("user", "Hello, how are you?")
conversation.add_message("assistant", "I'm fine, thank you!")
messages = conversation.get_messages()
if len(messages) != 2:
logger.error(f"Conversation memory add/get test failed: expected 2 messages, got {len(messages)}")
return False
if messages[0]["role"] != "user" or messages[0]["content"] != "Hello, how are you?":
logger.error(f"Conversation memory message test failed: unexpected first message {messages[0]}")
return False
if messages[1]["role"] != "assistant" or messages[1]["content"] != "I'm fine, thank you!":
logger.error(f"Conversation memory message test failed: unexpected second message {messages[1]}")
return False
# Test clear
conversation.clear()
messages = conversation.get_messages()
if len(messages) != 0:
logger.error(f"Conversation memory clear test failed: expected 0 messages, got {len(messages)}")
return False
logger.info("✅ Conversation memory tests passed")
return True
def test_result_cache(memory: SupabaseMemory):
"""Test the result cache functionality."""
logger.info("Testing result cache...")
cache = ResultCache(memory)
# Test cache and get
test_result = {"search_results": ["result1", "result2"], "timestamp": "2025-05-08T10:00:00"}
cache.cache_result("test_query", test_result)
value = cache.get_result("test_query")
if value != test_result:
logger.error(f"Result cache/get test failed: expected {test_result}, got {value}")
return False
# Test invalidate
cache.invalidate("test_query")
value = cache.get_result("test_query")
if value is not None:
logger.error(f"Result cache invalidate test failed: expected None, got {value}")
return False
# Test clear
cache.cache_result("test_query2", "test_result2")
cache.clear()
value = cache.get_result("test_query2")
if value is not None:
logger.error(f"Result cache clear test failed: expected None, got {value}")
return False
logger.info("✅ Result cache tests passed")
return True
if __name__ == "__main__":
if run_tests():
sys.exit(0)
else:
sys.exit(1)