superkart_backend / test_api.py
jskswamy's picture
Uploading files via huggingface api
120c870 verified
"""
Test script for SuperKart Backend API
Simple test script to verify the API endpoints are working correctly.
Run this after starting the Flask application.
"""
import requests
from typing import Any, Dict, List, Tuple
# API base URL
BASE_URL = "http://localhost:7860"
def test_health_check() -> bool:
"""Test the health check endpoint."""
print("πŸ” Testing health check endpoint...")
try:
response = requests.get(f"{BASE_URL}/")
print(f"Status: {response.status_code}")
print(f"Response: {response.json()}")
return response.status_code == 200
except Exception as e:
print(f"❌ Health check failed: {e}")
return False
def test_features_endpoint() -> bool:
"""Test the features information endpoint."""
print("\nπŸ” Testing features endpoint...")
try:
response = requests.get(f"{BASE_URL}/features")
print(f"Status: {response.status_code}")
data = response.json()
print(f"Required features: {len(data['required_features'])}")
return response.status_code == 200
except Exception as e:
print(f"❌ Features endpoint failed: {e}")
return False
def test_single_prediction() -> bool:
"""Test single prediction endpoint."""
print("\nπŸ” Testing single prediction endpoint...")
# Example input data
test_data: Dict[str, Any] = {
"Product_Weight": 12.66,
"Product_Sugar_Content": "Low Sugar",
"Product_Allocated_Area": 0.027,
"Product_Type": "Frozen Foods",
"Product_MRP": 117.08,
"Store_Establishment_Year": 2009,
"Store_Size": "Medium",
"Store_Location_City_Type": "Tier 2",
"Store_Type": "Supermarket Type2",
}
try:
response = requests.post(f"{BASE_URL}/predict", json=test_data)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"Predicted sales: ${result['predicted_sales']:.2f}")
return True
else:
print(f"Error: {response.json()}")
return False
except Exception as e:
print(f"❌ Single prediction failed: {e}")
return False
def test_batch_prediction() -> bool:
"""Test batch prediction endpoint."""
print("\nπŸ” Testing batch prediction endpoint...")
# Example batch data
batch_data: Dict[str, List[Dict[str, Any]]] = {
"predictions": [
{
"Product_Weight": 12.66,
"Product_Sugar_Content": "Low Sugar",
"Product_Allocated_Area": 0.027,
"Product_Type": "Frozen Foods",
"Product_MRP": 117.08,
"Store_Establishment_Year": 2009,
"Store_Size": "Medium",
"Store_Location_City_Type": "Tier 2",
"Store_Type": "Supermarket Type2",
},
{
"Product_Weight": 16.54,
"Product_Sugar_Content": "Low Sugar",
"Product_Allocated_Area": 0.144,
"Product_Type": "Dairy",
"Product_MRP": 171.43,
"Store_Establishment_Year": 1999,
"Store_Size": "Medium",
"Store_Location_City_Type": "Tier 1",
"Store_Type": "Departmental Store",
},
],
}
try:
response = requests.post(f"{BASE_URL}/predict/batch", json=batch_data)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"Successful predictions: {result['successful_predictions']}")
print(f"Failed predictions: {result['failed_predictions']}")
# Print each prediction value
if "results" in result:
for pred in result["results"]:
idx = pred.get("index", "?")
val = pred.get("predicted_sales", "?")
print(f"Prediction {idx}: ${val:.2f}")
return True
else:
print(f"Error: {response.json()}")
return False
except Exception as e:
print(f"❌ Batch prediction failed: {e}")
return False
def run_all_tests() -> None:
"""Run all API tests."""
print("πŸš€ Starting SuperKart API Tests\n")
tests: List[Tuple[str, Any]] = [
("Health Check", test_health_check),
("Features Endpoint", test_features_endpoint),
("Single Prediction", test_single_prediction),
("Batch Prediction", test_batch_prediction),
]
results: List[Tuple[str, bool]] = []
for test_name, test_func in tests:
result = test_func()
results.append((test_name, result))
print("βœ… PASSED" if result else "❌ FAILED")
print(f"\nπŸ“Š Test Summary:")
passed = sum(1 for _, result in results if result)
total = len(results)
print(f"Passed: {passed}/{total}")
if passed == total:
print("πŸŽ‰ All tests passed!")
else:
print("⚠️ Some tests failed. Check the API server.")
if __name__ == "__main__":
run_all_tests()