gaia-enhanced-agent / tests /test_tool_selection.py
GAIA Agent Deployment
Deploy Complete Enhanced GAIA Agent with Phase 1-6 Improvements
9a6a4dc
"""
Comprehensive Testing for Phase 4 Tool Selection Optimization
This test suite validates the tool selection optimization implementation
to ensure it addresses the critical evaluation issues identified:
1. Inappropriate tool selection for specific question types
2. Tool usage pattern optimization
3. Dynamic tool selection based on question analysis
4. Tool execution strategy optimization
"""
import pytest
import logging
from typing import List, Dict, Any
from unittest.mock import Mock, patch
# Import the modules to test
from utils.enhanced_question_classifier import (
EnhancedQuestionClassifier,
ClassificationResult,
QuestionType,
ToolType
)
from utils.tool_selector import (
ToolSelector,
ToolSelectionResult,
ToolExecutionPlan,
ToolExecutionStrategy,
ToolPriority
)
from agents.fixed_enhanced_unified_agno_agent import FixedGAIAAgent
logger = logging.getLogger(__name__)
class TestEnhancedQuestionClassifier:
"""Test the enhanced question classifier."""
def setup_method(self):
"""Set up test fixtures."""
self.classifier = EnhancedQuestionClassifier()
def test_bird_species_classification(self):
"""Test classification of bird species counting questions."""
question = "How many bird species are there in the world?"
result = self.classifier.classify_question(question)
assert result.question_type == QuestionType.KNOWLEDGE_FACTS
assert result.sub_category == "counting_facts"
assert ToolType.WIKIPEDIA in result.recommended_tools
assert ToolType.EXA in result.recommended_tools
assert result.confidence > 0.8
assert "bird species" in result.reasoning.lower()
def test_exponentiation_classification(self):
"""Test classification of exponentiation questions."""
question = "What is 2^8?"
result = self.classifier.classify_question(question)
assert result.question_type == QuestionType.MATHEMATICAL
assert result.sub_category == "exponentiation"
assert ToolType.PYTHON in result.recommended_tools
assert result.confidence > 0.8
assert "exponentiation" in result.reasoning.lower()
def test_artist_discography_classification(self):
"""Test classification of artist discography questions."""
question = "What albums did Mercedes Sosa release between 2000 and 2009?"
result = self.classifier.classify_question(question)
assert result.question_type == QuestionType.WEB_RESEARCH
assert result.sub_category == "artist_discography"
assert ToolType.EXA in result.recommended_tools
assert result.confidence > 0.7
assert "discography" in result.reasoning.lower()
def test_basic_arithmetic_classification(self):
"""Test classification of basic arithmetic questions."""
question = "What is 25 * 17?"
result = self.classifier.classify_question(question)
assert result.question_type == QuestionType.MATHEMATICAL
assert result.sub_category == "basic_arithmetic"
assert ToolType.CALCULATOR in result.recommended_tools
assert result.confidence > 0.9
def test_youtube_content_classification(self):
"""Test classification of YouTube content questions."""
question = "What is discussed in this YouTube video? https://youtube.com/watch?v=example"
result = self.classifier.classify_question(question)
assert result.question_type == QuestionType.VIDEO_ANALYSIS
assert ToolType.YOUTUBE in result.recommended_tools
assert result.confidence > 0.8
def test_multimodal_image_classification(self):
"""Test classification with image attachments."""
question = "What do you see in this image?"
files = [{"type": "image", "path": "test.jpg"}]
result = self.classifier.classify_question(question, files)
assert result.question_type == QuestionType.MULTIMODAL
assert result.sub_category == "image_analysis"
assert ToolType.IMAGE_ANALYSIS in result.recommended_tools
assert result.confidence > 0.8
class TestToolSelector:
"""Test the tool selector optimization."""
def setup_method(self):
"""Set up test fixtures."""
self.selector = ToolSelector()
def test_bird_species_optimization_rule(self):
"""Test optimization rule for bird species counting."""
question = "How many bird species are there in the world?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.WIKIPEDIA
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert result.confidence > 0.9
assert "bird species counting" in result.optimization_reasoning.lower()
assert len(result.fallback_plans) > 0
assert result.fallback_plans[0].tool_type == ToolType.EXA
def test_exponentiation_optimization_rule(self):
"""Test optimization rule for exponentiation."""
question = "What is 2^8?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.PYTHON
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert result.confidence > 0.8
assert "exponentiation" in result.optimization_reasoning.lower()
assert "variable_to_return" in result.primary_plan.parameters
def test_artist_discography_optimization_rule(self):
"""Test optimization rule for artist discography."""
question = "What albums did Mercedes Sosa release between 2000 and 2009?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.EXA
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert result.confidence > 0.8
assert "discography" in result.optimization_reasoning.lower()
def test_basic_arithmetic_optimization_rule(self):
"""Test optimization rule for basic arithmetic."""
question = "What is 25 * 17?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.CALCULATOR
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert result.confidence > 0.9
assert "arithmetic" in result.optimization_reasoning.lower()
def test_youtube_optimization_rule(self):
"""Test optimization rule for YouTube content."""
question = "What is discussed in https://youtube.com/watch?v=example?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.YOUTUBE
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert result.confidence > 0.9
assert "youtube" in result.optimization_reasoning.lower()
def test_general_classification_fallback(self):
"""Test fallback to general classification when no specific rule matches."""
question = "What is the weather like today?"
result = self.selector.select_optimal_tools(question)
# Should fall back to general classification
assert result.primary_plan.tool_type in [ToolType.EXA, ToolType.WIKIPEDIA]
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL
assert "Classification-based selection" in result.optimization_reasoning
def test_tool_performance_tracking(self):
"""Test tool performance tracking functionality."""
# Update performance for a tool
self.selector.update_tool_performance(ToolType.WIKIPEDIA, True, 5.0, 0.9)
# Check that performance was updated
stats = self.selector.performance_stats[ToolType.WIKIPEDIA]
assert stats['usage_count'] == 1
assert stats['failure_count'] == 0
assert stats['success_rate'] > 0.8
assert stats['avg_response_time'] < 10.0
def test_performance_report_generation(self):
"""Test performance report generation."""
report = self.selector.get_tool_performance_report()
assert 'tool_performance' in report
assert 'optimization_rules' in report
assert 'performance_summary' in report
assert len(report['optimization_rules']) > 0
assert 'avg_success_rate' in report['performance_summary']
class TestFixedGAIAAgentIntegration:
"""Test integration of tool selection optimization in the main agent."""
def setup_method(self):
"""Set up test fixtures."""
# Mock the agent initialization to avoid API key requirements
with patch('agents.fixed_enhanced_unified_agno_agent.MistralChat'), \
patch('agents.fixed_enhanced_unified_agno_agent.Agent'):
self.agent = FixedGAIAAgent()
self.agent.available = True
self.agent.agent = Mock()
def test_tool_optimization_integration(self):
"""Test that tool optimization is properly integrated."""
# Check that optimization components are initialized
assert hasattr(self.agent, 'question_classifier')
assert hasattr(self.agent, 'tool_selector')
assert isinstance(self.agent.question_classifier, EnhancedQuestionClassifier)
assert isinstance(self.agent.tool_selector, ToolSelector)
def test_apply_tool_optimizations_method(self):
"""Test the _apply_tool_optimizations method."""
question = "What is 2^8?"
# Create a mock tool selection result
mock_selection = ToolSelectionResult(
primary_plan=ToolExecutionPlan(
tool_type=ToolType.PYTHON,
priority=ToolPriority.CRITICAL,
parameters={"variable_to_return": "result"},
expected_output="Numeric result",
success_criteria="Output contains: result",
fallback_tools=[],
timeout_seconds=30,
retry_count=1
),
fallback_plans=[],
execution_strategy=ToolExecutionStrategy.SEQUENTIAL,
optimization_reasoning="Exponentiation requires Python",
confidence=0.9,
estimated_success_rate=0.85
)
# Test the optimization application
optimized_question = self.agent._apply_tool_optimizations(question, mock_selection)
assert "TOOL OPTIMIZATION GUIDANCE" in optimized_question
assert "python" in optimized_question.lower()
assert "confidence: 0.9" in optimized_question.lower()
assert question in optimized_question
class TestCriticalEvaluationScenarios:
"""Test scenarios that address the specific evaluation issues."""
def setup_method(self):
"""Set up test fixtures."""
self.selector = ToolSelector()
def test_bird_species_not_calculator(self):
"""Test that bird species questions don't use calculator (addresses '468' issue)."""
question = "How many bird species are there in the world?"
result = self.selector.select_optimal_tools(question)
# Should NOT use calculator
assert result.primary_plan.tool_type != ToolType.CALCULATOR
# Should use Wikipedia or Exa
assert result.primary_plan.tool_type in [ToolType.WIKIPEDIA, ToolType.EXA]
def test_exponentiation_uses_python(self):
"""Test that exponentiation uses Python, not calculator."""
questions = [
"What is 2^8?",
"Calculate 3 to the power of 4",
"What is 5**3?"
]
for question in questions:
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.PYTHON
assert "variable_to_return" in result.primary_plan.parameters
def test_artist_discography_specific_search(self):
"""Test that artist discography uses targeted search."""
question = "What albums did Mercedes Sosa release between 2000 and 2009?"
result = self.selector.select_optimal_tools(question)
assert result.primary_plan.tool_type == ToolType.EXA
# Should have specific search parameters
assert "Mercedes Sosa" in str(result.primary_plan.parameters).replace("'", "").replace('"', '')
def test_factual_counting_authoritative_sources(self):
"""Test that factual counting uses authoritative sources."""
questions = [
"How many countries are in the world?",
"How many continents are there?",
"How many oceans exist?"
]
for question in questions:
result = self.selector.select_optimal_tools(question)
# Should use Wikipedia or Exa, not calculator
assert result.primary_plan.tool_type in [ToolType.WIKIPEDIA, ToolType.EXA]
assert result.primary_plan.tool_type != ToolType.CALCULATOR
class TestToolSelectionConfidence:
"""Test confidence scoring and selection quality."""
def setup_method(self):
"""Set up test fixtures."""
self.selector = ToolSelector()
def test_high_confidence_specific_rules(self):
"""Test that specific optimization rules have high confidence."""
high_confidence_questions = [
"How many bird species are there in the world?",
"What is 2^8?",
"What is 25 * 17?",
"https://youtube.com/watch?v=example"
]
for question in high_confidence_questions:
result = self.selector.select_optimal_tools(question)
assert result.confidence > 0.8, f"Low confidence for: {question}"
def test_success_rate_estimation(self):
"""Test success rate estimation for tool combinations."""
question = "How many bird species are there in the world?"
result = self.selector.select_optimal_tools(question)
# Should have reasonable success rate with fallbacks
assert result.estimated_success_rate > 0.7
assert result.estimated_success_rate <= 1.0
def test_fallback_strategy_quality(self):
"""Test quality of fallback strategies."""
question = "How many bird species are there in the world?"
result = self.selector.select_optimal_tools(question)
# Should have at least one fallback
assert len(result.fallback_plans) > 0
# Fallback should be different from primary
primary_tool = result.primary_plan.tool_type
fallback_tools = [plan.tool_type for plan in result.fallback_plans]
assert primary_tool not in fallback_tools
# Integration test scenarios
@pytest.mark.integration
class TestEndToEndOptimization:
"""End-to-end testing of the optimization system."""
def test_complete_optimization_pipeline(self):
"""Test the complete optimization pipeline."""
# Test questions that previously caused issues
test_cases = [
{
'question': "How many bird species are there in the world?",
'expected_tool': ToolType.WIKIPEDIA,
'should_not_use': ToolType.CALCULATOR
},
{
'question': "What is 2^8?",
'expected_tool': ToolType.PYTHON,
'should_not_use': ToolType.CALCULATOR
},
{
'question': "What albums did Mercedes Sosa release between 2000 and 2009?",
'expected_tool': ToolType.EXA,
'should_not_use': ToolType.CALCULATOR
}
]
selector = ToolSelector()
for case in test_cases:
result = selector.select_optimal_tools(case['question'])
# Check expected tool is selected
assert result.primary_plan.tool_type == case['expected_tool'], \
f"Wrong tool for: {case['question']}"
# Check problematic tool is not used
assert result.primary_plan.tool_type != case['should_not_use'], \
f"Should not use {case['should_not_use'].value} for: {case['question']}"
# Check confidence is reasonable
assert result.confidence > 0.7, \
f"Low confidence for: {case['question']}"
if __name__ == "__main__":
# Configure logging for tests
logging.basicConfig(level=logging.INFO)
# Run specific test scenarios
print("πŸ§ͺ Running Phase 4 Tool Selection Optimization Tests")
print("=" * 60)
# Test critical scenarios
test_selector = TestCriticalEvaluationScenarios()
test_selector.setup_method()
print("Testing bird species optimization...")
test_selector.test_bird_species_not_calculator()
print("βœ… Bird species test passed")
print("Testing exponentiation optimization...")
test_selector.test_exponentiation_uses_python()
print("βœ… Exponentiation test passed")
print("Testing artist discography optimization...")
test_selector.test_artist_discography_specific_search()
print("βœ… Artist discography test passed")
print("Testing factual counting optimization...")
test_selector.test_factual_counting_authoritative_sources()
print("βœ… Factual counting test passed")
print("\n🎯 All critical optimization tests passed!")
print("Phase 4 tool selection optimization is working correctly.")