AISqlGeneratorApp / src /test_api_client.py
Vivek0912's picture
added new code
cce43bc
"""
Unit tests for APIClient class
"""
import pytest
import requests
import json
from unittest.mock import Mock, patch, MagicMock
from api_client import APIClient
class TestAPIClient:
"""Test cases for APIClient class"""
def setup_method(self):
"""Set up test fixtures before each test method."""
self.base_url = "http://localhost:8000"
self.client = APIClient(self.base_url, timeout=30)
def test_init(self):
"""Test APIClient initialization"""
client = APIClient("http://localhost:8000/", timeout=45)
assert client.base_url == "http://localhost:8000" # Trailing slash removed
assert client.timeout == 45
assert client.endpoints['process_text'] == '/api/process-text'
assert client.endpoints['health'] == '/api/health'
@patch('requests.post')
def test_send_query_success_with_data(self, mock_post):
"""Test successful query with data response"""
# Mock successful response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"sql": "SELECT * FROM employees",
"rows": [{"id": 1, "name": "John Doe"}],
"heading": "Employee Details",
"summary": "List of employees",
"chart": None
}
mock_post.return_value = mock_response
result = self.client.send_query("Show me all employees")
assert result["error"] == False
assert result["message"] == "Employee Details"
assert len(result["rows"]) == 1
assert result["rows"][0]["name"] == "John Doe"
assert result["sql"] == "SELECT * FROM employees"
@patch('requests.post')
def test_send_query_success_with_message_only(self, mock_post):
"""Test successful query with message-only response"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"message": "I understand your question about banking services.",
"model_used": "gpt-4"
}
mock_post.return_value = mock_response
result = self.client.send_query("What services do you offer?")
assert result["error"] == False
assert result["message"] == "I understand your question about banking services."
assert result["rows"] == []
assert result["model_used"] == "gpt-4"
assert result["status"] == "message"
@patch('requests.post')
def test_send_query_with_model_parameter(self, mock_post):
"""Test query with specific model parameter"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"message": "Success"}
mock_post.return_value = mock_response
self.client.send_query("Test question", model="gpt-4-turbo")
# Verify the payload includes model_name
args, kwargs = mock_post.call_args
payload = kwargs['json']
assert payload["question"] == "Test question"
assert payload["model_name"] == "gpt-4-turbo"
@patch('requests.post')
def test_send_query_connection_error(self, mock_post):
"""Test handling of connection errors"""
mock_post.side_effect = requests.exceptions.ConnectionError("Connection failed")
result = self.client.send_query("Test question")
assert result["error"] == True
assert "Cannot connect to the server" in result["message"]
assert result["rows"] == []
@patch('requests.post')
def test_send_query_request_exception(self, mock_post):
"""Test handling of general request exceptions"""
mock_post.side_effect = requests.exceptions.RequestException("Request failed")
result = self.client.send_query("Test question")
assert result["error"] == True
assert "Request failed" in result["message"]
assert result["rows"] == []
@patch('requests.post')
def test_send_query_unexpected_exception(self, mock_post):
"""Test handling of unexpected exceptions"""
mock_post.side_effect = ValueError("Unexpected error")
result = self.client.send_query("Test question")
assert result["error"] == True
assert "Exception" in result["message"]
assert "Unexpected error" in result["message"]
def test_process_response_success_with_data(self):
"""Test _process_response with successful data response"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"sql": "SELECT * FROM branches",
"rows": [{"id": 1, "name": "Main Branch"}, {"id": 2, "name": "Downtown"}],
"heading": "Branch Information",
"summary": "List of all branches"
}
result = self.client._process_response(mock_response)
assert result["error"] == False
assert result["message"] == "Branch Information"
assert len(result["rows"]) == 2
assert result["heading"] == "Branch Information"
def test_process_response_with_empty_heading(self):
"""Test _process_response with empty heading"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"rows": [{"id": 1, "name": "Test"}],
"heading": "",
"summary": ""
}
result = self.client._process_response(mock_response)
assert result["message"] == "Here are the 1 results I found:"
def test_process_response_no_data_no_heading(self):
"""Test _process_response with no data and no heading"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"rows": [],
"heading": "",
"summary": ""
}
result = self.client._process_response(mock_response)
assert result["message"] == "I could not find matching records for your query."
def test_process_response_error_status(self):
"""Test _process_response with error status code"""
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {
"detail": "Bad request error"
}
result = self.client._process_response(mock_response)
assert result["error"] == True
assert "Error: Bad request error" in result["message"]
def test_process_response_invalid_json(self):
"""Test _process_response with invalid JSON"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.side_effect = ValueError("Invalid JSON")
mock_response.text = "Invalid response text"
result = self.client._process_response(mock_response)
# Should handle the JSON parsing error gracefully
assert "detail" in result or "message" in result
@patch('requests.get')
def test_check_health_healthy(self, mock_get):
"""Test health check with healthy status"""
mock_response = Mock()
mock_response.json.return_value = {"status": "healthy"}
mock_get.return_value = mock_response
status, message, level = self.client.check_health()
assert status == "🟢 Active"
assert message == "Online"
assert level == "success"
@patch('requests.get')
def test_check_health_degraded(self, mock_get):
"""Test health check with degraded status"""
mock_response = Mock()
mock_response.status_code = 503
mock_response.json.return_value = {"status": "degraded"}
mock_get.return_value = mock_response
status, message, level = self.client.check_health()
assert status == "🟡 Degraded"
assert message == "Some Issues"
assert level == "warning"
@patch('requests.get')
@patch('socket.create_connection')
def test_check_health_connection_error_with_socket_fallback(self, mock_socket, mock_get):
"""Test health check with connection error but socket reachable"""
mock_get.side_effect = requests.exceptions.RequestException("Connection failed")
mock_socket.return_value.close.return_value = None
status, message, level = self.client.check_health()
assert status == "🟡 Reachable"
assert message == "Port Open"
assert level == "warning"
@patch('requests.get')
@patch('socket.create_connection')
def test_check_health_completely_offline(self, mock_socket, mock_get):
"""Test health check when completely offline"""
mock_get.side_effect = requests.exceptions.RequestException("Connection failed")
mock_socket.side_effect = Exception("Socket connection failed")
status, message, level = self.client.check_health()
assert status == "🔴 Offline"
assert message == "Connection Failed"
assert level == "error"
@patch('requests.get')
def test_get_detailed_health_success(self, mock_get):
"""Test get_detailed_health with successful response"""
expected_health = {
"status": "healthy",
"checks": {"database": "ok", "api": "ok"}
}
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = expected_health
mock_get.return_value = mock_response
result = self.client.get_detailed_health()
assert result == expected_health
@patch('requests.get')
def test_get_detailed_health_error(self, mock_get):
"""Test get_detailed_health with error response"""
mock_response = Mock()
mock_response.status_code = 500
mock_get.return_value = mock_response
result = self.client.get_detailed_health()
assert result["status"] == "error"
assert "500" in result["message"]
@patch('requests.get')
def test_get_method_success(self, mock_get):
"""Test generic GET method with success"""
expected_data = {"models": ["gpt-4", "gpt-3.5"]}
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = expected_data
mock_get.return_value = mock_response
result = self.client.get("/models")
assert result == expected_data
@patch('requests.get')
def test_get_method_error(self, mock_get):
"""Test generic GET method with error"""
mock_response = Mock()
mock_response.status_code = 404
mock_get.return_value = mock_response
result = self.client.get("/models")
assert result is None
@patch('requests.post')
def test_post_method_success(self, mock_post):
"""Test generic POST method with success"""
expected_data = {"result": "success"}
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = expected_data
mock_post.return_value = mock_response
result = self.client.post("/change-model", {"model": "gpt-4"})
assert result == expected_data
@patch('requests.post')
def test_post_method_error(self, mock_post):
"""Test generic POST method with error"""
mock_response = Mock()
mock_response.status_code = 400
mock_post.return_value = mock_response
result = self.client.post("/change-model", {"model": "invalid"})
assert result is None
def test_url_construction(self):
"""Test URL construction for different endpoints"""
client = APIClient("http://localhost:8000/")
# Test that trailing slash is removed
assert client.base_url == "http://localhost:8000"
# Test endpoint URLs
process_url = f"{client.base_url}{client.endpoints['process_text']}"
assert process_url == "http://localhost:8000/api/process-text"
@patch('requests.post')
def test_send_query_with_legacy_agent_parameter(self, mock_post):
"""Test that legacy agent parameter is handled correctly"""
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"message": "Success"}
mock_post.return_value = mock_response
# Agent parameter should be ignored in payload
self.client.send_query("Test", model="gpt-4", agent="legacy_agent")
args, kwargs = mock_post.call_args
payload = kwargs['json']
assert "agent" not in payload # Agent should not be in payload
assert payload["model_name"] == "gpt-4"
if __name__ == "__main__":
pytest.main([__file__, "-v"])