Deploy optimized lightweight oil spill detection API - Runtime model download, memory efficient
ada534f
| """ | |
| Lightweight Oil Spill Detection API for HuggingFace Spaces | |
| Simplified version that loads models on-demand to reduce memory usage | |
| """ | |
| import os | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, List | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| import base64 | |
| from datetime import datetime | |
| import gc | |
| # Configure environment for HuggingFace Spaces | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Suppress TensorFlow logs | |
| app = FastAPI( | |
| title="Oil Spill Detection API", | |
| description="Lightweight API for oil spill detection using deep learning", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware for frontend access | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your frontend domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for lazy loading | |
| model1 = None | |
| model2 = None | |
| class PredictionResponse(BaseModel): | |
| success: bool | |
| prediction: Optional[str] = None | |
| confidence: Optional[float] = None | |
| processing_time: Optional[float] = None | |
| model_used: Optional[str] = None | |
| error: Optional[str] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| timestamp: str | |
| models_loaded: dict | |
| def lazy_load_model1(): | |
| """Load model1 only when needed""" | |
| global model1 | |
| if model1 is None: | |
| try: | |
| import tensorflow as tf | |
| model_path = "models/unet_final_model.h5" | |
| if os.path.exists(model_path): | |
| model1 = tf.keras.models.load_model(model_path) | |
| print("Model 1 (U-Net) loaded successfully") | |
| else: | |
| print(f"Model 1 not found at {model_path}") | |
| except Exception as e: | |
| print(f"Error loading model 1: {e}") | |
| return model1 | |
| def lazy_load_model2(): | |
| """Load model2 only when needed""" | |
| global model2 | |
| if model2 is None: | |
| try: | |
| import tensorflow as tf | |
| model_path = "models/deeplab_final_model.h5" | |
| if os.path.exists(model_path): | |
| model2 = tf.keras.models.load_model(model_path) | |
| print("Model 2 (DeepLab) loaded successfully") | |
| else: | |
| print(f"Model 2 not found at {model_path}") | |
| except Exception as e: | |
| print(f"Error loading model 2: {e}") | |
| return model2 | |
| def preprocess_image(image: Image.Image, target_size=(256, 256)): | |
| """Preprocess image for model prediction""" | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize image | |
| image = image.resize(target_size) | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| # Normalize pixel values | |
| img_array = img_array.astype(np.float32) / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| return img_array | |
| def predict_oil_spill(image_array, model): | |
| """Make prediction using the specified model""" | |
| try: | |
| # Make prediction | |
| prediction = model.predict(image_array) | |
| # Process prediction (assuming binary classification) | |
| confidence = float(np.max(prediction)) | |
| predicted_class = "Oil Spill Detected" if confidence > 0.5 else "No Oil Spill" | |
| return predicted_class, confidence | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| timestamp=datetime.now().isoformat(), | |
| models_loaded={ | |
| "model1": model1 is not None, | |
| "model2": model2 is not None | |
| } | |
| ) | |
| async def get_models_info(): | |
| """Get information about available models""" | |
| return { | |
| "models": { | |
| "model1": { | |
| "name": "U-Net", | |
| "description": "U-Net model for semantic segmentation", | |
| "loaded": model1 is not None | |
| }, | |
| "model2": { | |
| "name": "DeepLab V3+", | |
| "description": "DeepLab V3+ model for semantic segmentation", | |
| "loaded": model2 is not None | |
| } | |
| } | |
| } | |
| async def predict( | |
| file: UploadFile = File(...), | |
| model_choice: str = "model1" | |
| ): | |
| """Predict oil spill in uploaded image""" | |
| start_time = datetime.now() | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)) | |
| # Preprocess image | |
| processed_image = preprocess_image(image) | |
| # Load appropriate model | |
| if model_choice == "model1": | |
| model = lazy_load_model1() | |
| model_name = "U-Net" | |
| else: | |
| model = lazy_load_model2() | |
| model_name = "DeepLab V3+" | |
| if model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model {model_choice} is not available" | |
| ) | |
| # Make prediction | |
| predicted_class, confidence = predict_oil_spill(processed_image, model) | |
| # Calculate processing time | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| # Clean up memory | |
| gc.collect() | |
| return PredictionResponse( | |
| success=True, | |
| prediction=predicted_class, | |
| confidence=round(confidence, 4), | |
| processing_time=round(processing_time, 2), | |
| model_used=model_name | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| return PredictionResponse( | |
| success=False, | |
| error=str(e) | |
| ) | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Oil Spill Detection API", | |
| "status": "running", | |
| "endpoints": { | |
| "health": "/health", | |
| "models": "/models/info", | |
| "predict": "/predict", | |
| "docs": "/docs" | |
| } | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |