Spaces:
Sleeping
Sleeping
""" | |
Unit tests for APIClient class | |
""" | |
import pytest | |
import requests | |
import json | |
from unittest.mock import Mock, patch, MagicMock | |
from api_client import APIClient | |
class TestAPIClient: | |
"""Test cases for APIClient class""" | |
def setup_method(self): | |
"""Set up test fixtures before each test method.""" | |
self.base_url = "http://localhost:8000" | |
self.client = APIClient(self.base_url, timeout=30) | |
def test_init(self): | |
"""Test APIClient initialization""" | |
client = APIClient("http://localhost:8000/", timeout=45) | |
assert client.base_url == "http://localhost:8000" # Trailing slash removed | |
assert client.timeout == 45 | |
assert client.endpoints['process_text'] == '/api/process-text' | |
assert client.endpoints['health'] == '/api/health' | |
def test_send_query_success_with_data(self, mock_post): | |
"""Test successful query with data response""" | |
# Mock successful response | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = { | |
"sql": "SELECT * FROM employees", | |
"rows": [{"id": 1, "name": "John Doe"}], | |
"heading": "Employee Details", | |
"summary": "List of employees", | |
"chart": None | |
} | |
mock_post.return_value = mock_response | |
result = self.client.send_query("Show me all employees") | |
assert result["error"] == False | |
assert result["message"] == "Employee Details" | |
assert len(result["rows"]) == 1 | |
assert result["rows"][0]["name"] == "John Doe" | |
assert result["sql"] == "SELECT * FROM employees" | |
def test_send_query_success_with_message_only(self, mock_post): | |
"""Test successful query with message-only response""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = { | |
"message": "I understand your question about banking services.", | |
"model_used": "gpt-4" | |
} | |
mock_post.return_value = mock_response | |
result = self.client.send_query("What services do you offer?") | |
assert result["error"] == False | |
assert result["message"] == "I understand your question about banking services." | |
assert result["rows"] == [] | |
assert result["model_used"] == "gpt-4" | |
assert result["status"] == "message" | |
def test_send_query_with_model_parameter(self, mock_post): | |
"""Test query with specific model parameter""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = {"message": "Success"} | |
mock_post.return_value = mock_response | |
self.client.send_query("Test question", model="gpt-4-turbo") | |
# Verify the payload includes model_name | |
args, kwargs = mock_post.call_args | |
payload = kwargs['json'] | |
assert payload["question"] == "Test question" | |
assert payload["model_name"] == "gpt-4-turbo" | |
def test_send_query_connection_error(self, mock_post): | |
"""Test handling of connection errors""" | |
mock_post.side_effect = requests.exceptions.ConnectionError("Connection failed") | |
result = self.client.send_query("Test question") | |
assert result["error"] == True | |
assert "Cannot connect to the server" in result["message"] | |
assert result["rows"] == [] | |
def test_send_query_request_exception(self, mock_post): | |
"""Test handling of general request exceptions""" | |
mock_post.side_effect = requests.exceptions.RequestException("Request failed") | |
result = self.client.send_query("Test question") | |
assert result["error"] == True | |
assert "Request failed" in result["message"] | |
assert result["rows"] == [] | |
def test_send_query_unexpected_exception(self, mock_post): | |
"""Test handling of unexpected exceptions""" | |
mock_post.side_effect = ValueError("Unexpected error") | |
result = self.client.send_query("Test question") | |
assert result["error"] == True | |
assert "Exception" in result["message"] | |
assert "Unexpected error" in result["message"] | |
def test_process_response_success_with_data(self): | |
"""Test _process_response with successful data response""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = { | |
"sql": "SELECT * FROM branches", | |
"rows": [{"id": 1, "name": "Main Branch"}, {"id": 2, "name": "Downtown"}], | |
"heading": "Branch Information", | |
"summary": "List of all branches" | |
} | |
result = self.client._process_response(mock_response) | |
assert result["error"] == False | |
assert result["message"] == "Branch Information" | |
assert len(result["rows"]) == 2 | |
assert result["heading"] == "Branch Information" | |
def test_process_response_with_empty_heading(self): | |
"""Test _process_response with empty heading""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = { | |
"rows": [{"id": 1, "name": "Test"}], | |
"heading": "", | |
"summary": "" | |
} | |
result = self.client._process_response(mock_response) | |
assert result["message"] == "Here are the 1 results I found:" | |
def test_process_response_no_data_no_heading(self): | |
"""Test _process_response with no data and no heading""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = { | |
"rows": [], | |
"heading": "", | |
"summary": "" | |
} | |
result = self.client._process_response(mock_response) | |
assert result["message"] == "I could not find matching records for your query." | |
def test_process_response_error_status(self): | |
"""Test _process_response with error status code""" | |
mock_response = Mock() | |
mock_response.status_code = 400 | |
mock_response.json.return_value = { | |
"detail": "Bad request error" | |
} | |
result = self.client._process_response(mock_response) | |
assert result["error"] == True | |
assert "Error: Bad request error" in result["message"] | |
def test_process_response_invalid_json(self): | |
"""Test _process_response with invalid JSON""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.side_effect = ValueError("Invalid JSON") | |
mock_response.text = "Invalid response text" | |
result = self.client._process_response(mock_response) | |
# Should handle the JSON parsing error gracefully | |
assert "detail" in result or "message" in result | |
def test_check_health_healthy(self, mock_get): | |
"""Test health check with healthy status""" | |
mock_response = Mock() | |
mock_response.json.return_value = {"status": "healthy"} | |
mock_get.return_value = mock_response | |
status, message, level = self.client.check_health() | |
assert status == "🟢 Active" | |
assert message == "Online" | |
assert level == "success" | |
def test_check_health_degraded(self, mock_get): | |
"""Test health check with degraded status""" | |
mock_response = Mock() | |
mock_response.status_code = 503 | |
mock_response.json.return_value = {"status": "degraded"} | |
mock_get.return_value = mock_response | |
status, message, level = self.client.check_health() | |
assert status == "🟡 Degraded" | |
assert message == "Some Issues" | |
assert level == "warning" | |
def test_check_health_connection_error_with_socket_fallback(self, mock_socket, mock_get): | |
"""Test health check with connection error but socket reachable""" | |
mock_get.side_effect = requests.exceptions.RequestException("Connection failed") | |
mock_socket.return_value.close.return_value = None | |
status, message, level = self.client.check_health() | |
assert status == "🟡 Reachable" | |
assert message == "Port Open" | |
assert level == "warning" | |
def test_check_health_completely_offline(self, mock_socket, mock_get): | |
"""Test health check when completely offline""" | |
mock_get.side_effect = requests.exceptions.RequestException("Connection failed") | |
mock_socket.side_effect = Exception("Socket connection failed") | |
status, message, level = self.client.check_health() | |
assert status == "🔴 Offline" | |
assert message == "Connection Failed" | |
assert level == "error" | |
def test_get_detailed_health_success(self, mock_get): | |
"""Test get_detailed_health with successful response""" | |
expected_health = { | |
"status": "healthy", | |
"checks": {"database": "ok", "api": "ok"} | |
} | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = expected_health | |
mock_get.return_value = mock_response | |
result = self.client.get_detailed_health() | |
assert result == expected_health | |
def test_get_detailed_health_error(self, mock_get): | |
"""Test get_detailed_health with error response""" | |
mock_response = Mock() | |
mock_response.status_code = 500 | |
mock_get.return_value = mock_response | |
result = self.client.get_detailed_health() | |
assert result["status"] == "error" | |
assert "500" in result["message"] | |
def test_get_method_success(self, mock_get): | |
"""Test generic GET method with success""" | |
expected_data = {"models": ["gpt-4", "gpt-3.5"]} | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = expected_data | |
mock_get.return_value = mock_response | |
result = self.client.get("/models") | |
assert result == expected_data | |
def test_get_method_error(self, mock_get): | |
"""Test generic GET method with error""" | |
mock_response = Mock() | |
mock_response.status_code = 404 | |
mock_get.return_value = mock_response | |
result = self.client.get("/models") | |
assert result is None | |
def test_post_method_success(self, mock_post): | |
"""Test generic POST method with success""" | |
expected_data = {"result": "success"} | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = expected_data | |
mock_post.return_value = mock_response | |
result = self.client.post("/change-model", {"model": "gpt-4"}) | |
assert result == expected_data | |
def test_post_method_error(self, mock_post): | |
"""Test generic POST method with error""" | |
mock_response = Mock() | |
mock_response.status_code = 400 | |
mock_post.return_value = mock_response | |
result = self.client.post("/change-model", {"model": "invalid"}) | |
assert result is None | |
def test_url_construction(self): | |
"""Test URL construction for different endpoints""" | |
client = APIClient("http://localhost:8000/") | |
# Test that trailing slash is removed | |
assert client.base_url == "http://localhost:8000" | |
# Test endpoint URLs | |
process_url = f"{client.base_url}{client.endpoints['process_text']}" | |
assert process_url == "http://localhost:8000/api/process-text" | |
def test_send_query_with_legacy_agent_parameter(self, mock_post): | |
"""Test that legacy agent parameter is handled correctly""" | |
mock_response = Mock() | |
mock_response.status_code = 200 | |
mock_response.json.return_value = {"message": "Success"} | |
mock_post.return_value = mock_response | |
# Agent parameter should be ignored in payload | |
self.client.send_query("Test", model="gpt-4", agent="legacy_agent") | |
args, kwargs = mock_post.call_args | |
payload = kwargs['json'] | |
assert "agent" not in payload # Agent should not be in payload | |
assert payload["model_name"] == "gpt-4" | |
if __name__ == "__main__": | |
pytest.main([__file__, "-v"]) | |