leaderboard-test / test_fixes.py
rntc's picture
Simplify leaderboard to EMEA-sen and MEDLINE tasks only
fefe31a
#!/usr/bin/env python3
"""Test script to verify the fixes work correctly."""
import sys
import os
import json
from pathlib import Path
# Add the src directory to the path
sys.path.insert(0, str(Path(__file__).parent / "src"))
def test_dataclass_creation():
"""Test that the AutoEvalColumn dataclass can be created successfully."""
print("Testing AutoEvalColumn dataclass creation...")
try:
from src.display.utils import AutoEvalColumn, fields
# Test that we can access the fields
all_fields = fields(AutoEvalColumn)
print(f"βœ“ Successfully created AutoEvalColumn with {len(all_fields)} fields")
# Test that the average field exists
assert hasattr(AutoEvalColumn, 'average'), "Missing 'average' field"
print("βœ“ 'average' field exists")
# Test that we can access field names
field_names = [c.name for c in all_fields]
assert 'average' in field_names, "Average field not in field names"
print("βœ“ Average field accessible in field names")
return True
except Exception as e:
print(f"βœ— Error: {e}")
return False
def test_precision_from_str():
"""Test that the Precision.from_str method works correctly."""
print("Testing Precision.from_str method...")
try:
from src.display.utils import Precision
# Test different precision values
result1 = Precision.from_str("torch.float16")
assert result1 == Precision.float16, f"Expected float16, got {result1}"
print("βœ“ torch.float16 correctly parsed")
result2 = Precision.from_str("float16")
assert result2 == Precision.float16, f"Expected float16, got {result2}"
print("βœ“ float16 correctly parsed")
result3 = Precision.from_str("torch.bfloat16")
assert result3 == Precision.bfloat16, f"Expected bfloat16, got {result3}"
print("βœ“ torch.bfloat16 correctly parsed")
result4 = Precision.from_str("unknown")
assert result4 == Precision.Unknown, f"Expected Unknown, got {result4}"
print("βœ“ Unknown precision correctly parsed")
return True
except Exception as e:
print(f"βœ— Error: {e}")
return False
def test_eval_result_parsing():
"""Test that the EvalResult can parse JSON files correctly."""
print("Testing EvalResult JSON parsing...")
try:
from src.leaderboard.read_evals import EvalResult
from src.about import Tasks
# Create a sample result file
sample_result = {
"config": {
"model_name": "test/model",
"model_dtype": "torch.float16",
"model_sha": "abc123"
},
"results": {
"emea_ner": {"f1": 0.85},
"medline_ner": {"f1": 0.82}
}
}
# Write to temp file
temp_file = "/tmp/test_result.json"
with open(temp_file, 'w') as f:
json.dump(sample_result, f)
# Test parsing
result = EvalResult.init_from_json_file(temp_file)
assert result.full_model == "test/model", f"Expected test/model, got {result.full_model}"
assert result.org == "test", f"Expected test, got {result.org}"
assert result.model == "model", f"Expected model, got {result.model}"
assert result.revision == "abc123", f"Expected abc123, got {result.revision}"
print("βœ“ JSON parsing works correctly")
# Test with missing fields
sample_result_minimal = {
"config": {
"model": "test/model2"
},
"results": {
"emea_ner": {"f1": 0.75}
}
}
temp_file_minimal = "/tmp/test_result_minimal.json"
with open(temp_file_minimal, 'w') as f:
json.dump(sample_result_minimal, f)
result_minimal = EvalResult.init_from_json_file(temp_file_minimal)
assert result_minimal.full_model == "test/model2", f"Expected test/model2, got {result_minimal.full_model}"
print("βœ“ Minimal JSON parsing works correctly")
# Clean up
os.remove(temp_file)
os.remove(temp_file_minimal)
return True
except Exception as e:
print(f"βœ— Error: {e}")
return False
def test_to_dict():
"""Test that EvalResult.to_dict works correctly."""
print("Testing EvalResult.to_dict method...")
try:
from src.leaderboard.read_evals import EvalResult
from src.display.utils import Precision, ModelType, WeightType
# Create a test EvalResult
eval_result = EvalResult(
eval_name="test_model_float16",
full_model="test/model",
org="test",
model="model",
revision="abc123",
results={"emea_ner": 85.0, "medline_ner": 82.0},
precision=Precision.float16,
model_type=ModelType.FT,
weight_type=WeightType.Original,
architecture="BertForTokenClassification",
license="MIT",
likes=10,
num_params=110,
date="2023-01-01",
still_on_hub=True
)
# Test to_dict conversion
result_dict = eval_result.to_dict()
# Check that all required fields are present
assert "average" in result_dict, "Missing average field in dict"
assert result_dict["average"] == 83.5, f"Expected average 83.5, got {result_dict['average']}"
print("βœ“ to_dict method works correctly")
print(f" - Average: {result_dict['average']}")
return True
except Exception as e:
print(f"βœ— Error: {e}")
return False
def main():
"""Run all tests."""
print("Running bug fix tests...\n")
tests = [
test_dataclass_creation,
test_precision_from_str,
test_eval_result_parsing,
test_to_dict,
]
results = []
for test in tests:
print(f"\n{'='*50}")
try:
result = test()
results.append(result)
except Exception as e:
print(f"βœ— Test {test.__name__} failed with exception: {e}")
results.append(False)
print(f"\n{'='*50}")
print(f"Test Results: {sum(results)}/{len(results)} tests passed")
if all(results):
print("πŸŽ‰ All tests passed! The fixes are working correctly.")
return 0
else:
print("❌ Some tests failed. Please check the output above.")
return 1
if __name__ == "__main__":
sys.exit(main())