Spaces:
Sleeping
Sleeping
import os | |
import random | |
import sys | |
from pathlib import Path | |
from aworld.core.task import Task | |
from aworld.core.agent.base import AgentFactory | |
from aworld.core.agent.swarm import Swarm | |
from aworld.runner import Runners | |
from aworld.agents.llm_agent import Agent | |
from aworld.config.conf import AgentConfig, ContextRuleConfig, ModelConfig, OptimizationConfig, LlmCompressionConfig | |
from aworld.core.context.base import Context | |
from aworld.core.event.base import Message | |
from aworld.runners.hook.hooks import PreLLMCallHook, PostLLMCallHook | |
from aworld.runners.hook.hook_factory import HookFactory | |
from aworld.utils.common import convert_to_snake | |
class ContextManagement(): | |
"""Test cases for Context Management system based on README examples""" | |
def init_agent(self, config_type: str = "1", context_rule: ContextRuleConfig = None): | |
if config_type == "1": | |
conf = AgentConfig( | |
llm_model_name=self.mock_model_name, | |
llm_base_url=self.mock_base_url, | |
llm_api_key=self.mock_api_key | |
) | |
else: | |
conf = AgentConfig( | |
llm_config=ModelConfig( | |
llm_model_name=self.mock_model_name, | |
llm_base_url=self.mock_base_url, | |
llm_api_key=self.mock_api_key | |
) | |
) | |
return Agent( | |
conf=conf, | |
name="my_agent" + str(random.randint(0, 1000000)), | |
system_prompt="You are a helpful assistant.", | |
agent_prompt="You are a helpful assistant.", | |
context_rule=context_rule | |
) | |
def __init__(self): | |
"""Set up test fixtures""" | |
self.mock_model_name = "gpt-4o" | |
self.mock_base_url = "http://localhost:34567" | |
self.mock_api_key = "lm-studio" | |
os.environ["LLM_API_KEY"] = self.mock_api_key | |
os.environ["LLM_BASE_URL"] = self.mock_base_url | |
os.environ["LLM_MODEL_NAME"] = self.mock_model_name | |
class _AssertRaisesContext: | |
"""Context manager for assertRaises""" | |
def __init__(self, expected_exception): | |
self.expected_exception = expected_exception | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
if exc_type is None: | |
raise AssertionError( | |
f"Expected {self.expected_exception.__name__} to be raised, but no exception was raised") | |
if not issubclass(exc_type, self.expected_exception): | |
raise AssertionError( | |
f"Expected {self.expected_exception.__name__} to be raised, but got {exc_type.__name__}: {exc_value}") | |
return True # Suppress the exception | |
def fail(self, msg=None): | |
"""Fail immediately with the given message""" | |
raise AssertionError(msg or "Test failed") | |
def run_agent(self, input, agent: Agent): | |
swarm = Swarm(agent, max_steps=1) | |
print('swarm ', swarm) | |
return Runners.sync_run( | |
input=input, | |
swarm=swarm | |
) | |
def run_multi_agent(self, input, agent1: Agent, agent2: Agent): | |
swarm = Swarm(agent1, agent2, max_steps=1) | |
return Runners.sync_run( | |
input=input, | |
swarm=swarm | |
) | |
def run_task(self, context: Context, agent: Agent): | |
swarm = Swarm(agent, max_steps=1) | |
task = Task(input="""What is an agent.""", swarm=swarm, context=context) | |
result = Runners.sync_run_task(task) | |
print("----------------------------------------------------------------------------------------------") | |
print(result) | |
def default_context_configuration(self): | |
# No need to explicitly configure context_rule, system automatically uses default configuration | |
# Default configuration is equivalent to: | |
# context_rule=ContextRuleConfig( | |
# optimization_config=OptimizationConfig( | |
# enabled=True, | |
# max_token_budget_ratio=1.0 # Use 100% of context window | |
# ), | |
# llm_compression_config=LlmCompressionConfig( | |
# enabled=False # Compression disabled by default | |
# ) | |
# ) | |
mock_agent = self.init_agent("1") | |
response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent) | |
print(response.answer) | |
def custom_context_configuration(self): | |
"""Test custom context configuration (README Configuration example)""" | |
# Create custom context rules | |
mock_agent = self.init_agent(context_rule=ContextRuleConfig( | |
optimization_config=OptimizationConfig( | |
enabled=True, | |
max_token_budget_ratio=0.00015 | |
), | |
llm_compression_config=LlmCompressionConfig( | |
enabled=True, | |
trigger_compress_token_length=100, | |
compress_model=ModelConfig( | |
llm_model_name=self.mock_model_name, | |
llm_base_url=self.mock_base_url, | |
llm_api_key=self.mock_api_key, | |
) | |
) | |
)) | |
response = self.run_agent(input="""describe What is an agent in details""", agent=mock_agent) | |
print(response.answer) | |
def state_management_and_recovery(self): | |
class StateModifyAgent(Agent): | |
async def async_policy(self, observation, info=None, **kwargs): | |
result = await super().async_policy(observation, info, **kwargs) | |
self.context.state['policy_executed'] = True | |
return result | |
class StateTrackingAgent(Agent): | |
async def async_policy(self, observation, info=None, **kwargs): | |
result = await super().async_policy(observation, info, **kwargs) | |
assert self.context.state['policy_executed'] == True | |
return result | |
# Create custom agent instance | |
custom_agent = StateModifyAgent( | |
conf=AgentConfig( | |
llm_model_name=self.mock_model_name, | |
llm_base_url=self.mock_base_url, | |
llm_api_key=self.mock_api_key | |
), | |
name="state_modify_agent", | |
system_prompt="You are a Python expert who provides detailed and practical answers.", | |
agent_prompt="You are a Python expert who provides detailed and practical answers.", | |
) | |
# Create a second agent for multi-agent testing | |
second_agent = StateTrackingAgent( | |
conf=AgentConfig( | |
llm_model_name=self.mock_model_name, | |
llm_base_url=self.mock_base_url, | |
llm_api_key=self.mock_api_key | |
), | |
name="state_tracking_agent", | |
system_prompt="You are a helpful assistant.", | |
agent_prompt="You are a helpful assistant.", | |
) | |
response = self.run_multi_agent( | |
input="What is an agent. describe within 20 words", | |
agent1=custom_agent, | |
agent2=second_agent | |
) | |
print(response.answer) | |
class TestHookSystem(ContextManagement): | |
def __init__(self): | |
super().__init__() | |
def hook_registration(self): | |
"""Test hook registration and retrieval""" | |
# Test that hooks are registered in _cls attribute | |
# Test hook creation using __call__ method | |
pre_hook = HookFactory("TestPreLLMHook") | |
post_hook = HookFactory("TestPostLLMHook") | |
def hook_execution(self): | |
mock_agent = self.init_agent("1") | |
response = self.run_agent(input="""What is an agent. describe within 20 words""", agent=mock_agent) | |
print(response.answer) | |
def task_context_transfer(self): | |
mock_agent = self.init_agent("1") | |
context = Context.instance() | |
context.state.update({"task": "What is an agent."}) | |
self.run_task(context=context, agent=mock_agent) | |
if __name__ == '__main__': | |
testContextManagement = ContextManagement() | |
testContextManagement.default_context_configuration() | |
testContextManagement.custom_context_configuration() | |
testContextManagement.state_management_and_recovery() | |
# testHookSystem = TestHookSystem() | |
# testHookSystem.hook_registration() | |
# testHookSystem = TestHookSystem() | |
# testHookSystem.hook_execution() | |
# testHookSystem = TestHookSystem() | |
# testHookSystem.task_context_transfer() | |