File size: 6,785 Bytes
fefe31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python3
"""Test script to verify the fixes work correctly."""

import sys
import os
import json
from pathlib import Path

# Add the src directory to the path
sys.path.insert(0, str(Path(__file__).parent / "src"))

def test_dataclass_creation():
    """Test that the AutoEvalColumn dataclass can be created successfully."""
    print("Testing AutoEvalColumn dataclass creation...")
    try:
        from src.display.utils import AutoEvalColumn, fields
        
        # Test that we can access the fields
        all_fields = fields(AutoEvalColumn)
        print(f"βœ“ Successfully created AutoEvalColumn with {len(all_fields)} fields")
        
        # Test that the average field exists
        assert hasattr(AutoEvalColumn, 'average'), "Missing 'average' field"
        print("βœ“ 'average' field exists")
        
        # Test that we can access field names
        field_names = [c.name for c in all_fields]
        assert 'average' in field_names, "Average field not in field names"
        print("βœ“ Average field accessible in field names")
        
        return True
    except Exception as e:
        print(f"βœ— Error: {e}")
        return False

def test_precision_from_str():
    """Test that the Precision.from_str method works correctly."""
    print("Testing Precision.from_str method...")
    try:
        from src.display.utils import Precision
        
        # Test different precision values
        result1 = Precision.from_str("torch.float16")
        assert result1 == Precision.float16, f"Expected float16, got {result1}"
        print("βœ“ torch.float16 correctly parsed")
        
        result2 = Precision.from_str("float16")
        assert result2 == Precision.float16, f"Expected float16, got {result2}"
        print("βœ“ float16 correctly parsed")
        
        result3 = Precision.from_str("torch.bfloat16")
        assert result3 == Precision.bfloat16, f"Expected bfloat16, got {result3}"
        print("βœ“ torch.bfloat16 correctly parsed")
        
        result4 = Precision.from_str("unknown")
        assert result4 == Precision.Unknown, f"Expected Unknown, got {result4}"
        print("βœ“ Unknown precision correctly parsed")
        
        return True
    except Exception as e:
        print(f"βœ— Error: {e}")
        return False

def test_eval_result_parsing():
    """Test that the EvalResult can parse JSON files correctly."""
    print("Testing EvalResult JSON parsing...")
    try:
        from src.leaderboard.read_evals import EvalResult
        from src.about import Tasks
        
        # Create a sample result file
        sample_result = {
            "config": {
                "model_name": "test/model",
                "model_dtype": "torch.float16",
                "model_sha": "abc123"
            },
            "results": {
                "emea_ner": {"f1": 0.85},
                "medline_ner": {"f1": 0.82}
            }
        }
        
        # Write to temp file
        temp_file = "/tmp/test_result.json"
        with open(temp_file, 'w') as f:
            json.dump(sample_result, f)
        
        # Test parsing
        result = EvalResult.init_from_json_file(temp_file)
        
        assert result.full_model == "test/model", f"Expected test/model, got {result.full_model}"
        assert result.org == "test", f"Expected test, got {result.org}"
        assert result.model == "model", f"Expected model, got {result.model}"
        assert result.revision == "abc123", f"Expected abc123, got {result.revision}"
        
        print("βœ“ JSON parsing works correctly")
        
        # Test with missing fields
        sample_result_minimal = {
            "config": {
                "model": "test/model2"
            },
            "results": {
                "emea_ner": {"f1": 0.75}
            }
        }
        
        temp_file_minimal = "/tmp/test_result_minimal.json"
        with open(temp_file_minimal, 'w') as f:
            json.dump(sample_result_minimal, f)
        
        result_minimal = EvalResult.init_from_json_file(temp_file_minimal)
        assert result_minimal.full_model == "test/model2", f"Expected test/model2, got {result_minimal.full_model}"
        print("βœ“ Minimal JSON parsing works correctly")
        
        # Clean up
        os.remove(temp_file)
        os.remove(temp_file_minimal)
        
        return True
    except Exception as e:
        print(f"βœ— Error: {e}")
        return False

def test_to_dict():
    """Test that EvalResult.to_dict works correctly."""
    print("Testing EvalResult.to_dict method...")
    try:
        from src.leaderboard.read_evals import EvalResult
        from src.display.utils import Precision, ModelType, WeightType
        
        # Create a test EvalResult
        eval_result = EvalResult(
            eval_name="test_model_float16",
            full_model="test/model",
            org="test",
            model="model",
            revision="abc123",
            results={"emea_ner": 85.0, "medline_ner": 82.0},
            precision=Precision.float16,
            model_type=ModelType.FT,
            weight_type=WeightType.Original,
            architecture="BertForTokenClassification",
            license="MIT",
            likes=10,
            num_params=110,
            date="2023-01-01",
            still_on_hub=True
        )
        
        # Test to_dict conversion
        result_dict = eval_result.to_dict()
        
        # Check that all required fields are present
        assert "average" in result_dict, "Missing average field in dict"
        assert result_dict["average"] == 83.5, f"Expected average 83.5, got {result_dict['average']}"
        
        print("βœ“ to_dict method works correctly")
        print(f"  - Average: {result_dict['average']}")
        
        return True
    except Exception as e:
        print(f"βœ— Error: {e}")
        return False

def main():
    """Run all tests."""
    print("Running bug fix tests...\n")
    
    tests = [
        test_dataclass_creation,
        test_precision_from_str,
        test_eval_result_parsing,
        test_to_dict,
    ]
    
    results = []
    for test in tests:
        print(f"\n{'='*50}")
        try:
            result = test()
            results.append(result)
        except Exception as e:
            print(f"βœ— Test {test.__name__} failed with exception: {e}")
            results.append(False)
    
    print(f"\n{'='*50}")
    print(f"Test Results: {sum(results)}/{len(results)} tests passed")
    
    if all(results):
        print("πŸŽ‰ All tests passed! The fixes are working correctly.")
        return 0
    else:
        print("❌ Some tests failed. Please check the output above.")
        return 1

if __name__ == "__main__":
    sys.exit(main())