#!/usr/bin/env python3 """Test script to verify the fixes work correctly.""" import sys import os import json from pathlib import Path # Add the src directory to the path sys.path.insert(0, str(Path(__file__).parent / "src")) def test_dataclass_creation(): """Test that the AutoEvalColumn dataclass can be created successfully.""" print("Testing AutoEvalColumn dataclass creation...") try: from src.display.utils import AutoEvalColumn, fields # Test that we can access the fields all_fields = fields(AutoEvalColumn) print(f"✓ Successfully created AutoEvalColumn with {len(all_fields)} fields") # Test that the average field exists assert hasattr(AutoEvalColumn, 'average'), "Missing 'average' field" print("✓ 'average' field exists") # Test that we can access field names field_names = [c.name for c in all_fields] assert 'average' in field_names, "Average field not in field names" print("✓ Average field accessible in field names") return True except Exception as e: print(f"✗ Error: {e}") return False def test_precision_from_str(): """Test that the Precision.from_str method works correctly.""" print("Testing Precision.from_str method...") try: from src.display.utils import Precision # Test different precision values result1 = Precision.from_str("torch.float16") assert result1 == Precision.float16, f"Expected float16, got {result1}" print("✓ torch.float16 correctly parsed") result2 = Precision.from_str("float16") assert result2 == Precision.float16, f"Expected float16, got {result2}" print("✓ float16 correctly parsed") result3 = Precision.from_str("torch.bfloat16") assert result3 == Precision.bfloat16, f"Expected bfloat16, got {result3}" print("✓ torch.bfloat16 correctly parsed") result4 = Precision.from_str("unknown") assert result4 == Precision.Unknown, f"Expected Unknown, got {result4}" print("✓ Unknown precision correctly parsed") return True except Exception as e: print(f"✗ Error: {e}") return False def test_eval_result_parsing(): """Test that the EvalResult can parse JSON files correctly.""" print("Testing EvalResult JSON parsing...") try: from src.leaderboard.read_evals import EvalResult from src.about import Tasks # Create a sample result file sample_result = { "config": { "model_name": "test/model", "model_dtype": "torch.float16", "model_sha": "abc123" }, "results": { "emea_ner": {"f1": 0.85}, "medline_ner": {"f1": 0.82} } } # Write to temp file temp_file = "/tmp/test_result.json" with open(temp_file, 'w') as f: json.dump(sample_result, f) # Test parsing result = EvalResult.init_from_json_file(temp_file) assert result.full_model == "test/model", f"Expected test/model, got {result.full_model}" assert result.org == "test", f"Expected test, got {result.org}" assert result.model == "model", f"Expected model, got {result.model}" assert result.revision == "abc123", f"Expected abc123, got {result.revision}" print("✓ JSON parsing works correctly") # Test with missing fields sample_result_minimal = { "config": { "model": "test/model2" }, "results": { "emea_ner": {"f1": 0.75} } } temp_file_minimal = "/tmp/test_result_minimal.json" with open(temp_file_minimal, 'w') as f: json.dump(sample_result_minimal, f) result_minimal = EvalResult.init_from_json_file(temp_file_minimal) assert result_minimal.full_model == "test/model2", f"Expected test/model2, got {result_minimal.full_model}" print("✓ Minimal JSON parsing works correctly") # Clean up os.remove(temp_file) os.remove(temp_file_minimal) return True except Exception as e: print(f"✗ Error: {e}") return False def test_to_dict(): """Test that EvalResult.to_dict works correctly.""" print("Testing EvalResult.to_dict method...") try: from src.leaderboard.read_evals import EvalResult from src.display.utils import Precision, ModelType, WeightType # Create a test EvalResult eval_result = EvalResult( eval_name="test_model_float16", full_model="test/model", org="test", model="model", revision="abc123", results={"emea_ner": 85.0, "medline_ner": 82.0}, precision=Precision.float16, model_type=ModelType.FT, weight_type=WeightType.Original, architecture="BertForTokenClassification", license="MIT", likes=10, num_params=110, date="2023-01-01", still_on_hub=True ) # Test to_dict conversion result_dict = eval_result.to_dict() # Check that all required fields are present assert "average" in result_dict, "Missing average field in dict" assert result_dict["average"] == 83.5, f"Expected average 83.5, got {result_dict['average']}" print("✓ to_dict method works correctly") print(f" - Average: {result_dict['average']}") return True except Exception as e: print(f"✗ Error: {e}") return False def main(): """Run all tests.""" print("Running bug fix tests...\n") tests = [ test_dataclass_creation, test_precision_from_str, test_eval_result_parsing, test_to_dict, ] results = [] for test in tests: print(f"\n{'='*50}") try: result = test() results.append(result) except Exception as e: print(f"✗ Test {test.__name__} failed with exception: {e}") results.append(False) print(f"\n{'='*50}") print(f"Test Results: {sum(results)}/{len(results)} tests passed") if all(results): print("🎉 All tests passed! The fixes are working correctly.") return 0 else: print("❌ Some tests failed. Please check the output above.") return 1 if __name__ == "__main__": sys.exit(main())