#!/usr/bin/env python3 """Test script to verify cost tracking is working properly.""" import logging import json from unittest.mock import Mock, patch from src.services.cost_tracker import CostTracker from src.agents.unique_indices_combinator import UniqueIndicesCombinator from src.agents.unique_indices_loop_agent import UniqueIndicesLoopAgent from src.config.settings import settings # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def test_cost_tracking(): """Test that cost tracking works properly with the new agents.""" # Create a cost tracker cost_tracker = CostTracker() # Create mock context ctx = { "text": "This is a test document with some content.", "unique_indices": ["Protein Lot", "Peptide", "Timepoint", "Modification"], "unique_indices_descriptions": { "Protein Lot": { "description": "Protein lot identifier", "format": "String", "examples": "P066_L14_H31_0-hulgG-LALAPG-FJB", "possible_values": "" }, "Peptide": { "description": "Peptide sequence", "format": "String", "examples": "QVQLQQSGPGLVQPSQSLSITCTVSDFSLAR", "possible_values": "" } }, "fields": ["Chain", "Percentage", "Seq Loc"], "field_descriptions": { "Chain": { "description": "Heavy or Light chain", "format": "String", "examples": "Heavy", "possible_values": "Heavy, Light" } }, "document_context": "Biotech document", "cost_tracker": cost_tracker } # Mock LLM responses mock_combinations = [ { "Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB", "Peptide": "PLTFGAGTK", "Timepoint": "0w", "Modification": "Clipping" }, { "Protein Lot": "P066_L14_H31_0-hulgG-LALAPG-FJB", "Peptide": "PLTFGAGTK", "Timepoint": "4w", "Modification": "Clipping" } ] mock_additional_fields = { "Chain": "Heavy", "Percentage": "90.0", "Seq Loc": "HC(1-31)" } # Test UniqueIndicesCombinator logger.info("Testing UniqueIndicesCombinator cost tracking...") with patch('openai.responses.create') as mock_create: # Mock the LLM response for combinations mock_create.return_value = Mock( output=[Mock(content=[Mock(text=json.dumps(mock_combinations))])], usage=Mock(input_tokens=1500, output_tokens=300) ) combinator = UniqueIndicesCombinator() result = combinator.execute(ctx) logger.info(f"Combinator result: {result}") logger.info(f"Cost tracker after combinator:") logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}") logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}") logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}") # Verify cost tracking worked assert cost_tracker.llm_input_tokens == 1500 assert cost_tracker.llm_output_tokens == 300 assert len(cost_tracker.llm_calls) == 1 assert cost_tracker.llm_calls[0].description == "Unique Indices Combination Extraction" # Test UniqueIndicesLoopAgent logger.info("Testing UniqueIndicesLoopAgent cost tracking...") # Set the results from combinator ctx["results"] = mock_combinations with patch('openai.responses.create') as mock_create: # Mock the LLM response for additional fields (will be called twice, once for each combination) mock_create.return_value = Mock( output=[Mock(content=[Mock(text=json.dumps(mock_additional_fields))])], usage=Mock(input_tokens=800, output_tokens=150) ) loop_agent = UniqueIndicesLoopAgent() result = loop_agent.execute(ctx) logger.info(f"Loop agent result: {result}") logger.info(f"Cost tracker after loop agent:") logger.info(f" Input tokens: {cost_tracker.llm_input_tokens}") logger.info(f" Output tokens: {cost_tracker.llm_output_tokens}") logger.info(f" LLM calls: {len(cost_tracker.llm_calls)}") # Verify cost tracking worked for both calls assert cost_tracker.llm_input_tokens == 1500 + (800 * 2) # Combinator + 2 loop iterations assert cost_tracker.llm_output_tokens == 300 + (150 * 2) # Combinator + 2 loop iterations assert len(cost_tracker.llm_calls) == 3 # 1 combinator + 2 loop iterations # Test detailed costs table logger.info("Testing detailed costs table...") costs_df = cost_tracker.get_detailed_costs_table() logger.info(f"Costs table:\n{costs_df}") # Verify the table has the expected structure assert len(costs_df) == 4 # 3 calls + 1 total row assert "Description" in costs_df.columns assert "Input Tokens" in costs_df.columns assert "Output Tokens" in costs_df.columns assert "Total Cost" in costs_df.columns logger.info("All cost tracking tests passed!") if __name__ == "__main__": test_cost_tracking()