Spaces:
Running
Running
""" | |
Comprehensive Testing for Phase 4 Tool Selection Optimization | |
This test suite validates the tool selection optimization implementation | |
to ensure it addresses the critical evaluation issues identified: | |
1. Inappropriate tool selection for specific question types | |
2. Tool usage pattern optimization | |
3. Dynamic tool selection based on question analysis | |
4. Tool execution strategy optimization | |
""" | |
import pytest | |
import logging | |
from typing import List, Dict, Any | |
from unittest.mock import Mock, patch | |
# Import the modules to test | |
from utils.enhanced_question_classifier import ( | |
EnhancedQuestionClassifier, | |
ClassificationResult, | |
QuestionType, | |
ToolType | |
) | |
from utils.tool_selector import ( | |
ToolSelector, | |
ToolSelectionResult, | |
ToolExecutionPlan, | |
ToolExecutionStrategy, | |
ToolPriority | |
) | |
from agents.fixed_enhanced_unified_agno_agent import FixedGAIAAgent | |
logger = logging.getLogger(__name__) | |
class TestEnhancedQuestionClassifier: | |
"""Test the enhanced question classifier.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.classifier = EnhancedQuestionClassifier() | |
def test_bird_species_classification(self): | |
"""Test classification of bird species counting questions.""" | |
question = "How many bird species are there in the world?" | |
result = self.classifier.classify_question(question) | |
assert result.question_type == QuestionType.KNOWLEDGE_FACTS | |
assert result.sub_category == "counting_facts" | |
assert ToolType.WIKIPEDIA in result.recommended_tools | |
assert ToolType.EXA in result.recommended_tools | |
assert result.confidence > 0.8 | |
assert "bird species" in result.reasoning.lower() | |
def test_exponentiation_classification(self): | |
"""Test classification of exponentiation questions.""" | |
question = "What is 2^8?" | |
result = self.classifier.classify_question(question) | |
assert result.question_type == QuestionType.MATHEMATICAL | |
assert result.sub_category == "exponentiation" | |
assert ToolType.PYTHON in result.recommended_tools | |
assert result.confidence > 0.8 | |
assert "exponentiation" in result.reasoning.lower() | |
def test_artist_discography_classification(self): | |
"""Test classification of artist discography questions.""" | |
question = "What albums did Mercedes Sosa release between 2000 and 2009?" | |
result = self.classifier.classify_question(question) | |
assert result.question_type == QuestionType.WEB_RESEARCH | |
assert result.sub_category == "artist_discography" | |
assert ToolType.EXA in result.recommended_tools | |
assert result.confidence > 0.7 | |
assert "discography" in result.reasoning.lower() | |
def test_basic_arithmetic_classification(self): | |
"""Test classification of basic arithmetic questions.""" | |
question = "What is 25 * 17?" | |
result = self.classifier.classify_question(question) | |
assert result.question_type == QuestionType.MATHEMATICAL | |
assert result.sub_category == "basic_arithmetic" | |
assert ToolType.CALCULATOR in result.recommended_tools | |
assert result.confidence > 0.9 | |
def test_youtube_content_classification(self): | |
"""Test classification of YouTube content questions.""" | |
question = "What is discussed in this YouTube video? https://youtube.com/watch?v=example" | |
result = self.classifier.classify_question(question) | |
assert result.question_type == QuestionType.VIDEO_ANALYSIS | |
assert ToolType.YOUTUBE in result.recommended_tools | |
assert result.confidence > 0.8 | |
def test_multimodal_image_classification(self): | |
"""Test classification with image attachments.""" | |
question = "What do you see in this image?" | |
files = [{"type": "image", "path": "test.jpg"}] | |
result = self.classifier.classify_question(question, files) | |
assert result.question_type == QuestionType.MULTIMODAL | |
assert result.sub_category == "image_analysis" | |
assert ToolType.IMAGE_ANALYSIS in result.recommended_tools | |
assert result.confidence > 0.8 | |
class TestToolSelector: | |
"""Test the tool selector optimization.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.selector = ToolSelector() | |
def test_bird_species_optimization_rule(self): | |
"""Test optimization rule for bird species counting.""" | |
question = "How many bird species are there in the world?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.WIKIPEDIA | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert result.confidence > 0.9 | |
assert "bird species counting" in result.optimization_reasoning.lower() | |
assert len(result.fallback_plans) > 0 | |
assert result.fallback_plans[0].tool_type == ToolType.EXA | |
def test_exponentiation_optimization_rule(self): | |
"""Test optimization rule for exponentiation.""" | |
question = "What is 2^8?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.PYTHON | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert result.confidence > 0.8 | |
assert "exponentiation" in result.optimization_reasoning.lower() | |
assert "variable_to_return" in result.primary_plan.parameters | |
def test_artist_discography_optimization_rule(self): | |
"""Test optimization rule for artist discography.""" | |
question = "What albums did Mercedes Sosa release between 2000 and 2009?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.EXA | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert result.confidence > 0.8 | |
assert "discography" in result.optimization_reasoning.lower() | |
def test_basic_arithmetic_optimization_rule(self): | |
"""Test optimization rule for basic arithmetic.""" | |
question = "What is 25 * 17?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.CALCULATOR | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert result.confidence > 0.9 | |
assert "arithmetic" in result.optimization_reasoning.lower() | |
def test_youtube_optimization_rule(self): | |
"""Test optimization rule for YouTube content.""" | |
question = "What is discussed in https://youtube.com/watch?v=example?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.YOUTUBE | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert result.confidence > 0.9 | |
assert "youtube" in result.optimization_reasoning.lower() | |
def test_general_classification_fallback(self): | |
"""Test fallback to general classification when no specific rule matches.""" | |
question = "What is the weather like today?" | |
result = self.selector.select_optimal_tools(question) | |
# Should fall back to general classification | |
assert result.primary_plan.tool_type in [ToolType.EXA, ToolType.WIKIPEDIA] | |
assert result.execution_strategy == ToolExecutionStrategy.SEQUENTIAL | |
assert "Classification-based selection" in result.optimization_reasoning | |
def test_tool_performance_tracking(self): | |
"""Test tool performance tracking functionality.""" | |
# Update performance for a tool | |
self.selector.update_tool_performance(ToolType.WIKIPEDIA, True, 5.0, 0.9) | |
# Check that performance was updated | |
stats = self.selector.performance_stats[ToolType.WIKIPEDIA] | |
assert stats['usage_count'] == 1 | |
assert stats['failure_count'] == 0 | |
assert stats['success_rate'] > 0.8 | |
assert stats['avg_response_time'] < 10.0 | |
def test_performance_report_generation(self): | |
"""Test performance report generation.""" | |
report = self.selector.get_tool_performance_report() | |
assert 'tool_performance' in report | |
assert 'optimization_rules' in report | |
assert 'performance_summary' in report | |
assert len(report['optimization_rules']) > 0 | |
assert 'avg_success_rate' in report['performance_summary'] | |
class TestFixedGAIAAgentIntegration: | |
"""Test integration of tool selection optimization in the main agent.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
# Mock the agent initialization to avoid API key requirements | |
with patch('agents.fixed_enhanced_unified_agno_agent.MistralChat'), \ | |
patch('agents.fixed_enhanced_unified_agno_agent.Agent'): | |
self.agent = FixedGAIAAgent() | |
self.agent.available = True | |
self.agent.agent = Mock() | |
def test_tool_optimization_integration(self): | |
"""Test that tool optimization is properly integrated.""" | |
# Check that optimization components are initialized | |
assert hasattr(self.agent, 'question_classifier') | |
assert hasattr(self.agent, 'tool_selector') | |
assert isinstance(self.agent.question_classifier, EnhancedQuestionClassifier) | |
assert isinstance(self.agent.tool_selector, ToolSelector) | |
def test_apply_tool_optimizations_method(self): | |
"""Test the _apply_tool_optimizations method.""" | |
question = "What is 2^8?" | |
# Create a mock tool selection result | |
mock_selection = ToolSelectionResult( | |
primary_plan=ToolExecutionPlan( | |
tool_type=ToolType.PYTHON, | |
priority=ToolPriority.CRITICAL, | |
parameters={"variable_to_return": "result"}, | |
expected_output="Numeric result", | |
success_criteria="Output contains: result", | |
fallback_tools=[], | |
timeout_seconds=30, | |
retry_count=1 | |
), | |
fallback_plans=[], | |
execution_strategy=ToolExecutionStrategy.SEQUENTIAL, | |
optimization_reasoning="Exponentiation requires Python", | |
confidence=0.9, | |
estimated_success_rate=0.85 | |
) | |
# Test the optimization application | |
optimized_question = self.agent._apply_tool_optimizations(question, mock_selection) | |
assert "TOOL OPTIMIZATION GUIDANCE" in optimized_question | |
assert "python" in optimized_question.lower() | |
assert "confidence: 0.9" in optimized_question.lower() | |
assert question in optimized_question | |
class TestCriticalEvaluationScenarios: | |
"""Test scenarios that address the specific evaluation issues.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.selector = ToolSelector() | |
def test_bird_species_not_calculator(self): | |
"""Test that bird species questions don't use calculator (addresses '468' issue).""" | |
question = "How many bird species are there in the world?" | |
result = self.selector.select_optimal_tools(question) | |
# Should NOT use calculator | |
assert result.primary_plan.tool_type != ToolType.CALCULATOR | |
# Should use Wikipedia or Exa | |
assert result.primary_plan.tool_type in [ToolType.WIKIPEDIA, ToolType.EXA] | |
def test_exponentiation_uses_python(self): | |
"""Test that exponentiation uses Python, not calculator.""" | |
questions = [ | |
"What is 2^8?", | |
"Calculate 3 to the power of 4", | |
"What is 5**3?" | |
] | |
for question in questions: | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.PYTHON | |
assert "variable_to_return" in result.primary_plan.parameters | |
def test_artist_discography_specific_search(self): | |
"""Test that artist discography uses targeted search.""" | |
question = "What albums did Mercedes Sosa release between 2000 and 2009?" | |
result = self.selector.select_optimal_tools(question) | |
assert result.primary_plan.tool_type == ToolType.EXA | |
# Should have specific search parameters | |
assert "Mercedes Sosa" in str(result.primary_plan.parameters).replace("'", "").replace('"', '') | |
def test_factual_counting_authoritative_sources(self): | |
"""Test that factual counting uses authoritative sources.""" | |
questions = [ | |
"How many countries are in the world?", | |
"How many continents are there?", | |
"How many oceans exist?" | |
] | |
for question in questions: | |
result = self.selector.select_optimal_tools(question) | |
# Should use Wikipedia or Exa, not calculator | |
assert result.primary_plan.tool_type in [ToolType.WIKIPEDIA, ToolType.EXA] | |
assert result.primary_plan.tool_type != ToolType.CALCULATOR | |
class TestToolSelectionConfidence: | |
"""Test confidence scoring and selection quality.""" | |
def setup_method(self): | |
"""Set up test fixtures.""" | |
self.selector = ToolSelector() | |
def test_high_confidence_specific_rules(self): | |
"""Test that specific optimization rules have high confidence.""" | |
high_confidence_questions = [ | |
"How many bird species are there in the world?", | |
"What is 2^8?", | |
"What is 25 * 17?", | |
"https://youtube.com/watch?v=example" | |
] | |
for question in high_confidence_questions: | |
result = self.selector.select_optimal_tools(question) | |
assert result.confidence > 0.8, f"Low confidence for: {question}" | |
def test_success_rate_estimation(self): | |
"""Test success rate estimation for tool combinations.""" | |
question = "How many bird species are there in the world?" | |
result = self.selector.select_optimal_tools(question) | |
# Should have reasonable success rate with fallbacks | |
assert result.estimated_success_rate > 0.7 | |
assert result.estimated_success_rate <= 1.0 | |
def test_fallback_strategy_quality(self): | |
"""Test quality of fallback strategies.""" | |
question = "How many bird species are there in the world?" | |
result = self.selector.select_optimal_tools(question) | |
# Should have at least one fallback | |
assert len(result.fallback_plans) > 0 | |
# Fallback should be different from primary | |
primary_tool = result.primary_plan.tool_type | |
fallback_tools = [plan.tool_type for plan in result.fallback_plans] | |
assert primary_tool not in fallback_tools | |
# Integration test scenarios | |
class TestEndToEndOptimization: | |
"""End-to-end testing of the optimization system.""" | |
def test_complete_optimization_pipeline(self): | |
"""Test the complete optimization pipeline.""" | |
# Test questions that previously caused issues | |
test_cases = [ | |
{ | |
'question': "How many bird species are there in the world?", | |
'expected_tool': ToolType.WIKIPEDIA, | |
'should_not_use': ToolType.CALCULATOR | |
}, | |
{ | |
'question': "What is 2^8?", | |
'expected_tool': ToolType.PYTHON, | |
'should_not_use': ToolType.CALCULATOR | |
}, | |
{ | |
'question': "What albums did Mercedes Sosa release between 2000 and 2009?", | |
'expected_tool': ToolType.EXA, | |
'should_not_use': ToolType.CALCULATOR | |
} | |
] | |
selector = ToolSelector() | |
for case in test_cases: | |
result = selector.select_optimal_tools(case['question']) | |
# Check expected tool is selected | |
assert result.primary_plan.tool_type == case['expected_tool'], \ | |
f"Wrong tool for: {case['question']}" | |
# Check problematic tool is not used | |
assert result.primary_plan.tool_type != case['should_not_use'], \ | |
f"Should not use {case['should_not_use'].value} for: {case['question']}" | |
# Check confidence is reasonable | |
assert result.confidence > 0.7, \ | |
f"Low confidence for: {case['question']}" | |
if __name__ == "__main__": | |
# Configure logging for tests | |
logging.basicConfig(level=logging.INFO) | |
# Run specific test scenarios | |
print("π§ͺ Running Phase 4 Tool Selection Optimization Tests") | |
print("=" * 60) | |
# Test critical scenarios | |
test_selector = TestCriticalEvaluationScenarios() | |
test_selector.setup_method() | |
print("Testing bird species optimization...") | |
test_selector.test_bird_species_not_calculator() | |
print("β Bird species test passed") | |
print("Testing exponentiation optimization...") | |
test_selector.test_exponentiation_uses_python() | |
print("β Exponentiation test passed") | |
print("Testing artist discography optimization...") | |
test_selector.test_artist_discography_specific_search() | |
print("β Artist discography test passed") | |
print("Testing factual counting optimization...") | |
test_selector.test_factual_counting_authoritative_sources() | |
print("β Factual counting test passed") | |
print("\nπ― All critical optimization tests passed!") | |
print("Phase 4 tool selection optimization is working correctly.") |