rklb_materials / app.py
gphua1's picture
Deploy app without large model - will download from HF Hub
5473ddb
"""
Production Defect Detection Application
Supports both API and Web Interface modes
"""
import os
import io
import sys
import base64
import json
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import argparse
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Model imports
from models.vision_transformer import get_model
# Global model cache
_model_cache = {"model": None, "device": None, "transform": None}
def get_transform():
"""Get image preprocessing transform"""
return A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
def load_model(model_path: Optional[str] = None) -> Tuple[torch.nn.Module, torch.device, dict]:
"""Load model with caching and automatic download from HuggingFace Hub"""
if _model_cache["model"] is not None:
return _model_cache["model"], _model_cache["device"], _model_cache.get("info", {})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Check for local model first (for development)
if model_path is None:
local_model_path = Path("models/best_model.pth")
if local_model_path.exists():
model_path = str(local_model_path)
print(f"πŸ“‚ Using local model: {model_path}")
# If no local model, download from HuggingFace Hub
if not model_path or not Path(model_path).exists():
print("πŸ“₯ Downloading model from HuggingFace Hub...")
print(" Model size: ~1.1GB (this will only happen once)")
try:
from huggingface_hub import hf_hub_download
# Try to download from model repository
model_repo = "gphua1/rklb-defect-model"
# IMPORTANT: Use /tmp directory on HuggingFace Spaces to avoid repo size limits
# /tmp is ephemeral but doesn't count against Space storage
if os.environ.get('SPACE_ID'):
cache_dir = Path("/tmp/model_cache")
else:
cache_dir = Path("models")
cache_dir.mkdir(parents=True, exist_ok=True)
# Download the model file to temp directory
model_path = hf_hub_download(
repo_id=model_repo,
filename="best_model.pth",
cache_dir=str(cache_dir),
local_files_only=False, # Allow downloading
resume_download=True, # Resume if interrupted
local_dir=str(cache_dir) if os.environ.get('SPACE_ID') else None,
local_dir_use_symlinks=False
)
print(f"βœ… Model downloaded to temporary cache")
except Exception as e:
error_msg = f"Failed to download model from HuggingFace Hub: {str(e)}"
print(f"❌ {error_msg}")
# Provide helpful error messages
if "401" in str(e) or "Repository Not Found" in str(e):
print("\n⚠️ Model repository not found or not accessible")
print(" Upload your model manually to: https://huggingface.co/gphua1/rklb-defect-model")
print(" 1. Go to https://huggingface.co/new")
print(" 2. Create repo named 'rklb-defect-model'")
print(" 3. Upload models/best_model.pth file")
elif "Connection" in str(e):
print("\n⚠️ Network connection issue. Please check your internet connection.")
raise RuntimeError(error_msg)
# Load checkpoint
try:
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model_type = checkpoint.get('model_type', 'efficient_vit')
# Create and load model
model = get_model(model_type, num_classes=2, pretrained=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
model_info = {
'model_type': model_type,
'accuracy': checkpoint.get('best_acc', checkpoint.get('accuracy', 0)),
'model_path': model_path,
'device': str(device)
}
# Cache model
_model_cache["model"] = model
_model_cache["device"] = device
_model_cache["transform"] = get_transform()
_model_cache["info"] = model_info
print(f"βœ… Model loaded: {model_type} (Accuracy: {model_info['accuracy']:.1f}%)")
print(f" Device: {device}")
return model, device, model_info
except Exception as e:
raise RuntimeError(f"Failed to load model checkpoint: {str(e)}")
@torch.no_grad()
def predict_image(image: np.ndarray, model=None) -> Dict:
"""Predict defect in image"""
if model is None:
model, device, _ = load_model()
else:
device = next(model.parameters()).device
transform = _model_cache.get("transform") or get_transform()
# Preprocess
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
# Inference
start_time = time.time()
outputs = model(image_tensor)
probs = F.softmax(outputs, dim=1)
confidence, predicted = torch.max(probs, 1)
inference_time = (time.time() - start_time) * 1000
return {
'prediction': 'DEFECTIVE' if predicted.item() == 1 else 'NORMAL',
'confidence': confidence.item(),
'defect_probability': probs[0][1].item(),
'normal_probability': probs[0][0].item(),
'inference_time': inference_time
}
def run_streamlit():
"""Run Streamlit web interface with clean Rocket Lab theme"""
import sys
import subprocess
import os
# If not running through streamlit, restart with streamlit
if "streamlit.runtime.scriptrunner" not in sys.modules:
print("πŸš€ Starting Rocket Lab Defect Detection System...")
print(" Opening browser to http://localhost:8501")
subprocess.run(["streamlit", "run", __file__])
return
import streamlit as st
# Rocket Lab themed configuration - with permanent sidebar
st.set_page_config(
page_title="RKLB Defect Detection",
page_icon="πŸš€",
layout="wide",
initial_sidebar_state="expanded" # Make sure sidebar is expanded
)
# Custom CSS with professional theme and permanent sidebar
st.markdown("""
<style>
/* Force sidebar to always be visible and expanded */
section[data-testid="stSidebar"] {
background: #0f0f0f !important;
border-right: 2px solid #333;
width: 21rem !important;
min-width: 21rem !important;
max-width: 21rem !important;
display: block !important;
position: relative !important;
left: 0 !important;
visibility: visible !important;
opacity: 1 !important;
transform: none !important;
}
/* Hide sidebar collapse button completely */
[data-testid="collapsedControl"] {
display: none !important;
}
button[kind="header"] {
display: none !important;
}
/* Hide the hamburger menu button */
[data-testid="baseButton-header"] {
display: none !important;
}
/* Ensure sidebar content is always visible */
section[data-testid="stSidebar"] > div {
display: block !important;
visibility: visible !important;
opacity: 1 !important;
}
section[data-testid="stSidebar"] > div:first-child {
padding-top: 2rem;
}
/* Clean dark background */
.stApp {
background: #0a0a0a;
}
/* Hide the link icon buttons next to headers */
[data-testid="StyledLinkIconContainer"] {
display: none !important;
}
/* Hide anchor links in headers */
.stMarkdown h1 a, .stMarkdown h2 a, .stMarkdown h3 a {
display: none !important;
}
/* Hide buttons that appear on hover for headers */
.element-container:has(.stMarkdown h1, .stMarkdown h2, .stMarkdown h3) button[kind="headerLink"] {
display: none !important;
}
/* Hide all header link anchors */
[data-testid="stHeaderActionElements"] {
display: none !important;
}
/* Hide copy buttons and link buttons */
.stMarkdown [data-testid="stCopyButton"],
.stMarkdown button[title*="link"] {
display: none !important;
}
/* Main header - Professional Rocket Lab style */
.main-header {
padding: 15px 0 20px 0;
border-bottom: 3px solid #dc2626;
margin-bottom: 30px;
background: linear-gradient(135deg, #1a1a1a 0%, #0f0f0f 100%);
margin-top: -3.5rem; /* Position header higher */
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.1);
}
.main-header h1 {
color: #ffffff;
font-size: 1.9rem;
letter-spacing: 2px;
margin: 0;
font-weight: 700; /* Bold font weight */
text-align: left;
padding-left: 30px;
text-transform: uppercase;
}
.main-header .rocket-white {
color: #ffffff;
font-weight: 700; /* Bold for ROCKET LAB */
font-size: inherit;
letter-spacing: inherit;
}
.main-header .rocket-red {
color: #dc2626;
font-weight: 800; /* Extra bold for emphasis */
font-size: inherit;
letter-spacing: inherit;
}
.subtitle {
color: #aaa;
font-size: 0.8rem;
letter-spacing: 1.5px;
margin-top: 10px;
padding-left: 30px;
font-weight: 400;
text-transform: uppercase;
}
/* Sidebar styling */
.sidebar-header {
font-size: 1.1rem;
font-weight: 600;
color: #dc2626;
letter-spacing: 1.5px;
text-transform: uppercase;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 2px solid #333;
}
/* Sidebar sample images - professional and compact */
.sidebar-sample {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 6px;
padding: 8px;
margin-bottom: 10px;
cursor: pointer;
transition: all 0.3s;
}
.sidebar-sample:hover {
border-color: #dc2626;
background: #1f1f1f;
transform: translateX(3px);
}
.sample-label {
color: #bbb;
font-size: 0.75rem;
text-align: center;
margin-top: 8px;
margin-bottom: 5px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/* Instruction button */
.instruction-btn {
background: transparent;
border: 1px solid #dc2626;
color: #dc2626;
padding: 8px 16px;
font-size: 0.85rem;
letter-spacing: 1px;
text-transform: uppercase;
border-radius: 4px;
cursor: pointer;
transition: all 0.3s;
margin-bottom: 15px;
}
.instruction-btn:hover {
background: #dc2626;
color: white;
}
/* Result box */
.result-box {
background: #1a1a1a;
border-radius: 8px;
padding: 30px;
margin: 20px 0;
text-align: center;
}
.result-pass {
border: 2px solid #10b981;
}
.result-fail {
border: 2px solid #dc2626;
}
.result-title {
font-size: 1.8rem;
margin: 0;
font-weight: 300;
}
.result-confidence {
font-size: 2.5rem;
margin: 15px 0;
font-weight: bold;
}
/* Metrics row */
.metrics-row {
display: flex;
justify-content: center;
gap: 40px;
margin: 20px 0;
}
.metric {
text-align: center;
}
.metric-label {
color: #888;
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 1px;
}
.metric-value {
color: #ffffff;
font-size: 1.2rem;
font-weight: bold;
margin-top: 5px;
}
/* Upload area - more subtle and professional */
.upload-section {
background: #141414;
border: 1px solid #2a2a2a;
border-radius: 8px;
padding: 20px;
text-align: center;
margin: 15px 0;
}
/* Upload section header - smaller and professional */
.section-header {
font-size: 0.95rem;
font-weight: 600;
color: #ffffff;
text-transform: uppercase;
letter-spacing: 1.5px;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 1px solid #333;
}
/* Buttons - professional style */
.stButton > button {
background: linear-gradient(135deg, #dc2626, #b91c1c);
color: white;
border: none;
padding: 10px 24px;
font-size: 0.85rem;
font-weight: 600;
letter-spacing: 1.2px;
text-transform: uppercase;
border-radius: 4px;
width: 100%;
transition: all 0.3s;
box-shadow: 0 2px 8px rgba(220, 38, 38, 0.2);
}
.stButton > button:hover {
background: linear-gradient(135deg, #ef4444, #dc2626);
box-shadow: 0 4px 12px rgba(220, 38, 38, 0.3);
transform: translateY(-1px);
}
/* Hide Streamlit branding */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* Clean file uploader - professional styling */
[data-testid="stFileUploader"] {
background: transparent;
border: none;
}
[data-testid="stFileUploader"] label {
font-size: 0.85rem !important;
font-weight: 500 !important;
color: #999 !important;
text-transform: uppercase;
letter-spacing: 1px;
}
.uploadedFile {
background: #1a1a1a;
border: 1px solid #333;
border-radius: 4px;
padding: 8px;
}
/* Text colors and typography */
p, span, div {
color: #ffffff;
}
label {
color: #bbb;
font-weight: 500;
}
/* Streamlit section headers */
.stMarkdown h3 {
font-size: 0.95rem !important;
font-weight: 600 !important;
color: #ffffff !important;
text-transform: uppercase;
letter-spacing: 1.5px;
margin-bottom: 15px !important;
padding-bottom: 10px;
border-bottom: 1px solid #333;
}
/* Progress bars minimal */
.stProgress > div > div > div > div {
background: #dc2626;
height: 4px;
}
/* Status badge */
.status-badge {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
letter-spacing: 1px;
text-transform: uppercase;
font-weight: bold;
}
.status-pass {
background: #10b981;
color: white;
}
.status-fail {
background: #dc2626;
color: white;
}
</style>
""", unsafe_allow_html=True)
# Load model
try:
model, device, model_info = load_model()
except Exception as e:
st.error(f"Model Error: {e}")
st.stop()
# Header - Professional Rocket Lab style
st.markdown("""
<div class="main-header">
<h1><span class="rocket-white">ROCKET LAB</span> <span class="rocket-red">COMPONENT DEFECT DETECTION</span></h1>
<div class="subtitle">Made by Gary Phua</div>
</div>
""", unsafe_allow_html=True)
# Sidebar with sample images - ensure it's visible
with st.sidebar:
# Professional sidebar header
st.markdown("""
<div class="sidebar-header">Test Samples</div>
<div style='color: #999; font-size: 0.75rem; margin-bottom: 20px; text-transform: uppercase; letter-spacing: 1px;'>
Click to load sample image
</div>
""", unsafe_allow_html=True)
# Get example images
examples_dir = Path("examples")
sample_images = []
if examples_dir.exists():
normal_samples = sorted((examples_dir / "normal").glob("*.png"))
defect_samples = sorted((examples_dir / "defective").glob("*.png"))
# Select samples to ensure variety
if len(normal_samples) >= 1:
sample_images.append(normal_samples[0]) # First normal
if len(defect_samples) >= 1:
sample_images.append(defect_samples[0]) # Defective
if len(normal_samples) >= 2:
sample_images.append(normal_samples[-1]) # Last normal
# Display sample images in sidebar
if sample_images:
for idx, sample_path in enumerate(sample_images):
# Load and create small thumbnail
img = Image.open(sample_path)
img_thumbnail = img.resize((120, 120), Image.Resampling.LANCZOS)
# Professional label based on type
if "defect" in str(sample_path).lower():
label = f"Sample {idx + 1}"
else:
label = f"Sample {idx + 1}"
# Display in sidebar with improved layout
col1, col2 = st.columns([1, 2])
with col1:
st.image(img_thumbnail, use_container_width=True)
with col2:
st.markdown(f"<div class='sample-label'>{label}</div>", unsafe_allow_html=True)
if st.button("Load", key=f"sample_{idx}", use_container_width=True):
st.session_state['selected_image'] = str(sample_path)
st.session_state['image_source'] = 'sample'
st.rerun()
# Main content area
main_container = st.container()
with main_container:
# Upload section with professional header
st.markdown("""
<div class="section-header">Upload Component Image</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Select image file (PNG, JPG, JPEG, BMP)",
type=['png', 'jpg', 'jpeg', 'bmp'],
label_visibility="visible"
)
# Process image
image = None
image_np = None
image_name = None
if uploaded_file:
image = Image.open(uploaded_file)
image_name = uploaded_file.name
st.session_state['image_source'] = 'upload'
st.session_state['selected_image'] = None
elif 'selected_image' in st.session_state and st.session_state['selected_image']:
image = Image.open(st.session_state['selected_image'])
image_name = Path(st.session_state['selected_image']).name
# Results section positioned below upload
if image:
image_np = np.array(image.convert('RGB'))
# Display image preview with professional header
st.markdown("""
<div class="section-header">Analysis Results</div>
""", unsafe_allow_html=True)
col1, col2 = st.columns([1, 2])
with col1:
st.image(image, use_container_width=True)
with col2:
# Run prediction
with st.spinner("Analyzing component..."):
result = predict_image(image_np, model)
# Result display - professional layout
st.markdown("""
<div style="font-size: 0.9rem; font-weight: 600; color: #999; text-transform: uppercase; letter-spacing: 1px; margin-bottom: 15px;">Quality Assessment</div>
""", unsafe_allow_html=True)
if result['prediction'] == 'DEFECTIVE':
st.markdown("""
<div class="result-box result-fail">
<div class="status-badge status-fail">DEFECT DETECTED</div>
<div style="margin-top: 20px; color: #dc2626; font-size: 1.2rem;">Component Failed Quality Check</div>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div class="result-box result-pass">
<div class="status-badge status-pass">PASSED</div>
<div style="margin-top: 20px; color: #10b981; font-size: 1.2rem;">Component Passed Quality Check</div>
</div>
""", unsafe_allow_html=True)
# Metrics
st.markdown("""
<div class="metrics-row">
<div class="metric">
<div class="metric-label">Confidence</div>
<div class="metric-value">{:.1f}%</div>
</div>
<div class="metric">
<div class="metric-label">Processing Time</div>
<div class="metric-value">{:.0f}ms</div>
</div>
</div>
""".format(result['confidence'] * 100, result['inference_time']), unsafe_allow_html=True)
else:
# Empty state - professional
st.markdown("""
<div style="text-align: center; padding: 80px 40px; background: #141414; border: 1px solid #2a2a2a; border-radius: 8px; margin-top: 40px;">
<div style="font-size: 1.1rem; margin-bottom: 15px; color: #999; font-weight: 600; text-transform: uppercase; letter-spacing: 1.5px;">Ready for Analysis</div>
<div style="font-size: 0.85rem; color: #666; line-height: 1.6;">Upload a component image or select a sample from the sidebar to begin quality inspection</div>
</div>
""", unsafe_allow_html=True)
def run_api():
"""Run FastAPI server"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import uvicorn
app = FastAPI(
title="Defect Detection API",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class PredictionRequest(BaseModel):
image: str # base64 encoded
@app.on_event("startup")
async def startup():
try:
load_model()
print("βœ… Model loaded successfully")
except Exception as e:
print(f"❌ Model loading failed: {e}")
@app.get("/")
async def root():
return {
"message": "Defect Detection API",
"endpoints": {
"health": "/health",
"predict": "/predict",
"interface": "/interface"
}
}
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": _model_cache["model"] is not None}
@app.post("/predict")
async def predict(request: PredictionRequest):
if _model_cache["model"] is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Decode image
image_bytes = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_bytes))
image_np = np.array(image.convert('RGB'))
# Predict
result = predict_image(image_np, _model_cache["model"])
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/interface")
async def interface():
html = """
<!DOCTYPE html>
<html>
<head>
<title>RKLB Defect Detection</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0a0a0a;
color: white;
min-height: 100vh;
padding: 40px 20px;
}
.container {
max-width: 900px;
margin: 0 auto;
}
h1 {
text-align: center;
font-size: 1.5rem;
font-weight: 300;
letter-spacing: 3px;
margin-bottom: 10px;
padding-bottom: 20px;
border-bottom: 2px solid #dc2626;
}
.subtitle {
text-align: center;
color: #999;
font-size: 0.7rem;
letter-spacing: 1px;
font-style: italic;
margin-bottom: 40px;
}
.upload-area {
border: 2px dashed #333;
padding: 40px;
text-align: center;
background: #1a1a1a;
border-radius: 8px;
margin: 30px 0;
}
.result {
margin: 30px 0;
padding: 30px;
border-radius: 8px;
background: #1a1a1a;
text-align: center;
}
.result-pass { border: 2px solid #10b981; }
.result-fail { border: 2px solid #dc2626; }
button {
background: #dc2626;
color: white;
padding: 10px 30px;
border: none;
border-radius: 4px;
cursor: pointer;
text-transform: uppercase;
letter-spacing: 1px;
font-size: 0.9rem;
}
button:hover { background: #b91c1c; }
#preview img {
max-width: 400px;
max-height: 400px;
margin: 20px auto;
display: block;
border: 1px solid #333;
border-radius: 8px;
}
.confidence {
font-size: 2.5rem;
font-weight: bold;
margin: 20px 0;
}
.status {
display: inline-block;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.8rem;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: bold;
margin-bottom: 10px;
}
.status-pass { background: #10b981; }
.status-fail { background: #dc2626; }
</style>
</head>
<body>
<div class="container">
<h1>ROCKET LAB <span style="color: #dc2626;">COMPONENT DEFECT DETECTION SYSTEM</span></h1>
<p class="subtitle">Made by Gary Phua</p>
<div class="upload-area">
<input type="file" id="imageInput" accept="image/*" style="margin-bottom: 20px;">
<br>
<button onclick="analyze()">Analyze Component</button>
</div>
<div id="preview"></div>
<div id="result"></div>
</div>
<script>
function analyze() {
const input = document.getElementById('imageInput');
const file = input.files[0];
if (!file) return alert('Select an image');
const reader = new FileReader();
reader.onload = e => {
document.getElementById('preview').innerHTML = '<img src="' + e.target.result + '">';
const base64 = e.target.result.split(',')[1];
document.getElementById('result').innerHTML = '<div class="result">Analyzing...</div>';
fetch('/predict', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({image: base64})
})
.then(r => r.json())
.then(data => {
const passClass = data.prediction === 'DEFECTIVE' ? 'result-fail' : 'result-pass';
const statusClass = data.prediction === 'DEFECTIVE' ? 'status-fail' : 'status-pass';
const statusText = data.prediction === 'DEFECTIVE' ? 'DEFECT DETECTED' : 'PASSED';
document.getElementById('result').innerHTML =
'<div class="result ' + passClass + '">' +
'<div class="status ' + statusClass + '">' + statusText + '</div>' +
'<div class="confidence">' + (data.confidence * 100).toFixed(1) + '%</div>' +
'<div style="color: #888;">CONFIDENCE</div>' +
'<div style="margin-top: 20px; color: #888; font-size: 0.9rem;">' +
'Time: ' + data.inference_time.toFixed(0) + 'ms</div>' +
'</div>';
});
};
reader.readAsDataURL(file);
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
# For Vercel deployment
if os.environ.get('VERCEL'):
return app
# Local server
uvicorn.run(app, host="0.0.0.0", port=8000)
def run_cli(args):
"""Run command-line interface"""
model, device, info = load_model(args.model)
print(f"βœ… Model loaded: {info['model_type']} (Acc: {info['accuracy']:.1f}%)")
if args.image:
# Single image prediction
image = cv2.imread(args.image)
if image is None:
print(f"❌ Cannot load image: {args.image}")
return
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
result = predict_image(image, model)
print(f"\nπŸ“· Image: {args.image}")
print(f"🎯 Prediction: {result['prediction']}")
print(f"πŸ“Š Confidence: {result['confidence']:.2%}")
print(f"⏱️ Inference: {result['inference_time']:.1f}ms")
elif args.directory:
# Batch prediction
from pathlib import Path
results = []
for img_path in Path(args.directory).glob("**/*"):
if img_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp']:
image = cv2.imread(str(img_path))
if image is not None:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
result = predict_image(image, model)
result['path'] = str(img_path)
results.append(result)
print(f"{'πŸ”΄' if result['prediction'] == 'DEFECTIVE' else '🟒'} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})")
# Summary
defective = sum(1 for r in results if r['prediction'] == 'DEFECTIVE')
print(f"\nπŸ“Š Results: {defective}/{len(results)} defective ({defective/len(results)*100:.1f}%)")
if args.output:
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f"πŸ’Ύ Saved to {args.output}")
def main():
# Check if running through streamlit
import sys
if "streamlit.runtime.scriptrunner" in sys.modules:
run_streamlit()
return
parser = argparse.ArgumentParser(description='Defect Detection Application')
parser.add_argument('--mode', choices=['web', 'api', 'cli'], default='web',
help='Run mode: web (Streamlit), api (FastAPI), or cli')
parser.add_argument('--model', type=str, help='Model path')
parser.add_argument('--image', type=str, help='Single image path (CLI mode)')
parser.add_argument('--directory', type=str, help='Directory of images (CLI mode)')
parser.add_argument('--output', type=str, help='Save results to JSON (CLI mode)')
args = parser.parse_args()
if args.mode == 'web':
run_streamlit()
elif args.mode == 'api':
run_api()
else:
run_cli(args)
# For Vercel deployment
app = None
if os.environ.get('VERCEL'):
from fastapi import FastAPI
# Return the FastAPI app for Vercel
app = run_api()
if __name__ == "__main__":
main()