Spaces:
Sleeping
Sleeping
File size: 5,175 Bytes
81ec5d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import unittest
import sys
import os
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
# Add the project root to Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from tests.base_test import BaseTest
from aworld.config.conf import AgentConfig, ModelConfig, ContextRuleConfig, OptimizationConfig, LlmCompressionConfig
from aworld.core.context.processor import CompressionResult, CompressionType
from aworld.core.context.processor.llm_compressor import LLMCompressor
from aworld.core.context.processor.prompt_processor import PromptProcessor
from aworld.core.context.base import AgentContext, ContextUsage
class TestPromptCompressor(BaseTest):
"""Test cases for PromptCompressor.compress_batch function"""
def __init__(self):
"""Set up test fixtures"""
self.mock_model_name = "qwen/qwen3-1.7b"
self.mock_base_url = "http://localhost:1234/v1"
self.mock_api_key = "lm-studio"
self.mock_llm_config = ModelConfig(
llm_model_name=self.mock_model_name,
llm_base_url=self.mock_base_url,
llm_api_key=self.mock_api_key
)
os.environ["LLM_API_KEY"] = self.mock_api_key
os.environ["LLM_BASE_URL"] = self.mock_base_url
os.environ["LLM_MODEL_NAME"] = self.mock_model_name
def test_compress_batch_basic(self):
compressor = LLMCompressor(
llm_config=self.mock_llm_config
)
# Test data
contents = [
"[SYSTEM]You are a helpful assistant.\n[USER]This is the first long text content that needs compression. This is the first long text content that needs compression.",
]
# Execute compress_batch
results = compressor.compress_batch(contents)
# Assertions
for result in results:
self.assertIsInstance(result, CompressionResult)
self.assertEqual(result.compression_type, CompressionType.LLM_BASED)
self.assertTrue('This is the first long text content that needs compression. This is the first long text content that needs compression.' not in result.compressed_content)
def test_compress_messages(self):
"""Test compress_messages function from PromptProcessor"""
# Create context rule with compression enabled
context_rule = ContextRuleConfig(
optimization_config=OptimizationConfig(
enabled=True,
max_token_budget_ratio=0.8
),
llm_compression_config=LlmCompressionConfig(
enabled=True,
trigger_compress_token_length=10, # Low threshold to trigger compression
compress_model=self.mock_llm_config
)
)
# Create agent context
agent_context = AgentContext(
agent_id="test_agent",
agent_name="test_agent",
agent_desc="Test agent for compression",
system_prompt="You are a helpful assistant.",
agent_prompt="You are a helpful assistant.",
model_config=self.mock_llm_config,
context_rule=context_rule,
context_usage=ContextUsage(total_context_length=4096)
)
# Create prompt processor
processor = PromptProcessor(agent_context)
# Test messages with repeated content that needs compression
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "This is the first long text content that needs compression. This is the first long text content that needs compression."
},
{
"role": "assistant",
"content": "I understand you want me to help with compression."
}
]
# Execute compress_messages
compressed_messages = processor.compress_messages(messages)
# Assertions
self.assertIsInstance(compressed_messages, list)
self.assertEqual(len(compressed_messages), len(messages))
# Find the user message and verify it was processed
user_message = None
for msg in compressed_messages:
if msg.get("role") == "user":
user_message = msg
break
self.assertIsNotNone(user_message)
# The original repeated text should be compressed
original_content = "This is the first long text content that needs compression. This is the first long text content that needs compression."
self.assertNotEqual(user_message["content"], original_content)
# The compressed content should be shorter than original
self.assertLess(len(user_message["content"]), len(original_content))
if __name__ == '__main__':
testPromptCompressor = TestPromptCompressor()
testPromptCompressor.test_compress_batch_basic()
testPromptCompressor = TestPromptCompressor()
testPromptCompressor.test_compress_messages()
|