""" Production Defect Detection Application Supports both API and Web Interface modes """ import os import io import sys import base64 import json import time from pathlib import Path from typing import Dict, Optional, Tuple import argparse import torch import torch.nn.functional as F import numpy as np from PIL import Image import cv2 import albumentations as A from albumentations.pytorch import ToTensorV2 # Model imports from models.vision_transformer import get_model # Global model cache _model_cache = {"model": None, "device": None, "transform": None} def get_transform(): """Get image preprocessing transform""" return A.Compose([ A.Resize(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() ]) def load_model(model_path: Optional[str] = None) -> Tuple[torch.nn.Module, torch.device, dict]: """Load model with caching and automatic download from HuggingFace Hub""" if _model_cache["model"] is not None: return _model_cache["model"], _model_cache["device"], _model_cache.get("info", {}) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Check for local model first (for development) if model_path is None: local_model_path = Path("models/best_model.pth") if local_model_path.exists(): model_path = str(local_model_path) print(f"šŸ“‚ Using local model: {model_path}") # If no local model, download from HuggingFace Hub if not model_path or not Path(model_path).exists(): print("šŸ“„ Downloading model from HuggingFace Hub...") print(" Model size: ~1.1GB (this will only happen once)") try: from huggingface_hub import hf_hub_download # Try to download from model repository model_repo = "gphua1/rklb-defect-model" # IMPORTANT: Use /tmp directory on HuggingFace Spaces to avoid repo size limits # /tmp is ephemeral but doesn't count against Space storage if os.environ.get('SPACE_ID'): cache_dir = Path("/tmp/model_cache") else: cache_dir = Path("models") cache_dir.mkdir(parents=True, exist_ok=True) # Download the model file to temp directory model_path = hf_hub_download( repo_id=model_repo, filename="best_model.pth", cache_dir=str(cache_dir), local_files_only=False, # Allow downloading resume_download=True, # Resume if interrupted local_dir=str(cache_dir) if os.environ.get('SPACE_ID') else None, local_dir_use_symlinks=False ) print(f"āœ… Model downloaded to temporary cache") except Exception as e: error_msg = f"Failed to download model from HuggingFace Hub: {str(e)}" print(f"āŒ {error_msg}") # Provide helpful error messages if "401" in str(e) or "Repository Not Found" in str(e): print("\nāš ļø Model repository not found or not accessible") print(" Upload your model manually to: https://huggingface.co/gphua1/rklb-defect-model") print(" 1. Go to https://huggingface.co/new") print(" 2. Create repo named 'rklb-defect-model'") print(" 3. Upload models/best_model.pth file") elif "Connection" in str(e): print("\nāš ļø Network connection issue. Please check your internet connection.") raise RuntimeError(error_msg) # Load checkpoint try: checkpoint = torch.load(model_path, map_location=device, weights_only=False) model_type = checkpoint.get('model_type', 'efficient_vit') # Create and load model model = get_model(model_type, num_classes=2, pretrained=False) model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() model_info = { 'model_type': model_type, 'accuracy': checkpoint.get('best_acc', checkpoint.get('accuracy', 0)), 'model_path': model_path, 'device': str(device) } # Cache model _model_cache["model"] = model _model_cache["device"] = device _model_cache["transform"] = get_transform() _model_cache["info"] = model_info print(f"āœ… Model loaded: {model_type} (Accuracy: {model_info['accuracy']:.1f}%)") print(f" Device: {device}") return model, device, model_info except Exception as e: raise RuntimeError(f"Failed to load model checkpoint: {str(e)}") @torch.no_grad() def predict_image(image: np.ndarray, model=None) -> Dict: """Predict defect in image""" if model is None: model, device, _ = load_model() else: device = next(model.parameters()).device transform = _model_cache.get("transform") or get_transform() # Preprocess augmented = transform(image=image) image_tensor = augmented['image'].unsqueeze(0).to(device) # Inference start_time = time.time() outputs = model(image_tensor) probs = F.softmax(outputs, dim=1) confidence, predicted = torch.max(probs, 1) inference_time = (time.time() - start_time) * 1000 return { 'prediction': 'DEFECTIVE' if predicted.item() == 1 else 'NORMAL', 'confidence': confidence.item(), 'defect_probability': probs[0][1].item(), 'normal_probability': probs[0][0].item(), 'inference_time': inference_time } def run_streamlit(): """Run Streamlit web interface with clean Rocket Lab theme""" import sys import subprocess import os # If not running through streamlit, restart with streamlit if "streamlit.runtime.scriptrunner" not in sys.modules: print("šŸš€ Starting Rocket Lab Defect Detection System...") print(" Opening browser to http://localhost:8501") subprocess.run(["streamlit", "run", __file__]) return import streamlit as st # Rocket Lab themed configuration - with permanent sidebar st.set_page_config( page_title="RKLB Defect Detection", page_icon="šŸš€", layout="wide", initial_sidebar_state="expanded" # Make sure sidebar is expanded ) # Custom CSS with professional theme and permanent sidebar st.markdown(""" """, unsafe_allow_html=True) # Load model try: model, device, model_info = load_model() except Exception as e: st.error(f"Model Error: {e}") st.stop() # Header - Professional Rocket Lab style st.markdown("""

