"""
Production Defect Detection Application
Supports both API and Web Interface modes
"""
import os
import io
import sys
import base64
import json
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import argparse
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Model imports
from models.vision_transformer import get_model
# Global model cache
_model_cache = {"model": None, "device": None, "transform": None}
def get_transform():
"""Get image preprocessing transform"""
return A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def load_model(model_path: Optional[str] = None) -> Tuple[torch.nn.Module, torch.device, dict]:
"""Load model with caching and automatic download from HuggingFace Hub"""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["device"], _model_cache.get("info", {})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Check for local model first (for development)
if model_path is None:
local_model_path = Path("models/best_model.pth")
if local_model_path.exists():
model_path = str(local_model_path)
print(f"š Using local model: {model_path}")
# If no local model, download from HuggingFace Hub
if not model_path or not Path(model_path).exists():
print("š„ Downloading model from HuggingFace Hub...")
print(" Model size: ~1.1GB (this will only happen once)")
try:
from huggingface_hub import hf_hub_download
# Try to download from model repository
model_repo = "gphua1/rklb-defect-model"
# IMPORTANT: Use /tmp directory on HuggingFace Spaces to avoid repo size limits
# /tmp is ephemeral but doesn't count against Space storage
if os.environ.get('SPACE_ID'):
cache_dir = Path("/tmp/model_cache")
else:
cache_dir = Path("models")
cache_dir.mkdir(parents=True, exist_ok=True)
# Download the model file to temp directory
model_path = hf_hub_download(
repo_id=model_repo,
filename="best_model.pth",
cache_dir=str(cache_dir),
local_files_only=False, # Allow downloading
resume_download=True, # Resume if interrupted
local_dir=str(cache_dir) if os.environ.get('SPACE_ID') else None,
local_dir_use_symlinks=False
)
print(f"ā Model downloaded to temporary cache")
except Exception as e:
error_msg = f"Failed to download model from HuggingFace Hub: {str(e)}"
print(f"ā {error_msg}")
# Provide helpful error messages
if "401" in str(e) or "Repository Not Found" in str(e):
print("\nā ļø Model repository not found or not accessible")
print(" Upload your model manually to: https://huggingface.co/gphua1/rklb-defect-model")
print(" 1. Go to https://huggingface.co/new")
print(" 2. Create repo named 'rklb-defect-model'")
print(" 3. Upload models/best_model.pth file")
elif "Connection" in str(e):
print("\nā ļø Network connection issue. Please check your internet connection.")
raise RuntimeError(error_msg)
# Load checkpoint
try:
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model_type = checkpoint.get('model_type', 'efficient_vit')
# Create and load model
model = get_model(model_type, num_classes=2, pretrained=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
model_info = {
'model_type': model_type,
'accuracy': checkpoint.get('best_acc', checkpoint.get('accuracy', 0)),
'model_path': model_path,
'device': str(device)
}
# Cache model
_model_cache["model"] = model
_model_cache["device"] = device
_model_cache["transform"] = get_transform()
_model_cache["info"] = model_info
print(f"ā Model loaded: {model_type} (Accuracy: {model_info['accuracy']:.1f}%)")
print(f" Device: {device}")
return model, device, model_info
except Exception as e:
raise RuntimeError(f"Failed to load model checkpoint: {str(e)}")
@torch.no_grad()
def predict_image(image: np.ndarray, model=None) -> Dict:
"""Predict defect in image"""
if model is None:
model, device, _ = load_model()
else:
device = next(model.parameters()).device
transform = _model_cache.get("transform") or get_transform()
# Preprocess
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
# Inference
start_time = time.time()
outputs = model(image_tensor)
probs = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
inference_time = (time.time() - start_time) * 1000
return {
'prediction': 'DEFECTIVE' if predicted.item() == 1 else 'NORMAL',
'confidence': confidence.item(),
'defect_probability': probs[0][1].item(),
'normal_probability': probs[0][0].item(),
'inference_time': inference_time
}
def run_streamlit():
"""Run Streamlit web interface with clean Rocket Lab theme"""
import sys
import subprocess
import os
# If not running through streamlit, restart with streamlit
if "streamlit.runtime.scriptrunner" not in sys.modules:
print("š Starting Rocket Lab Defect Detection System...")
print(" Opening browser to http://localhost:8501")
subprocess.run(["streamlit", "run", __file__])
return
import streamlit as st
# Rocket Lab themed configuration - with permanent sidebar
st.set_page_config(
page_title="RKLB Defect Detection",
page_icon="š",
layout="wide",
initial_sidebar_state="expanded" # Make sure sidebar is expanded
)
# Custom CSS with professional theme and permanent sidebar
st.markdown("""
""", unsafe_allow_html=True)
# Load model
try:
model, device, model_info = load_model()
except Exception as e:
st.error(f"Model Error: {e}")
st.stop()
# Header - Professional Rocket Lab style
st.markdown("""
ROCKET LABCOMPONENT DEFECT DETECTION
Made by Gary Phua
""", unsafe_allow_html=True)
# Sidebar with sample images - ensure it's visible
with st.sidebar:
# Professional sidebar header
st.markdown("""
Test Samples
Click to load sample image
""", unsafe_allow_html=True)
# Get example images
examples_dir = Path("examples")
sample_images = []
if examples_dir.exists():
normal_samples = sorted((examples_dir / "normal").glob("*.png"))
defect_samples = sorted((examples_dir / "defective").glob("*.png"))
# Select samples to ensure variety
if len(normal_samples) >= 1:
sample_images.append(normal_samples[0]) # First normal
if len(defect_samples) >= 1:
sample_images.append(defect_samples[0]) # Defective
if len(normal_samples) >= 2:
sample_images.append(normal_samples[-1]) # Last normal
# Display sample images in sidebar
if sample_images:
for idx, sample_path in enumerate(sample_images):
# Load and create small thumbnail
img = Image.open(sample_path)
img_thumbnail = img.resize((120, 120), Image.Resampling.LANCZOS)
# Professional label based on type
if "defect" in str(sample_path).lower():
label = f"Sample {idx + 1}"
else:
label = f"Sample {idx + 1}"
# Display in sidebar with improved layout
col1, col2 = st.columns([1, 2])
with col1:
st.image(img_thumbnail, use_container_width=True)
with col2:
st.markdown(f"
{label}
", unsafe_allow_html=True)
if st.button("Load", key=f"sample_{idx}", use_container_width=True):
st.session_state['selected_image'] = str(sample_path)
st.session_state['image_source'] = 'sample'
st.rerun()
# Main content area
main_container = st.container()
with main_container:
# Upload section with professional header
st.markdown("""
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 2])
with col1:
st.image(image, use_container_width=True)
with col2:
# Run prediction
with st.spinner("Analyzing component..."):
result = predict_image(image_np, model)
# Result display - professional layout
st.markdown("""
Quality Assessment
""", unsafe_allow_html=True)
if result['prediction'] == 'DEFECTIVE':
st.markdown("""