Spaces:
Runtime error
Runtime error
File size: 7,245 Bytes
5fc6e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from unittest.mock import patch
from fastapi.testclient import TestClient
import numpy as np
import pytest
from turing.api.app import app
from turing.api.schemas import PredictionRequest, PredictionResponse
@pytest.fixture
def client():
"""Fixture that provides a test client for the FastAPI app."""
return TestClient(app)
@pytest.fixture
def mock_inference_engine():
"""Fixture that provides a mocked inference engine."""
with patch('turing.api.app.inference_engine') as mock:
yield mock
class TestHealthCheck:
"""Test suite for the health check endpoint."""
def test_health_check_returns_ok(self, client):
"""Test that the health check endpoint returns status ok."""
response = client.get("/")
assert response.status_code == 200
assert response.json() == {
"status": "ok",
"message": "Turing Code Classification API is ready."
}
class TestPredictEndpoint:
"""Test suite for the predict endpoint."""
def test_predict_success_java(self, client, mock_inference_engine):
"""Test successful prediction for Java code."""
# Setup mock
mock_inference_engine.predict_payload.return_value = (
np.array([0, 1]), # raw predictions as numpy array
["class", "method"], # labels
"run_id_123", # run_id
"models:/CodeBERTa_java/Production" # artifact
)
# Make request
request_data = {
"texts": ["public class Main", "public void test()"],
"language": "java"
}
response = client.post("/predict", json=request_data)
# Assertions
assert response.status_code == 200
data = response.json()
assert "predictions" in data
assert "labels" in data
assert "model_info" in data
assert data["labels"] == ["class", "method"]
assert data["model_info"]["language"] == "java"
def test_predict_success_python(self, client, mock_inference_engine):
"""Test successful prediction for Python code."""
# Setup mock
mock_inference_engine.predict_payload.return_value = (
np.array([1, 0]), # raw predictions as numpy array
["function", "class"], # labels
"run_id_456", # run_id
"models:/CodeBERTa_python/Production" # artifact
)
# Make request
request_data = {
"texts": ["def main():", "class MyClass:"],
"language": "python"
}
response = client.post("/predict", json=request_data)
# Assertions
assert response.status_code == 200
data = response.json()
assert data["labels"] == ["function", "class"]
assert data["model_info"]["language"] == "python"
def test_predict_success_pharo(self, client, mock_inference_engine):
"""Test successful prediction for Pharo code."""
# Setup mock
mock_inference_engine.predict_payload.return_value = (
np.array([0]), # raw predictions as numpy array
["method"], # labels
"run_id_789", # run_id
"models:/CodeBERTa_pharo/Production" # artifact
)
# Make request
request_data = {
"texts": ["initialize"],
"language": "pharo"
}
response = client.post("/predict", json=request_data)
# Assertions
assert response.status_code == 200
data = response.json()
assert data["labels"] == ["method"]
assert data["model_info"]["language"] == "pharo"
def test_predict_missing_texts(self, client):
"""Test that prediction fails when texts are missing."""
request_data = {
"language": "java"
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422 # Validation error
def test_predict_missing_language(self, client):
"""Test that prediction fails when language is missing."""
request_data = {
"texts": ["public class Main"]
}
response = client.post("/predict", json=request_data)
assert response.status_code == 422 # Validation error
def test_predict_empty_texts(self, client, mock_inference_engine):
"""Test prediction with empty texts list."""
mock_inference_engine.predict_payload.return_value = (
np.array([]), # raw predictions as empty numpy array
[], # labels
"run_id_000", # run_id
"models:/CodeBERTa_java/Production" # artifact
)
request_data = {
"texts": [],
"language": "java"
}
response = client.post("/predict", json=request_data)
# Should succeed with empty results
assert response.status_code == 200
data = response.json()
assert data["predictions"] == []
assert data["labels"] == []
def test_predict_error_handling(self, client, mock_inference_engine):
"""Test that prediction endpoint handles errors gracefully."""
# Setup mock to raise an exception
mock_inference_engine.predict_payload.side_effect = Exception("Model loading failed")
request_data = {
"texts": ["public class Main"],
"language": "java"
}
response = client.post("/predict", json=request_data)
# Should return 500 error
assert response.status_code == 500
assert "Model loading failed" in response.json()["detail"]
def test_predict_invalid_language(self, client, mock_inference_engine):
"""Test prediction with invalid language parameter."""
# The model might raise an error for unsupported language
mock_inference_engine.predict_payload.side_effect = ValueError("Unsupported language: cobol")
request_data = {
"texts": ["IDENTIFICATION DIVISION."],
"language": "cobol"
}
response = client.post("/predict", json=request_data)
# Should return 500 error
assert response.status_code == 500
assert "Unsupported language" in response.json()["detail"]
class TestAPISchemas:
"""Test suite for API schemas validation."""
def test_prediction_request_valid(self):
"""Test that PredictionRequest validates correct data."""
request = PredictionRequest(
texts=["public void main"],
language="java"
)
assert request.texts == ["public void main"]
assert request.language == "java"
def test_prediction_response_valid(self):
"""Test that PredictionResponse validates correct data."""
response = PredictionResponse(
predictions=[0, 1],
labels=["class", "method"],
model_info={"artifact": "models:/CodeBERTa_java/Production", "language": "java"}
)
assert response.predictions == [0, 1]
assert response.labels == ["class", "method"]
assert response.model_info["language"] == "java"
|