ROCKET LAB COMPONENT DEFECT DETECTION

Made by Gary Phua
""", unsafe_allow_html=True) # Sidebar with sample images - ensure it's visible with st.sidebar: # Professional sidebar header st.markdown("""
Click to load sample image
""", unsafe_allow_html=True) # Get example images examples_dir = Path("examples") sample_images = [] if examples_dir.exists(): normal_samples = sorted((examples_dir / "normal").glob("*.png")) defect_samples = sorted((examples_dir / "defective").glob("*.png")) # Select samples to ensure variety if len(normal_samples) >= 1: sample_images.append(normal_samples[0]) # First normal if len(defect_samples) >= 1: sample_images.append(defect_samples[0]) # Defective if len(normal_samples) >= 2: sample_images.append(normal_samples[-1]) # Last normal # Display sample images in sidebar if sample_images: for idx, sample_path in enumerate(sample_images): # Load and create small thumbnail img = Image.open(sample_path) img_thumbnail = img.resize((120, 120), Image.Resampling.LANCZOS) # Professional label based on type if "defect" in str(sample_path).lower(): label = f"Sample {idx + 1}" else: label = f"Sample {idx + 1}" # Display in sidebar with improved layout col1, col2 = st.columns([1, 2]) with col1: st.image(img_thumbnail, use_container_width=True) with col2: st.markdown(f"
{label}
", unsafe_allow_html=True) if st.button("Load", key=f"sample_{idx}", use_container_width=True): st.session_state['selected_image'] = str(sample_path) st.session_state['image_source'] = 'sample' st.rerun() # Main content area main_container = st.container() with main_container: # Upload section with professional header st.markdown("""
Upload Component Image
""", unsafe_allow_html=True) uploaded_file = st.file_uploader( "Select image file (PNG, JPG, JPEG, BMP)", type=['png', 'jpg', 'jpeg', 'bmp'], label_visibility="visible" ) # Process image image = None image_np = None image_name = None if uploaded_file: image = Image.open(uploaded_file) image_name = uploaded_file.name st.session_state['image_source'] = 'upload' st.session_state['selected_image'] = None elif 'selected_image' in st.session_state and st.session_state['selected_image']: image = Image.open(st.session_state['selected_image']) image_name = Path(st.session_state['selected_image']).name # Results section positioned below upload if image: image_np = np.array(image.convert('RGB')) # Display image preview with professional header st.markdown("""
Analysis Results
""", unsafe_allow_html=True) col1, col2 = st.columns([1, 2]) with col1: st.image(image, use_container_width=True) with col2: # Run prediction with st.spinner("Analyzing component..."): result = predict_image(image_np, model) # Result display - professional layout st.markdown("""
Quality Assessment
""", unsafe_allow_html=True) if result['prediction'] == 'DEFECTIVE': st.markdown("""
DEFECT DETECTED
Component Failed Quality Check
""", unsafe_allow_html=True) else: st.markdown("""
PASSED
Component Passed Quality Check
""", unsafe_allow_html=True) # Metrics st.markdown("""
Confidence
{:.1f}%
Processing Time
{:.0f}ms
""".format(result['confidence'] * 100, result['inference_time']), unsafe_allow_html=True) else: # Empty state - professional st.markdown("""
Ready for Analysis
Upload a component image or select a sample from the sidebar to begin quality inspection
""", unsafe_allow_html=True) def run_api(): """Run FastAPI server""" from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from pydantic import BaseModel import uvicorn app = FastAPI( title="Defect Detection API", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class PredictionRequest(BaseModel): image: str # base64 encoded @app.on_event("startup") async def startup(): try: load_model() print("āœ… Model loaded successfully") except Exception as e: print(f"āŒ Model loading failed: {e}") @app.get("/") async def root(): return { "message": "Defect Detection API", "endpoints": { "health": "/health", "predict": "/predict", "interface": "/interface" } } @app.get("/health") async def health(): return {"status": "healthy", "model_loaded": _model_cache["model"] is not None} @app.post("/predict") async def predict(request: PredictionRequest): if _model_cache["model"] is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Decode image image_bytes = base64.b64decode(request.image) image = Image.open(io.BytesIO(image_bytes)) image_np = np.array(image.convert('RGB')) # Predict result = predict_image(image_np, _model_cache["model"]) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/interface") async def interface(): html = """ RKLB Defect Detection

