sahilvishwa2108's picture
Deploy optimized lightweight oil spill detection API - Runtime model download, memory efficient
ada534f
"""
Lightweight Oil Spill Detection API for HuggingFace Spaces
Simplified version that loads models on-demand to reduce memory usage
"""
import os
import io
import numpy as np
from PIL import Image
from typing import Optional, List
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import base64
from datetime import datetime
import gc
# Configure environment for HuggingFace Spaces
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow logs
app = FastAPI(
title="Oil Spill Detection API",
description="Lightweight API for oil spill detection using deep learning",
version="1.0.0"
)
# CORS middleware for frontend access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify your frontend domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for lazy loading
model1 = None
model2 = None
class PredictionResponse(BaseModel):
success: bool
prediction: Optional[str] = None
confidence: Optional[float] = None
processing_time: Optional[float] = None
model_used: Optional[str] = None
error: Optional[str] = None
class HealthResponse(BaseModel):
status: str
timestamp: str
models_loaded: dict
def lazy_load_model1():
"""Load model1 only when needed"""
global model1
if model1 is None:
try:
import tensorflow as tf
model_path = "models/unet_final_model.h5"
if os.path.exists(model_path):
model1 = tf.keras.models.load_model(model_path)
print("Model 1 (U-Net) loaded successfully")
else:
print(f"Model 1 not found at {model_path}")
except Exception as e:
print(f"Error loading model 1: {e}")
return model1
def lazy_load_model2():
"""Load model2 only when needed"""
global model2
if model2 is None:
try:
import tensorflow as tf
model_path = "models/deeplab_final_model.h5"
if os.path.exists(model_path):
model2 = tf.keras.models.load_model(model_path)
print("Model 2 (DeepLab) loaded successfully")
else:
print(f"Model 2 not found at {model_path}")
except Exception as e:
print(f"Error loading model 2: {e}")
return model2
def preprocess_image(image: Image.Image, target_size=(256, 256)):
"""Preprocess image for model prediction"""
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize image
image = image.resize(target_size)
# Convert to numpy array
img_array = np.array(image)
# Normalize pixel values
img_array = img_array.astype(np.float32) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_oil_spill(image_array, model):
"""Make prediction using the specified model"""
try:
# Make prediction
prediction = model.predict(image_array)
# Process prediction (assuming binary classification)
confidence = float(np.max(prediction))
predicted_class = "Oil Spill Detected" if confidence > 0.5 else "No Oil Spill"
return predicted_class, confidence
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy",
timestamp=datetime.now().isoformat(),
models_loaded={
"model1": model1 is not None,
"model2": model2 is not None
}
)
@app.get("/models/info")
async def get_models_info():
"""Get information about available models"""
return {
"models": {
"model1": {
"name": "U-Net",
"description": "U-Net model for semantic segmentation",
"loaded": model1 is not None
},
"model2": {
"name": "DeepLab V3+",
"description": "DeepLab V3+ model for semantic segmentation",
"loaded": model2 is not None
}
}
}
@app.post("/predict", response_model=PredictionResponse)
async def predict(
file: UploadFile = File(...),
model_choice: str = "model1"
):
"""Predict oil spill in uploaded image"""
start_time = datetime.now()
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read and process image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data))
# Preprocess image
processed_image = preprocess_image(image)
# Load appropriate model
if model_choice == "model1":
model = lazy_load_model1()
model_name = "U-Net"
else:
model = lazy_load_model2()
model_name = "DeepLab V3+"
if model is None:
raise HTTPException(
status_code=503,
detail=f"Model {model_choice} is not available"
)
# Make prediction
predicted_class, confidence = predict_oil_spill(processed_image, model)
# Calculate processing time
processing_time = (datetime.now() - start_time).total_seconds()
# Clean up memory
gc.collect()
return PredictionResponse(
success=True,
prediction=predicted_class,
confidence=round(confidence, 4),
processing_time=round(processing_time, 2),
model_used=model_name
)
except HTTPException:
raise
except Exception as e:
return PredictionResponse(
success=False,
error=str(e)
)
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Oil Spill Detection API",
"status": "running",
"endpoints": {
"health": "/health",
"models": "/models/info",
"predict": "/predict",
"docs": "/docs"
}
}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)