Spaces:
Configuration error
Configuration error
""" | |
Production Defect Detection Application | |
Supports both API and Web Interface modes | |
""" | |
import os | |
import io | |
import sys | |
import base64 | |
import json | |
import time | |
from pathlib import Path | |
from typing import Dict, Optional, Tuple | |
import argparse | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
# Model imports | |
from models.vision_transformer import get_model | |
# Global model cache | |
_model_cache = {"model": None, "device": None, "transform": None} | |
def get_transform(): | |
"""Get image preprocessing transform""" | |
return A.Compose([ | |
A.Resize(224, 224), | |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ToTensorV2() | |
]) | |
def load_model(model_path: Optional[str] = None) -> Tuple[torch.nn.Module, torch.device, dict]: | |
"""Load model with caching and automatic download from HuggingFace Hub""" | |
if _model_cache["model"] is not None: | |
return _model_cache["model"], _model_cache["device"], _model_cache.get("info", {}) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Check for local model first (for development) | |
if model_path is None: | |
local_model_path = Path("models/best_model.pth") | |
if local_model_path.exists(): | |
model_path = str(local_model_path) | |
print(f"π Using local model: {model_path}") | |
# If no local model, download from HuggingFace Hub | |
if not model_path or not Path(model_path).exists(): | |
print("π₯ Downloading model from HuggingFace Hub...") | |
print(" Model size: ~1.1GB (this will only happen once)") | |
try: | |
from huggingface_hub import hf_hub_download | |
# Try to download from model repository | |
model_repo = "gphua1/rklb-defect-model" | |
# IMPORTANT: Use /tmp directory on HuggingFace Spaces to avoid repo size limits | |
# /tmp is ephemeral but doesn't count against Space storage | |
if os.environ.get('SPACE_ID'): | |
cache_dir = Path("/tmp/model_cache") | |
else: | |
cache_dir = Path("models") | |
cache_dir.mkdir(parents=True, exist_ok=True) | |
# Download the model file to temp directory | |
model_path = hf_hub_download( | |
repo_id=model_repo, | |
filename="best_model.pth", | |
cache_dir=str(cache_dir), | |
local_files_only=False, # Allow downloading | |
resume_download=True, # Resume if interrupted | |
local_dir=str(cache_dir) if os.environ.get('SPACE_ID') else None, | |
local_dir_use_symlinks=False | |
) | |
print(f"β Model downloaded to temporary cache") | |
except Exception as e: | |
error_msg = f"Failed to download model from HuggingFace Hub: {str(e)}" | |
print(f"β {error_msg}") | |
# Provide helpful error messages | |
if "401" in str(e) or "Repository Not Found" in str(e): | |
print("\nβ οΈ Model repository not found or not accessible") | |
print(" Upload your model manually to: https://huggingface.co/gphua1/rklb-defect-model") | |
print(" 1. Go to https://huggingface.co/new") | |
print(" 2. Create repo named 'rklb-defect-model'") | |
print(" 3. Upload models/best_model.pth file") | |
elif "Connection" in str(e): | |
print("\nβ οΈ Network connection issue. Please check your internet connection.") | |
raise RuntimeError(error_msg) | |
# Load checkpoint | |
try: | |
checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
model_type = checkpoint.get('model_type', 'efficient_vit') | |
# Create and load model | |
model = get_model(model_type, num_classes=2, pretrained=False) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
model_info = { | |
'model_type': model_type, | |
'accuracy': checkpoint.get('best_acc', checkpoint.get('accuracy', 0)), | |
'model_path': model_path, | |
'device': str(device) | |
} | |
# Cache model | |
_model_cache["model"] = model | |
_model_cache["device"] = device | |
_model_cache["transform"] = get_transform() | |
_model_cache["info"] = model_info | |
print(f"β Model loaded: {model_type} (Accuracy: {model_info['accuracy']:.1f}%)") | |
print(f" Device: {device}") | |
return model, device, model_info | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model checkpoint: {str(e)}") | |
def predict_image(image: np.ndarray, model=None) -> Dict: | |
"""Predict defect in image""" | |
if model is None: | |
model, device, _ = load_model() | |
else: | |
device = next(model.parameters()).device | |
transform = _model_cache.get("transform") or get_transform() | |
# Preprocess | |
augmented = transform(image=image) | |
image_tensor = augmented['image'].unsqueeze(0).to(device) | |
# Inference | |
start_time = time.time() | |
outputs = model(image_tensor) | |
probs = F.softmax(outputs, dim=1) | |
confidence, predicted = torch.max(probs, 1) | |
inference_time = (time.time() - start_time) * 1000 | |
return { | |
'prediction': 'DEFECTIVE' if predicted.item() == 1 else 'NORMAL', | |
'confidence': confidence.item(), | |
'defect_probability': probs[0][1].item(), | |
'normal_probability': probs[0][0].item(), | |
'inference_time': inference_time | |
} | |
def run_streamlit(): | |
"""Run Streamlit web interface with clean Rocket Lab theme""" | |
import sys | |
import subprocess | |
import os | |
# If not running through streamlit, restart with streamlit | |
if "streamlit.runtime.scriptrunner" not in sys.modules: | |
print("π Starting Rocket Lab Defect Detection System...") | |
print(" Opening browser to http://localhost:8501") | |
subprocess.run(["streamlit", "run", __file__]) | |
return | |
import streamlit as st | |
# Rocket Lab themed configuration - with permanent sidebar | |
st.set_page_config( | |
page_title="RKLB Defect Detection", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" # Make sure sidebar is expanded | |
) | |
# Custom CSS with professional theme and permanent sidebar | |
st.markdown(""" | |
<style> | |
/* Force sidebar to always be visible and expanded */ | |
section[data-testid="stSidebar"] { | |
background: #0f0f0f !important; | |
border-right: 2px solid #333; | |
width: 21rem !important; | |
min-width: 21rem !important; | |
max-width: 21rem !important; | |
display: block !important; | |
position: relative !important; | |
left: 0 !important; | |
visibility: visible !important; | |
opacity: 1 !important; | |
transform: none !important; | |
} | |
/* Hide sidebar collapse button completely */ | |
[data-testid="collapsedControl"] { | |
display: none !important; | |
} | |
button[kind="header"] { | |
display: none !important; | |
} | |
/* Hide the hamburger menu button */ | |
[data-testid="baseButton-header"] { | |
display: none !important; | |
} | |
/* Ensure sidebar content is always visible */ | |
section[data-testid="stSidebar"] > div { | |
display: block !important; | |
visibility: visible !important; | |
opacity: 1 !important; | |
} | |
section[data-testid="stSidebar"] > div:first-child { | |
padding-top: 2rem; | |
} | |
/* Clean dark background */ | |
.stApp { | |
background: #0a0a0a; | |
} | |
/* Hide the link icon buttons next to headers */ | |
[data-testid="StyledLinkIconContainer"] { | |
display: none !important; | |
} | |
/* Hide anchor links in headers */ | |
.stMarkdown h1 a, .stMarkdown h2 a, .stMarkdown h3 a { | |
display: none !important; | |
} | |
/* Hide buttons that appear on hover for headers */ | |
.element-container:has(.stMarkdown h1, .stMarkdown h2, .stMarkdown h3) button[kind="headerLink"] { | |
display: none !important; | |
} | |
/* Hide all header link anchors */ | |
[data-testid="stHeaderActionElements"] { | |
display: none !important; | |
} | |
/* Hide copy buttons and link buttons */ | |
.stMarkdown [data-testid="stCopyButton"], | |
.stMarkdown button[title*="link"] { | |
display: none !important; | |
} | |
/* Main header - Professional Rocket Lab style */ | |
.main-header { | |
padding: 15px 0 20px 0; | |
border-bottom: 3px solid #dc2626; | |
margin-bottom: 30px; | |
background: linear-gradient(135deg, #1a1a1a 0%, #0f0f0f 100%); | |
margin-top: -3.5rem; /* Position header higher */ | |
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.1); | |
} | |
.main-header h1 { | |
color: #ffffff; | |
font-size: 1.9rem; | |
letter-spacing: 2px; | |
margin: 0; | |
font-weight: 700; /* Bold font weight */ | |
text-align: left; | |
padding-left: 30px; | |
text-transform: uppercase; | |
} | |
.main-header .rocket-white { | |
color: #ffffff; | |
font-weight: 700; /* Bold for ROCKET LAB */ | |
font-size: inherit; | |
letter-spacing: inherit; | |
} | |
.main-header .rocket-red { | |
color: #dc2626; | |
font-weight: 800; /* Extra bold for emphasis */ | |
font-size: inherit; | |
letter-spacing: inherit; | |
} | |
.subtitle { | |
color: #aaa; | |
font-size: 0.8rem; | |
letter-spacing: 1.5px; | |
margin-top: 10px; | |
padding-left: 30px; | |
font-weight: 400; | |
text-transform: uppercase; | |
} | |
/* Sidebar styling */ | |
.sidebar-header { | |
font-size: 1.1rem; | |
font-weight: 600; | |
color: #dc2626; | |
letter-spacing: 1.5px; | |
text-transform: uppercase; | |
margin-bottom: 15px; | |
padding-bottom: 10px; | |
border-bottom: 2px solid #333; | |
} | |
/* Sidebar sample images - professional and compact */ | |
.sidebar-sample { | |
background: #1a1a1a; | |
border: 1px solid #333; | |
border-radius: 6px; | |
padding: 8px; | |
margin-bottom: 10px; | |
cursor: pointer; | |
transition: all 0.3s; | |
} | |
.sidebar-sample:hover { | |
border-color: #dc2626; | |
background: #1f1f1f; | |
transform: translateX(3px); | |
} | |
.sample-label { | |
color: #bbb; | |
font-size: 0.75rem; | |
text-align: center; | |
margin-top: 8px; | |
margin-bottom: 5px; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
font-weight: 500; | |
} | |
/* Instruction button */ | |
.instruction-btn { | |
background: transparent; | |
border: 1px solid #dc2626; | |
color: #dc2626; | |
padding: 8px 16px; | |
font-size: 0.85rem; | |
letter-spacing: 1px; | |
text-transform: uppercase; | |
border-radius: 4px; | |
cursor: pointer; | |
transition: all 0.3s; | |
margin-bottom: 15px; | |
} | |
.instruction-btn:hover { | |
background: #dc2626; | |
color: white; | |
} | |
/* Result box */ | |
.result-box { | |
background: #1a1a1a; | |
border-radius: 8px; | |
padding: 30px; | |
margin: 20px 0; | |
text-align: center; | |
} | |
.result-pass { | |
border: 2px solid #10b981; | |
} | |
.result-fail { | |
border: 2px solid #dc2626; | |
} | |
.result-title { | |
font-size: 1.8rem; | |
margin: 0; | |
font-weight: 300; | |
} | |
.result-confidence { | |
font-size: 2.5rem; | |
margin: 15px 0; | |
font-weight: bold; | |
} | |
/* Metrics row */ | |
.metrics-row { | |
display: flex; | |
justify-content: center; | |
gap: 40px; | |
margin: 20px 0; | |
} | |
.metric { | |
text-align: center; | |
} | |
.metric-label { | |
color: #888; | |
font-size: 0.8rem; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
} | |
.metric-value { | |
color: #ffffff; | |
font-size: 1.2rem; | |
font-weight: bold; | |
margin-top: 5px; | |
} | |
/* Upload area - more subtle and professional */ | |
.upload-section { | |
background: #141414; | |
border: 1px solid #2a2a2a; | |
border-radius: 8px; | |
padding: 20px; | |
text-align: center; | |
margin: 15px 0; | |
} | |
/* Upload section header - smaller and professional */ | |
.section-header { | |
font-size: 0.95rem; | |
font-weight: 600; | |
color: #ffffff; | |
text-transform: uppercase; | |
letter-spacing: 1.5px; | |
margin-bottom: 15px; | |
padding-bottom: 10px; | |
border-bottom: 1px solid #333; | |
} | |
/* Buttons - professional style */ | |
.stButton > button { | |
background: linear-gradient(135deg, #dc2626, #b91c1c); | |
color: white; | |
border: none; | |
padding: 10px 24px; | |
font-size: 0.85rem; | |
font-weight: 600; | |
letter-spacing: 1.2px; | |
text-transform: uppercase; | |
border-radius: 4px; | |
width: 100%; | |
transition: all 0.3s; | |
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.2); | |
} | |
.stButton > button:hover { | |
background: linear-gradient(135deg, #ef4444, #dc2626); | |
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.3); | |
transform: translateY(-1px); | |
} | |
/* Hide Streamlit branding */ | |
#MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
header {visibility: hidden;} | |
/* Clean file uploader - professional styling */ | |
[data-testid="stFileUploader"] { | |
background: transparent; | |
border: none; | |
} | |
[data-testid="stFileUploader"] label { | |
font-size: 0.85rem !important; | |
font-weight: 500 !important; | |
color: #999 !important; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
} | |
.uploadedFile { | |
background: #1a1a1a; | |
border: 1px solid #333; | |
border-radius: 4px; | |
padding: 8px; | |
} | |
/* Text colors and typography */ | |
p, span, div { | |
color: #ffffff; | |
} | |
label { | |
color: #bbb; | |
font-weight: 500; | |
} | |
/* Streamlit section headers */ | |
.stMarkdown h3 { | |
font-size: 0.95rem !important; | |
font-weight: 600 !important; | |
color: #ffffff !important; | |
text-transform: uppercase; | |
letter-spacing: 1.5px; | |
margin-bottom: 15px !important; | |
padding-bottom: 10px; | |
border-bottom: 1px solid #333; | |
} | |
/* Progress bars minimal */ | |
.stProgress > div > div > div > div { | |
background: #dc2626; | |
height: 4px; | |
} | |
/* Status badge */ | |
.status-badge { | |
display: inline-block; | |
padding: 4px 12px; | |
border-radius: 20px; | |
font-size: 0.8rem; | |
letter-spacing: 1px; | |
text-transform: uppercase; | |
font-weight: bold; | |
} | |
.status-pass { | |
background: #10b981; | |
color: white; | |
} | |
.status-fail { | |
background: #dc2626; | |
color: white; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Load model | |
try: | |
model, device, model_info = load_model() | |
except Exception as e: | |
st.error(f"Model Error: {e}") | |
st.stop() | |
# Header - Professional Rocket Lab style | |
st.markdown(""" | |
<div class="main-header"> | |
<h1><span class="rocket-white">ROCKET LAB</span> <span class="rocket-red">COMPONENT DEFECT DETECTION</span></h1> | |
<div class="subtitle">Made by Gary Phua</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Sidebar with sample images - ensure it's visible | |
with st.sidebar: | |
# Professional sidebar header | |
st.markdown(""" | |
<div class="sidebar-header">Test Samples</div> | |
<div style='color: #999; font-size: 0.75rem; margin-bottom: 20px; text-transform: uppercase; letter-spacing: 1px;'> | |
Click to load sample image | |
</div> | |
""", unsafe_allow_html=True) | |
# Get example images | |
examples_dir = Path("examples") | |
sample_images = [] | |
if examples_dir.exists(): | |
normal_samples = sorted((examples_dir / "normal").glob("*.png")) | |
defect_samples = sorted((examples_dir / "defective").glob("*.png")) | |
# Select samples to ensure variety | |
if len(normal_samples) >= 1: | |
sample_images.append(normal_samples[0]) # First normal | |
if len(defect_samples) >= 1: | |
sample_images.append(defect_samples[0]) # Defective | |
if len(normal_samples) >= 2: | |
sample_images.append(normal_samples[-1]) # Last normal | |
# Display sample images in sidebar | |
if sample_images: | |
for idx, sample_path in enumerate(sample_images): | |
# Load and create small thumbnail | |
img = Image.open(sample_path) | |
img_thumbnail = img.resize((120, 120), Image.Resampling.LANCZOS) | |
# Professional label based on type | |
if "defect" in str(sample_path).lower(): | |
label = f"Sample {idx + 1}" | |
else: | |
label = f"Sample {idx + 1}" | |
# Display in sidebar with improved layout | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(img_thumbnail, use_container_width=True) | |
with col2: | |
st.markdown(f"<div class='sample-label'>{label}</div>", unsafe_allow_html=True) | |
if st.button("Load", key=f"sample_{idx}", use_container_width=True): | |
st.session_state['selected_image'] = str(sample_path) | |
st.session_state['image_source'] = 'sample' | |
st.rerun() | |
# Main content area | |
main_container = st.container() | |
with main_container: | |
# Upload section with professional header | |
st.markdown(""" | |
<div class="section-header">Upload Component Image</div> | |
""", unsafe_allow_html=True) | |
uploaded_file = st.file_uploader( | |
"Select image file (PNG, JPG, JPEG, BMP)", | |
type=['png', 'jpg', 'jpeg', 'bmp'], | |
label_visibility="visible" | |
) | |
# Process image | |
image = None | |
image_np = None | |
image_name = None | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
image_name = uploaded_file.name | |
st.session_state['image_source'] = 'upload' | |
st.session_state['selected_image'] = None | |
elif 'selected_image' in st.session_state and st.session_state['selected_image']: | |
image = Image.open(st.session_state['selected_image']) | |
image_name = Path(st.session_state['selected_image']).name | |
# Results section positioned below upload | |
if image: | |
image_np = np.array(image.convert('RGB')) | |
# Display image preview with professional header | |
st.markdown(""" | |
<div class="section-header">Analysis Results</div> | |
""", unsafe_allow_html=True) | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(image, use_container_width=True) | |
with col2: | |
# Run prediction | |
with st.spinner("Analyzing component..."): | |
result = predict_image(image_np, model) | |
# Result display - professional layout | |
st.markdown(""" | |
<div style="font-size: 0.9rem; font-weight: 600; color: #999; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 15px;">Quality Assessment</div> | |
""", unsafe_allow_html=True) | |
if result['prediction'] == 'DEFECTIVE': | |
st.markdown(""" | |
<div class="result-box result-fail"> | |
<div class="status-badge status-fail">DEFECT DETECTED</div> | |
<div style="margin-top: 20px; color: #dc2626; font-size: 1.2rem;">Component Failed Quality Check</div> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.markdown(""" | |
<div class="result-box result-pass"> | |
<div class="status-badge status-pass">PASSED</div> | |
<div style="margin-top: 20px; color: #10b981; font-size: 1.2rem;">Component Passed Quality Check</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Metrics | |
st.markdown(""" | |
<div class="metrics-row"> | |
<div class="metric"> | |
<div class="metric-label">Confidence</div> | |
<div class="metric-value">{:.1f}%</div> | |
</div> | |
<div class="metric"> | |
<div class="metric-label">Processing Time</div> | |
<div class="metric-value">{:.0f}ms</div> | |
</div> | |
</div> | |
""".format(result['confidence'] * 100, result['inference_time']), unsafe_allow_html=True) | |
else: | |
# Empty state - professional | |
st.markdown(""" | |
<div style="text-align: center; padding: 80px 40px; background: #141414; border: 1px solid #2a2a2a; border-radius: 8px; margin-top: 40px;"> | |
<div style="font-size: 1.1rem; margin-bottom: 15px; color: #999; font-weight: 600; text-transform: uppercase; letter-spacing: 1.5px;">Ready for Analysis</div> | |
<div style="font-size: 0.85rem; color: #666; line-height: 1.6;">Upload a component image or select a sample from the sidebar to begin quality inspection</div> | |
</div> | |
""", unsafe_allow_html=True) | |
def run_api(): | |
"""Run FastAPI server""" | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
import uvicorn | |
app = FastAPI( | |
title="Defect Detection API", | |
version="1.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class PredictionRequest(BaseModel): | |
image: str # base64 encoded | |
async def startup(): | |
try: | |
load_model() | |
print("β Model loaded successfully") | |
except Exception as e: | |
print(f"β Model loading failed: {e}") | |
async def root(): | |
return { | |
"message": "Defect Detection API", | |
"endpoints": { | |
"health": "/health", | |
"predict": "/predict", | |
"interface": "/interface" | |
} | |
} | |
async def health(): | |
return {"status": "healthy", "model_loaded": _model_cache["model"] is not None} | |
async def predict(request: PredictionRequest): | |
if _model_cache["model"] is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
try: | |
# Decode image | |
image_bytes = base64.b64decode(request.image) | |
image = Image.open(io.BytesIO(image_bytes)) | |
image_np = np.array(image.convert('RGB')) | |
# Predict | |
result = predict_image(image_np, _model_cache["model"]) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def interface(): | |
html = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>RKLB Defect Detection</title> | |
<style> | |
* { margin: 0; padding: 0; box-sizing: border-box; } | |
body { | |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
background: #0a0a0a; | |
color: white; | |
min-height: 100vh; | |
padding: 40px 20px; | |
} | |
.container { | |
max-width: 900px; | |
margin: 0 auto; | |
} | |
h1 { | |
text-align: center; | |
font-size: 1.5rem; | |
font-weight: 300; | |
letter-spacing: 3px; | |
margin-bottom: 10px; | |
padding-bottom: 20px; | |
border-bottom: 2px solid #dc2626; | |
} | |
.subtitle { | |
text-align: center; | |
color: #999; | |
font-size: 0.7rem; | |
letter-spacing: 1px; | |
font-style: italic; | |
margin-bottom: 40px; | |
} | |
.upload-area { | |
border: 2px dashed #333; | |
padding: 40px; | |
text-align: center; | |
background: #1a1a1a; | |
border-radius: 8px; | |
margin: 30px 0; | |
} | |
.result { | |
margin: 30px 0; | |
padding: 30px; | |
border-radius: 8px; | |
background: #1a1a1a; | |
text-align: center; | |
} | |
.result-pass { border: 2px solid #10b981; } | |
.result-fail { border: 2px solid #dc2626; } | |
button { | |
background: #dc2626; | |
color: white; | |
padding: 10px 30px; | |
border: none; | |
border-radius: 4px; | |
cursor: pointer; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
font-size: 0.9rem; | |
} | |
button:hover { background: #b91c1c; } | |
#preview img { | |
max-width: 400px; | |
max-height: 400px; | |
margin: 20px auto; | |
display: block; | |
border: 1px solid #333; | |
border-radius: 8px; | |
} | |
.confidence { | |
font-size: 2.5rem; | |
font-weight: bold; | |
margin: 20px 0; | |
} | |
.status { | |
display: inline-block; | |
padding: 4px 12px; | |
border-radius: 20px; | |
font-size: 0.8rem; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
font-weight: bold; | |
margin-bottom: 10px; | |
} | |
.status-pass { background: #10b981; } | |
.status-fail { background: #dc2626; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>ROCKET LAB <span style="color: #dc2626;">COMPONENT DEFECT DETECTION SYSTEM</span></h1> | |
<p class="subtitle">Made by Gary Phua</p> | |
<div class="upload-area"> | |
<input type="file" id="imageInput" accept="image/*" style="margin-bottom: 20px;"> | |
<br> | |
<button onclick="analyze()">Analyze Component</button> | |
</div> | |
<div id="preview"></div> | |
<div id="result"></div> | |
</div> | |
<script> | |
function analyze() { | |
const input = document.getElementById('imageInput'); | |
const file = input.files[0]; | |
if (!file) return alert('Select an image'); | |
const reader = new FileReader(); | |
reader.onload = e => { | |
document.getElementById('preview').innerHTML = '<img src="' + e.target.result + '">'; | |
const base64 = e.target.result.split(',')[1]; | |
document.getElementById('result').innerHTML = '<div class="result">Analyzing...</div>'; | |
fetch('/predict', { | |
method: 'POST', | |
headers: {'Content-Type': 'application/json'}, | |
body: JSON.stringify({image: base64}) | |
}) | |
.then(r => r.json()) | |
.then(data => { | |
const passClass = data.prediction === 'DEFECTIVE' ? 'result-fail' : 'result-pass'; | |
const statusClass = data.prediction === 'DEFECTIVE' ? 'status-fail' : 'status-pass'; | |
const statusText = data.prediction === 'DEFECTIVE' ? 'DEFECT DETECTED' : 'PASSED'; | |
document.getElementById('result').innerHTML = | |
'<div class="result ' + passClass + '">' + | |
'<div class="status ' + statusClass + '">' + statusText + '</div>' + | |
'<div class="confidence">' + (data.confidence * 100).toFixed(1) + '%</div>' + | |
'<div style="color: #888;">CONFIDENCE</div>' + | |
'<div style="margin-top: 20px; color: #888; font-size: 0.9rem;">' + | |
'Time: ' + data.inference_time.toFixed(0) + 'ms</div>' + | |
'</div>'; | |
}); | |
}; | |
reader.readAsDataURL(file); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
return HTMLResponse(content=html) | |
# For Vercel deployment | |
if os.environ.get('VERCEL'): | |
return app | |
# Local server | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
def run_cli(args): | |
"""Run command-line interface""" | |
model, device, info = load_model(args.model) | |
print(f"β Model loaded: {info['model_type']} (Acc: {info['accuracy']:.1f}%)") | |
if args.image: | |
# Single image prediction | |
image = cv2.imread(args.image) | |
if image is None: | |
print(f"β Cannot load image: {args.image}") | |
return | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
result = predict_image(image, model) | |
print(f"\nπ· Image: {args.image}") | |
print(f"π― Prediction: {result['prediction']}") | |
print(f"π Confidence: {result['confidence']:.2%}") | |
print(f"β±οΈ Inference: {result['inference_time']:.1f}ms") | |
elif args.directory: | |
# Batch prediction | |
from pathlib import Path | |
results = [] | |
for img_path in Path(args.directory).glob("**/*"): | |
if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp']: | |
image = cv2.imread(str(img_path)) | |
if image is not None: | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
result = predict_image(image, model) | |
result['path'] = str(img_path) | |
results.append(result) | |
print(f"{'π΄' if result['prediction'] == 'DEFECTIVE' else 'π’'} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})") | |
# Summary | |
defective = sum(1 for r in results if r['prediction'] == 'DEFECTIVE') | |
print(f"\nπ Results: {defective}/{len(results)} defective ({defective/len(results)*100:.1f}%)") | |
if args.output: | |
with open(args.output, 'w') as f: | |
json.dump(results, f, indent=2) | |
print(f"πΎ Saved to {args.output}") | |
def main(): | |
# Check if running through streamlit | |
import sys | |
if "streamlit.runtime.scriptrunner" in sys.modules: | |
run_streamlit() | |
return | |
parser = argparse.ArgumentParser(description='Defect Detection Application') | |
parser.add_argument('--mode', choices=['web', 'api', 'cli'], default='web', | |
help='Run mode: web (Streamlit), api (FastAPI), or cli') | |
parser.add_argument('--model', type=str, help='Model path') | |
parser.add_argument('--image', type=str, help='Single image path (CLI mode)') | |
parser.add_argument('--directory', type=str, help='Directory of images (CLI mode)') | |
parser.add_argument('--output', type=str, help='Save results to JSON (CLI mode)') | |
args = parser.parse_args() | |
if args.mode == 'web': | |
run_streamlit() | |
elif args.mode == 'api': | |
run_api() | |
else: | |
run_cli(args) | |
# For Vercel deployment | |
app = None | |
if os.environ.get('VERCEL'): | |
from fastapi import FastAPI | |
# Return the FastAPI app for Vercel | |
app = run_api() | |
if __name__ == "__main__": | |
main() |