ROCKET LAB COMPONENT DEFECT DETECTION SYSTEM

Made by Gary Phua


""" return HTMLResponse(content=html) # For Vercel deployment if os.environ.get('VERCEL'): return app # Local server uvicorn.run(app, host="0.0.0.0", port=8000) def run_cli(args): """Run command-line interface""" model, device, info = load_model(args.model) print(f"āœ… Model loaded: {info['model_type']} (Acc: {info['accuracy']:.1f}%)") if args.image: # Single image prediction image = cv2.imread(args.image) if image is None: print(f"āŒ Cannot load image: {args.image}") return image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) result = predict_image(image, model) print(f"\nšŸ“· Image: {args.image}") print(f"šŸŽÆ Prediction: {result['prediction']}") print(f"šŸ“Š Confidence: {result['confidence']:.2%}") print(f"ā±ļø Inference: {result['inference_time']:.1f}ms") elif args.directory: # Batch prediction from pathlib import Path results = [] for img_path in Path(args.directory).glob("**/*"): if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp']: image = cv2.imread(str(img_path)) if image is not None: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) result = predict_image(image, model) result['path'] = str(img_path) results.append(result) print(f"{'šŸ”“' if result['prediction'] == 'DEFECTIVE' else '🟢'} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})") # Summary defective = sum(1 for r in results if r['prediction'] == 'DEFECTIVE') print(f"\nšŸ“Š Results: {defective}/{len(results)} defective ({defective/len(results)*100:.1f}%)") if args.output: with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f"šŸ’¾ Saved to {args.output}") def main(): # Check if running through streamlit import sys if "streamlit.runtime.scriptrunner" in sys.modules: run_streamlit() return parser = argparse.ArgumentParser(description='Defect Detection Application') parser.add_argument('--mode', choices=['web', 'api', 'cli'], default='web', help='Run mode: web (Streamlit), api (FastAPI), or cli') parser.add_argument('--model', type=str, help='Model path') parser.add_argument('--image', type=str, help='Single image path (CLI mode)') parser.add_argument('--directory', type=str, help='Directory of images (CLI mode)') parser.add_argument('--output', type=str, help='Save results to JSON (CLI mode)') args = parser.parse_args() if args.mode == 'web': run_streamlit() elif args.mode == 'api': run_api() else: run_cli(args) # For Vercel deployment app = None if os.environ.get('VERCEL'): from fastapi import FastAPI # Return the FastAPI app for Vercel app = run_api() if __name__ == "__main__": main()