File size: 7,245 Bytes
5fc6e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
from unittest.mock import patch

from fastapi.testclient import TestClient
import numpy as np
import pytest

from turing.api.app import app
from turing.api.schemas import PredictionRequest, PredictionResponse


@pytest.fixture
def client():
    """Fixture that provides a test client for the FastAPI app."""
    return TestClient(app)


@pytest.fixture
def mock_inference_engine():
    """Fixture that provides a mocked inference engine."""
    with patch('turing.api.app.inference_engine') as mock:
        yield mock


class TestHealthCheck:
    """Test suite for the health check endpoint."""
    
    def test_health_check_returns_ok(self, client):
        """Test that the health check endpoint returns status ok."""
        response = client.get("/")
        assert response.status_code == 200
        assert response.json() == {
            "status": "ok",
            "message": "Turing Code Classification API is ready."
        }


class TestPredictEndpoint:
    """Test suite for the predict endpoint."""
    
    def test_predict_success_java(self, client, mock_inference_engine):
        """Test successful prediction for Java code."""
        # Setup mock
        mock_inference_engine.predict_payload.return_value = (
            np.array([0, 1]),  # raw predictions as numpy array
            ["class", "method"],  # labels
            "run_id_123",  # run_id
            "models:/CodeBERTa_java/Production"  # artifact
        )
        
        # Make request
        request_data = {
            "texts": ["public class Main", "public void test()"],
            "language": "java"
        }
        response = client.post("/predict", json=request_data)
        
        # Assertions
        assert response.status_code == 200
        data = response.json()
        assert "predictions" in data
        assert "labels" in data
        assert "model_info" in data
        assert data["labels"] == ["class", "method"]
        assert data["model_info"]["language"] == "java"
    
    def test_predict_success_python(self, client, mock_inference_engine):
        """Test successful prediction for Python code."""
        # Setup mock
        mock_inference_engine.predict_payload.return_value = (
            np.array([1, 0]),  # raw predictions as numpy array
            ["function", "class"],  # labels
            "run_id_456",  # run_id
            "models:/CodeBERTa_python/Production"  # artifact
        )
        
        # Make request
        request_data = {
            "texts": ["def main():", "class MyClass:"],
            "language": "python"
        }
        response = client.post("/predict", json=request_data)
        
        # Assertions
        assert response.status_code == 200
        data = response.json()
        assert data["labels"] == ["function", "class"]
        assert data["model_info"]["language"] == "python"
    
    def test_predict_success_pharo(self, client, mock_inference_engine):
        """Test successful prediction for Pharo code."""
        # Setup mock
        mock_inference_engine.predict_payload.return_value = (
            np.array([0]),  # raw predictions as numpy array
            ["method"],  # labels
            "run_id_789",  # run_id
            "models:/CodeBERTa_pharo/Production"  # artifact
        )
        
        # Make request
        request_data = {
            "texts": ["initialize"],
            "language": "pharo"
        }
        response = client.post("/predict", json=request_data)
        
        # Assertions
        assert response.status_code == 200
        data = response.json()
        assert data["labels"] == ["method"]
        assert data["model_info"]["language"] == "pharo"
    
    def test_predict_missing_texts(self, client):
        """Test that prediction fails when texts are missing."""
        request_data = {
            "language": "java"
        }
        response = client.post("/predict", json=request_data)
        assert response.status_code == 422  # Validation error
    
    def test_predict_missing_language(self, client):
        """Test that prediction fails when language is missing."""
        request_data = {
            "texts": ["public class Main"]
        }
        response = client.post("/predict", json=request_data)
        assert response.status_code == 422  # Validation error
    
    def test_predict_empty_texts(self, client, mock_inference_engine):
        """Test prediction with empty texts list."""
        mock_inference_engine.predict_payload.return_value = (
            np.array([]),  # raw predictions as empty numpy array
            [],  # labels
            "run_id_000",  # run_id
            "models:/CodeBERTa_java/Production"  # artifact
        )
        
        request_data = {
            "texts": [],
            "language": "java"
        }
        response = client.post("/predict", json=request_data)
        
        # Should succeed with empty results
        assert response.status_code == 200
        data = response.json()
        assert data["predictions"] == []
        assert data["labels"] == []
    
    def test_predict_error_handling(self, client, mock_inference_engine):
        """Test that prediction endpoint handles errors gracefully."""
        # Setup mock to raise an exception
        mock_inference_engine.predict_payload.side_effect = Exception("Model loading failed")
        
        request_data = {
            "texts": ["public class Main"],
            "language": "java"
        }
        response = client.post("/predict", json=request_data)
        
        # Should return 500 error
        assert response.status_code == 500
        assert "Model loading failed" in response.json()["detail"]
    
    def test_predict_invalid_language(self, client, mock_inference_engine):
        """Test prediction with invalid language parameter."""
        # The model might raise an error for unsupported language
        mock_inference_engine.predict_payload.side_effect = ValueError("Unsupported language: cobol")
        
        request_data = {
            "texts": ["IDENTIFICATION DIVISION."],
            "language": "cobol"
        }
        response = client.post("/predict", json=request_data)
        
        # Should return 500 error
        assert response.status_code == 500
        assert "Unsupported language" in response.json()["detail"]


class TestAPISchemas:
    """Test suite for API schemas validation."""
    
    def test_prediction_request_valid(self):
        """Test that PredictionRequest validates correct data."""
        request = PredictionRequest(
            texts=["public void main"],
            language="java"
        )
        assert request.texts == ["public void main"]
        assert request.language == "java"
    
    def test_prediction_response_valid(self):
        """Test that PredictionResponse validates correct data."""
        response = PredictionResponse(
            predictions=[0, 1],
            labels=["class", "method"],
            model_info={"artifact": "models:/CodeBERTa_java/Production", "language": "java"}
        )
        assert response.predictions == [0, 1]
        assert response.labels == ["class", "method"]
        assert response.model_info["language"] == "java"