Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Camie-Tagger-V2 Application | |
A Streamlit web app for tagging images using an AI model. | |
""" | |
import streamlit as st | |
import os | |
import sys | |
import traceback | |
import tempfile | |
import time | |
import platform | |
import subprocess | |
import webbrowser | |
import glob | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
import json | |
from matplotlib.colors import LinearSegmentedColormap | |
from PIL import Image | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
# Add parent directory to path to allow importing from utils - updated for new structure | |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
# Import utilities | |
from utils.image_processing import process_image, batch_process_images | |
from utils.file_utils import save_tags_to_file, get_default_save_locations | |
from utils.ui_components import display_progress_bar, show_example_images, display_batch_results | |
from utils.onnx_processing import batch_process_images_onnx | |
# Add environment variables for HF Spaces permissions | |
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
os.environ['HF_HOME'] = '/tmp/huggingface' | |
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers' | |
# Fix Streamlit permission issues | |
os.environ['STREAMLIT_SERVER_HEADLESS'] = 'true' | |
os.environ['STREAMLIT_SERVER_ENABLE_CORS'] = 'false' | |
os.environ['STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION'] = 'false' | |
os.environ['STREAMLIT_BROWSER_GATHER_USAGE_STATS'] = 'false' | |
os.environ['STREAMLIT_GLOBAL_DEVELOPMENT_MODE'] = 'false' | |
# Constants - matching your v1 pattern | |
MODEL_REPO = "Camais03/camie-tagger-v2" | |
ONNX_MODEL_FILE = "camie-tagger-v2.onnx" | |
SAFETENSORS_MODEL_FILE = "camie-tagger-v2.safetensors" | |
METADATA_FILE = "camie-tagger-v2-metadata.json" | |
VALIDATION_FILE = "full_validation_results.json" | |
def get_model_files(): | |
"""Download model files from HF Hub and return paths - optimized for HF Spaces""" | |
try: | |
# Use smaller /tmp directory and be more careful with large files | |
cache_dir = "/tmp/hf_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Download metadata first (small file) | |
metadata_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=METADATA_FILE, | |
cache_dir=cache_dir, | |
resume_download=True # Allow resuming if interrupted | |
) | |
# Try streaming download for large ONNX file | |
try: | |
onnx_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=ONNX_MODEL_FILE, | |
cache_dir=cache_dir, | |
resume_download=True, | |
force_download=False # Use cached version if available | |
) | |
except Exception as e: | |
print(f"ONNX download failed: {e}") | |
# Fallback: try direct URL download with requests | |
import requests | |
onnx_url = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{ONNX_MODEL_FILE}" | |
onnx_path = os.path.join(cache_dir, ONNX_MODEL_FILE) | |
print(f"Trying direct download from: {onnx_url}") | |
response = requests.get(onnx_url, stream=True) | |
response.raise_for_status() | |
with open(onnx_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
print(f"Direct download successful: {onnx_path}") | |
# Try optional files | |
try: | |
safetensors_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=SAFETENSORS_MODEL_FILE, | |
cache_dir=cache_dir, | |
resume_download=True | |
) | |
except Exception as e: | |
print(f"SafeTensors model not available: {e}") | |
safetensors_path = None | |
try: | |
validation_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=VALIDATION_FILE, | |
cache_dir=cache_dir, | |
resume_download=True | |
) | |
except Exception as e: | |
print(f"Validation results not available: {e}") | |
validation_path = None | |
return { | |
'onnx_path': onnx_path, | |
'safetensors_path': safetensors_path, | |
'metadata_path': metadata_path, | |
'validation_path': validation_path | |
} | |
except Exception as e: | |
print(f"Failed to download model files: {e}") | |
return None | |
# Define threshold profile descriptions and explanations | |
threshold_profile_descriptions = { | |
"Micro Optimized": "Maximizes micro-averaged F1 score (best for dominant classes). Optimal for overall prediction quality.", | |
"Macro Optimized": "Maximizes macro-averaged F1 score (equal weight to all classes). Better for balanced performance across all tags.", | |
"Balanced": "Provides a trade-off between precision and recall with moderate thresholds. Good general-purpose setting.", | |
"Overall": "Uses a single threshold value across all categories. Simplest approach for consistent behavior.", | |
"Category-specific": "Uses different optimal thresholds for each category. Best for fine-tuning results." | |
} | |
threshold_profile_explanations = { | |
"Micro Optimized": """ | |
### Micro Optimized Profile | |
**Technical definition**: Maximizes micro-averaged F1 score, which calculates metrics globally across all predictions. | |
**When to use**: When you want the best overall accuracy, especially for common tags and dominant categories. | |
**Effects**: | |
- Optimizes performance for the most frequent tags | |
- Gives more weight to categories with many examples (like 'character' and 'general') | |
- Provides higher precision in most common use cases | |
**Performance from validation**: | |
- Micro F1: ~67.3% | |
- Macro F1: ~46.3% | |
- Threshold: ~0.614 | |
""", | |
"Macro Optimized": """ | |
### Macro Optimized Profile | |
**Technical definition**: Maximizes macro-averaged F1 score, which gives equal weight to all categories regardless of size. | |
**When to use**: When balanced performance across all categories is important, including rare tags. | |
**Effects**: | |
- More balanced performance across all tag categories | |
- Better at detecting rare or unusual tags | |
- Generally has lower thresholds than micro-optimized | |
**Performance from validation**: | |
- Micro F1: ~60.9% | |
- Macro F1: ~50.6% | |
- Threshold: ~0.492 | |
""", | |
"Balanced": """ | |
### Balanced Profile | |
**Technical definition**: Same as Micro Optimized but provides a good reference point for manual adjustment. | |
**When to use**: For general-purpose tagging when you don't have specific recall or precision requirements. | |
**Effects**: | |
- Good middle ground between precision and recall | |
- Works well for most common use cases | |
- Default choice for most users | |
**Performance from validation**: | |
- Micro F1: ~67.3% | |
- Macro F1: ~46.3% | |
- Threshold: ~0.614 | |
""", | |
"Overall": """ | |
### Overall Profile | |
**Technical definition**: Uses a single threshold value across all categories. | |
**When to use**: When you want consistent behavior across all categories and a simple approach. | |
**Effects**: | |
- Consistent tagging threshold for all categories | |
- Simpler to understand than category-specific thresholds | |
- User-adjustable with a single slider | |
**Default threshold value**: 0.5 (user-adjustable) | |
**Note**: The threshold value is user-adjustable with the slider below. | |
""", | |
"Category-specific": """ | |
### Category-specific Profile | |
**Technical definition**: Uses different optimal thresholds for each category, allowing fine-tuning. | |
**When to use**: When you want to customize tagging sensitivity for different categories. | |
**Effects**: | |
- Each category has its own independent threshold | |
- Full control over category sensitivity | |
- Best for fine-tuning results when some categories need different treatment | |
**Default threshold values**: Starts with balanced thresholds for each category | |
**Note**: Use the category sliders below to adjust thresholds for individual categories. | |
""" | |
} | |
def load_validation_results(results_path): | |
"""Load validation results from JSON file""" | |
try: | |
with open(results_path, 'r') as f: | |
data = json.load(f) | |
return data | |
except Exception as e: | |
print(f"Error loading validation results: {e}") | |
return None | |
def extract_thresholds_from_results(validation_data): | |
"""Extract threshold information from validation results""" | |
if not validation_data or 'results' not in validation_data: | |
return {} | |
thresholds = { | |
'overall': {}, | |
'categories': {} | |
} | |
# Process results to extract thresholds | |
for result in validation_data['results']: | |
category = result['CATEGORY'].lower() | |
profile = result['PROFILE'].lower().replace(' ', '_') | |
threshold = result['THRESHOLD'] | |
micro_f1 = result['MICRO-F1'] | |
macro_f1 = result['MACRO-F1'] | |
# Map profile names | |
if profile == 'micro_opt': | |
profile = 'micro_optimized' | |
elif profile == 'macro_opt': | |
profile = 'macro_optimized' | |
threshold_info = { | |
'threshold': threshold, | |
'micro_f1': micro_f1, | |
'macro_f1': macro_f1 | |
} | |
if category == 'overall': | |
thresholds['overall'][profile] = threshold_info | |
else: | |
if category not in thresholds['categories']: | |
thresholds['categories'][category] = {} | |
thresholds['categories'][category][profile] = threshold_info | |
return thresholds | |
def load_model_and_metadata(): | |
"""Load model and metadata from HF Hub""" | |
# Download model files | |
model_files = get_model_files() | |
if not model_files: | |
return None, None, {} | |
model_info = { | |
'safetensors_available': model_files['safetensors_path'] is not None, | |
'onnx_available': model_files['onnx_path'] is not None, | |
'validation_results_available': model_files['validation_path'] is not None | |
} | |
# Load metadata | |
metadata = None | |
if model_files['metadata_path']: | |
try: | |
with open(model_files['metadata_path'], 'r') as f: | |
metadata = json.load(f) | |
except Exception as e: | |
print(f"Error loading metadata: {e}") | |
# Load validation results for thresholds | |
thresholds = {} | |
if model_files['validation_path']: | |
validation_data = load_validation_results(model_files['validation_path']) | |
if validation_data: | |
thresholds = extract_thresholds_from_results(validation_data) | |
# Add default thresholds if not available | |
if not thresholds: | |
thresholds = { | |
'overall': { | |
'balanced': {'threshold': 0.5, 'micro_f1': 0, 'macro_f1': 0}, | |
'micro_optimized': {'threshold': 0.6, 'micro_f1': 0, 'macro_f1': 0}, | |
'macro_optimized': {'threshold': 0.4, 'micro_f1': 0, 'macro_f1': 0} | |
}, | |
'categories': {} | |
} | |
# Store file paths in session state for later use | |
st.session_state.model_files = model_files | |
return model_info, metadata, thresholds | |
def load_safetensors_model(safetensors_path, metadata_path): | |
"""Load SafeTensors model""" | |
try: | |
from safetensors.torch import load_file | |
import torch | |
# Load metadata | |
with open(metadata_path, 'r') as f: | |
metadata = json.load(f) | |
# Import the model class (assuming it's available) | |
# You'll need to make sure the ImageTagger class is importable | |
from utils.model_loader import ImageTagger # Update this import | |
model_info = metadata['model_info'] | |
dataset_info = metadata['dataset_info'] | |
# Recreate model architecture | |
model = ImageTagger( | |
total_tags=dataset_info['total_tags'], | |
dataset=None, | |
model_name=model_info['backbone'], | |
num_heads=model_info['num_attention_heads'], | |
dropout=0.0, | |
pretrained=False, | |
tag_context_size=model_info['tag_context_size'], | |
use_gradient_checkpointing=False, | |
img_size=model_info['img_size'] | |
) | |
# Load weights | |
state_dict = load_file(safetensors_path) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model, metadata | |
except Exception as e: | |
raise Exception(f"Failed to load SafeTensors model: {e}") | |
def get_profile_metrics(thresholds, profile_name): | |
"""Extract metrics for the given profile from the thresholds dictionary""" | |
profile_key = None | |
# Map UI-friendly names to internal keys | |
if profile_name == "Micro Optimized": | |
profile_key = "micro_optimized" | |
elif profile_name == "Macro Optimized": | |
profile_key = "macro_optimized" | |
elif profile_name == "Balanced": | |
profile_key = "balanced" | |
elif profile_name in ["Overall", "Category-specific"]: | |
profile_key = "macro_optimized" # Use macro as default for these modes | |
if profile_key and 'overall' in thresholds and profile_key in thresholds['overall']: | |
return thresholds['overall'][profile_key] | |
return None | |
def on_threshold_profile_change(): | |
"""Handle threshold profile changes""" | |
new_profile = st.session_state.threshold_profile | |
# Clear any existing results to prevent UI duplication | |
if hasattr(st.session_state, 'all_probs'): | |
del st.session_state.all_probs | |
if hasattr(st.session_state, 'tags'): | |
del st.session_state.tags | |
if hasattr(st.session_state, 'all_tags'): | |
del st.session_state.all_tags | |
if hasattr(st.session_state, 'thresholds') and hasattr(st.session_state, 'settings'): | |
# Initialize category thresholds if needed | |
if st.session_state.settings['active_category_thresholds'] is None: | |
st.session_state.settings['active_category_thresholds'] = {} | |
current_thresholds = st.session_state.settings['active_category_thresholds'] | |
# Map profile names to keys | |
profile_key = None | |
if new_profile == "Micro Optimized": | |
profile_key = "micro_optimized" | |
elif new_profile == "Macro Optimized": | |
profile_key = "macro_optimized" | |
elif new_profile == "Balanced": | |
profile_key = "balanced" | |
# Update thresholds based on profile | |
if profile_key and 'overall' in st.session_state.thresholds and profile_key in st.session_state.thresholds['overall']: | |
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall'][profile_key]['threshold'] | |
# Set category thresholds if categories exist | |
if hasattr(st.session_state, 'categories'): | |
for category in st.session_state.categories: | |
if category in st.session_state.thresholds['categories'] and profile_key in st.session_state.thresholds['categories'][category]: | |
current_thresholds[category] = st.session_state.thresholds['categories'][category][profile_key]['threshold'] | |
else: | |
current_thresholds[category] = st.session_state.settings['active_threshold'] | |
elif new_profile == "Overall": | |
# Use balanced threshold for Overall profile | |
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: | |
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] | |
else: | |
st.session_state.settings['active_threshold'] = 0.5 | |
# Clear category-specific overrides | |
st.session_state.settings['active_category_thresholds'] = {} | |
elif new_profile == "Category-specific": | |
# Initialize with balanced thresholds | |
if 'overall' in st.session_state.thresholds and 'balanced' in st.session_state.thresholds['overall']: | |
st.session_state.settings['active_threshold'] = st.session_state.thresholds['overall']['balanced']['threshold'] | |
else: | |
st.session_state.settings['active_threshold'] = 0.5 | |
# Initialize category thresholds if categories exist | |
if hasattr(st.session_state, 'categories'): | |
for category in st.session_state.categories: | |
if category in st.session_state.thresholds['categories'] and 'balanced' in st.session_state.thresholds['categories'][category]: | |
current_thresholds[category] = st.session_state.thresholds['categories'][category]['balanced']['threshold'] | |
else: | |
current_thresholds[category] = st.session_state.settings['active_threshold'] | |
def apply_thresholds(all_probs, threshold_profile, active_threshold, active_category_thresholds, min_confidence, selected_categories): | |
"""Apply thresholds to raw probabilities and return filtered tags""" | |
tags = {} | |
all_tags = [] | |
# Handle None case for active_category_thresholds | |
active_category_thresholds = active_category_thresholds or {} | |
for category, cat_probs in all_probs.items(): | |
# Get the appropriate threshold for this category | |
threshold = active_category_thresholds.get(category, active_threshold) | |
# Filter tags above threshold | |
tags[category] = [(tag, prob) for tag, prob in cat_probs if prob >= threshold] | |
# Add to all_tags if selected | |
if selected_categories.get(category, True): | |
for tag, prob in tags[category]: | |
all_tags.append(tag) | |
return tags, all_tags | |
def image_tagger_app(): | |
"""Main Streamlit application for image tagging.""" | |
st.set_page_config(layout="wide", page_title="Camie Tagger", page_icon="🖼️") | |
st.title("Camie-Tagger-v2 Interface") | |
st.markdown("---") | |
# Prevent UI duplication by using container | |
if 'app_container' not in st.session_state: | |
st.session_state.app_container = True | |
# Initialize settings | |
if 'settings' not in st.session_state: | |
st.session_state.settings = { | |
'show_all_tags': False, | |
'compact_view': True, | |
'min_confidence': 0.01, | |
'threshold_profile': "Macro", | |
'active_threshold': 0.5, | |
'active_category_thresholds': {}, # Initialize as empty dict, not None | |
'selected_categories': {}, | |
'replace_underscores': False | |
} | |
st.session_state.show_profile_help = False | |
# Session state initialization for model | |
if 'model_loaded' not in st.session_state: | |
st.session_state.model_loaded = False | |
st.session_state.model = None | |
st.session_state.thresholds = None | |
st.session_state.metadata = None | |
st.session_state.model_type = "onnx" # Default to ONNX | |
# Sidebar for model selection and information | |
with st.sidebar: | |
# Support information | |
st.subheader("💡 Notes") | |
st.markdown(""" | |
This tagger was trained on a subset of the available data due to hardware limitations. | |
A more comprehensive model trained on the full 3+ million image dataset would provide: | |
- More recent characters and tags. | |
- Improved accuracy. | |
If you find this tool useful and would like to support future development: | |
""") | |
# Add Buy Me a Coffee button with Star of the City-like glow effect | |
st.markdown(""" | |
<style> | |
@keyframes coffee-button-glow { | |
0% { box-shadow: 0 0 5px #FFD700; } | |
50% { box-shadow: 0 0 15px #FFD700; } | |
100% { box-shadow: 0 0 5px #FFD700; } | |
} | |
.coffee-button { | |
display: inline-block; | |
animation: coffee-button-glow 2s infinite; | |
border-radius: 5px; | |
transition: transform 0.3s ease; | |
} | |
.coffee-button:hover { | |
transform: scale(1.05); | |
} | |
</style> | |
<a href="https://ko-fi.com/camais" target="_blank" class="coffee-button"> | |
<img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" | |
alt="Buy Me A Coffee" | |
style="height: 45px; width: 162px; border-radius: 5px;" /> | |
</a> | |
""", unsafe_allow_html=True) | |
st.markdown(""" | |
Your support helps with: | |
- GPU costs for training | |
- Storage for larger datasets | |
- Development of new features | |
- Future projects | |
Thank you! 🙏 | |
Full Details: https://huggingface.co/Camais03/camie-tagger-v2 | |
""") | |
st.header("Model Selection") | |
# Load model information | |
try: | |
with st.spinner("Loading model from HF Hub..."): | |
model_info, metadata, thresholds = load_model_and_metadata() | |
except Exception as e: | |
st.error(f"Failed to load model information: {e}") | |
st.stop() | |
# Check if model info loaded successfully | |
if model_info is None: | |
st.error("Could not download model files from Hugging Face Hub") | |
st.info("Please check your internet connection or try again later") | |
st.stop() | |
# Check if model info loaded successfully | |
if model_info is None: | |
st.error("Could not download model files from Hugging Face Hub") | |
st.info("Please check your internet connection or try again later") | |
st.stop() | |
# Determine available model options | |
model_options = [] | |
if model_info['onnx_available']: | |
model_options.append("ONNX (Recommended)") | |
if model_info['safetensors_available']: | |
model_options.append("SafeTensors (PyTorch)") | |
if not model_options: | |
st.error("No model files found!") | |
st.info("Expected files in Camais03/camie-tagger-v2:") | |
st.info("- camie-tagger-v2.onnx") | |
st.info("- camie-tagger-v2.safetensors") | |
st.info("- camie-tagger-v2-metadata.json") | |
st.stop() | |
# Model type selection | |
default_index = 0 if model_info['onnx_available'] else 0 | |
model_type = st.radio( | |
"Select Model Type:", | |
model_options, | |
index=default_index, | |
help="ONNX: Optimized for speed and compatibility\nSafeTensors: Native PyTorch format" | |
) | |
# Convert selection to internal model type | |
if model_type == "ONNX (Recommended)": | |
selected_model_type = "onnx" | |
else: | |
selected_model_type = "safetensors" | |
# If model type changed, reload | |
if selected_model_type != st.session_state.model_type: | |
st.session_state.model_loaded = False | |
st.session_state.model_type = selected_model_type | |
# Reload button | |
if st.button("Reload Model") and st.session_state.model_loaded: | |
st.session_state.model_loaded = False | |
st.info("Reloading model...") | |
# Try to load the model | |
if not st.session_state.model_loaded: | |
try: | |
with st.spinner(f"Loading {st.session_state.model_type.upper()} model..."): | |
if st.session_state.model_type == "onnx": | |
# Load ONNX model - matching your v1 approach exactly | |
import onnxruntime as ort | |
onnx_path = st.session_state.model_files['onnx_path'] | |
# Initialize ONNX Runtime session (like your v1) | |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) | |
st.session_state.model = session | |
st.session_state.device = "CPU" # Simplified like your v1 | |
st.session_state.param_dtype = "float32" | |
else: | |
# Load SafeTensors model | |
safetensors_path = st.session_state.model_files['safetensors_path'] | |
metadata_path = st.session_state.model_files['metadata_path'] | |
model, loaded_metadata = load_safetensors_model(safetensors_path, metadata_path) | |
st.session_state.model = model | |
device = next(model.parameters()).device | |
param_dtype = next(model.parameters()).dtype | |
st.session_state.device = device | |
st.session_state.param_dtype = param_dtype | |
metadata = loaded_metadata # Use loaded metadata instead | |
# Store common info | |
st.session_state.thresholds = thresholds | |
st.session_state.metadata = metadata | |
st.session_state.model_loaded = True | |
# Get categories | |
if metadata and 'dataset_info' in metadata: | |
tag_mapping = metadata['dataset_info']['tag_mapping'] | |
categories = list(set(tag_mapping['tag_to_category'].values())) | |
st.session_state.categories = categories | |
# Initialize selected categories | |
if not st.session_state.settings['selected_categories']: | |
st.session_state.settings['selected_categories'] = {cat: True for cat in categories} | |
# Set initial threshold from validation results | |
if 'overall' in thresholds and 'macro_optimized' in thresholds['overall']: | |
st.session_state.settings['active_threshold'] = thresholds['overall']['macro_optimized']['threshold'] | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
st.code(traceback.format_exc()) | |
st.stop() | |
# Display model information in sidebar | |
with st.sidebar: | |
st.header("Model Information") | |
if st.session_state.model_loaded: | |
if st.session_state.model_type == "onnx": | |
st.success("Using ONNX Model") | |
else: | |
st.success("Using SafeTensors Model") | |
st.write(f"Device: {st.session_state.device}") | |
st.write(f"Precision: {st.session_state.param_dtype}") | |
if st.session_state.metadata: | |
if 'dataset_info' in st.session_state.metadata: | |
total_tags = st.session_state.metadata['dataset_info']['total_tags'] | |
st.write(f"Total tags: {total_tags}") | |
elif 'total_tags' in st.session_state.metadata: | |
st.write(f"Total tags: {st.session_state.metadata['total_tags']}") | |
# Show categories | |
with st.expander("Available Categories"): | |
if hasattr(st.session_state, 'categories'): | |
for category in sorted(st.session_state.categories): | |
st.write(f"- {category.capitalize()}") | |
else: | |
st.write("Categories will be available after model loads") | |
# About section | |
with st.expander("About this app"): | |
st.write(""" | |
This app uses a trained image tagging model to analyze and tag images. | |
**Model Options**: | |
- **ONNX (Recommended)**: Optimized for inference speed with broad compatibility | |
- **SafeTensors**: Native PyTorch format for advanced users | |
**Features**: | |
- Upload or process images in batches | |
- Multiple threshold profiles based on validation results | |
- Category-specific threshold adjustment | |
- Export tags in various formats | |
- Fast inference with GPU acceleration (when available) | |
**Threshold Profiles**: | |
- **Micro Optimized**: Best overall F1 score (67.3% micro F1) | |
- **Macro Optimized**: Balanced across categories (50.6% macro F1) | |
- **Balanced**: Good general-purpose setting | |
- **Overall**: Single adjustable threshold | |
- **Category-specific**: Fine-tune each category individually | |
""") | |
# Main content area - Image upload and processing | |
col1, col2 = st.columns([1, 1.5]) | |
with col1: | |
st.header("Image") | |
upload_tab, batch_tab = st.tabs(["Upload Image", "Batch Processing"]) | |
image_path = None | |
with upload_tab: | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
# Create temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
image_path = tmp_file.name | |
st.session_state.original_filename = uploaded_file.name | |
# Display image | |
image = Image.open(uploaded_file) | |
st.image(image, use_container_width=True) | |
with batch_tab: | |
st.subheader("Batch Process Images") | |
# Note about batch processing in HF Spaces | |
st.info("Note: Batch processing from local folders is not available in HF Spaces. Use the single image upload instead.") | |
# Folder selection (disabled for HF Spaces) | |
batch_folder = st.text_input("Enter folder path containing images:", "", disabled=True) | |
st.write("For batch processing, please:") | |
st.write("1. Download this code and run locally") | |
st.write("2. Or upload images one by one using the Upload Image tab") | |
# Column 2: Controls and Results | |
with col2: | |
st.header("Tagging Controls") | |
# Only show controls if model is loaded | |
if not st.session_state.model_loaded: | |
st.info("Model loading... Controls will appear once the model is ready.") | |
return | |
# Threshold profile selection | |
all_profiles = [ | |
"Micro Optimized", | |
"Macro Optimized", | |
"Balanced", | |
"Overall", | |
"Category-specific" | |
] | |
profile_col1, profile_col2 = st.columns([3, 1]) | |
with profile_col1: | |
threshold_profile = st.selectbox( | |
"Select threshold profile", | |
options=all_profiles, | |
index=1, # Default to Macro | |
key="threshold_profile", | |
on_change=on_threshold_profile_change | |
) | |
with profile_col2: | |
if st.button("ℹ️ Help", key="profile_help"): | |
st.session_state.show_profile_help = not st.session_state.get('show_profile_help', False) | |
# Show profile help | |
if st.session_state.get('show_profile_help', False): | |
st.markdown(threshold_profile_explanations[threshold_profile]) | |
else: | |
st.info(threshold_profile_descriptions[threshold_profile]) | |
# Show profile metrics if available | |
if st.session_state.model_loaded and hasattr(st.session_state, 'thresholds'): | |
metrics = get_profile_metrics(st.session_state.thresholds, threshold_profile) | |
if metrics: | |
metrics_cols = st.columns(3) | |
with metrics_cols[0]: | |
st.metric("Threshold", f"{metrics['threshold']:.3f}") | |
with metrics_cols[1]: | |
st.metric("Micro F1", f"{metrics['micro_f1']:.1f}%") | |
with metrics_cols[2]: | |
st.metric("Macro F1", f"{metrics['macro_f1']:.1f}%") | |
# Threshold controls based on profile | |
if st.session_state.model_loaded: | |
active_threshold = st.session_state.settings.get('active_threshold', 0.5) | |
active_category_thresholds = st.session_state.settings.get('active_category_thresholds', {}) | |
if threshold_profile in ["Micro Optimized", "Macro Optimized", "Balanced"]: | |
# Show reference threshold (disabled) | |
st.slider( | |
"Threshold (from validation)", | |
min_value=0.01, | |
max_value=1.0, | |
value=float(active_threshold), | |
step=0.01, | |
disabled=True, | |
help="This threshold is optimized from validation results" | |
) | |
elif threshold_profile == "Overall": | |
# Adjustable overall threshold | |
active_threshold = st.slider( | |
"Overall threshold", | |
min_value=0.01, | |
max_value=1.0, | |
value=float(active_threshold), | |
step=0.01 | |
) | |
st.session_state.settings['active_threshold'] = active_threshold | |
elif threshold_profile == "Category-specific": | |
# Show reference overall threshold | |
st.slider( | |
"Overall threshold (reference)", | |
min_value=0.01, | |
max_value=1.0, | |
value=float(active_threshold), | |
step=0.01, | |
disabled=True | |
) | |
st.write("Adjust thresholds for individual categories:") | |
# Category sliders | |
slider_cols = st.columns(2) | |
if not active_category_thresholds: | |
active_category_thresholds = {} | |
if hasattr(st.session_state, 'categories'): | |
for i, category in enumerate(sorted(st.session_state.categories)): | |
col_idx = i % 2 | |
with slider_cols[col_idx]: | |
default_val = active_category_thresholds.get(category, active_threshold) | |
new_threshold = st.slider( | |
f"{category.capitalize()}", | |
min_value=0.01, | |
max_value=1.0, | |
value=float(default_val), | |
step=0.01, | |
key=f"slider_{category}" | |
) | |
active_category_thresholds[category] = new_threshold | |
st.session_state.settings['active_category_thresholds'] = active_category_thresholds | |
# Display options | |
with st.expander("Display Options", expanded=False): | |
col1, col2 = st.columns(2) | |
with col1: | |
show_all_tags = st.checkbox("Show all tags (including below threshold)", | |
value=st.session_state.settings['show_all_tags']) | |
compact_view = st.checkbox("Compact view (hide progress bars)", | |
value=st.session_state.settings['compact_view']) | |
replace_underscores = st.checkbox("Replace underscores with spaces", | |
value=st.session_state.settings.get('replace_underscores', False)) | |
with col2: | |
min_confidence = st.slider("Minimum confidence to display", 0.0, 0.5, | |
st.session_state.settings['min_confidence'], 0.01) | |
# Update settings | |
st.session_state.settings.update({ | |
'show_all_tags': show_all_tags, | |
'compact_view': compact_view, | |
'min_confidence': min_confidence, | |
'replace_underscores': replace_underscores | |
}) | |
# Category selection | |
st.write("Categories to include in 'All Tags' section:") | |
category_cols = st.columns(3) | |
selected_categories = {} | |
if hasattr(st.session_state, 'categories'): | |
for i, category in enumerate(sorted(st.session_state.categories)): | |
col_idx = i % 3 | |
with category_cols[col_idx]: | |
default_val = st.session_state.settings['selected_categories'].get(category, True) | |
selected_categories[category] = st.checkbox( | |
f"{category.capitalize()}", | |
value=default_val, | |
key=f"cat_select_{category}" | |
) | |
st.session_state.settings['selected_categories'] = selected_categories | |
# Run tagging button | |
if image_path and st.button("Run Tagging"): | |
if not st.session_state.model_loaded: | |
st.error("Model not loaded") | |
else: | |
# Create progress indicators | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
try: | |
status_text.text("Starting image analysis...") | |
progress_bar.progress(10) | |
# Process image based on model type | |
if st.session_state.model_type == "onnx": | |
# Check if we have the necessary modules | |
try: | |
from utils.onnx_processing import process_single_image_onnx | |
progress_bar.progress(20) | |
status_text.text("Module imported successfully...") | |
except ImportError as import_e: | |
st.error(f"Missing required module: {import_e}") | |
st.error("This suggests the utils modules aren't properly configured") | |
return | |
# Update progress before inference | |
status_text.text("Running ONNX inference... This may take 2-5 seconds.") | |
progress_bar.progress(30) | |
# Add timeout warning | |
st.warning("⏳ Model inference in progress. Please wait and don't refresh the page.") | |
result = process_single_image_onnx( | |
image_path=image_path, | |
model_path=st.session_state.model_files['onnx_path'], | |
metadata=st.session_state.metadata, | |
threshold_profile=threshold_profile, | |
active_threshold=st.session_state.settings['active_threshold'], | |
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), | |
min_confidence=st.session_state.settings['min_confidence'] | |
) | |
progress_bar.progress(90) | |
status_text.text("Processing results...") | |
else: | |
# SafeTensors processing | |
try: | |
from utils.image_processing import process_image | |
progress_bar.progress(20) | |
except ImportError as import_e: | |
st.error(f"Missing required module: {import_e}") | |
return | |
status_text.text("Running SafeTensors inference...") | |
progress_bar.progress(30) | |
result = process_image( | |
image_path=image_path, | |
model=st.session_state.model, | |
thresholds=st.session_state.thresholds, | |
metadata=st.session_state.metadata, | |
threshold_profile=threshold_profile, | |
active_threshold=st.session_state.settings['active_threshold'], | |
active_category_thresholds=st.session_state.settings.get('active_category_thresholds', {}), | |
min_confidence=st.session_state.settings['min_confidence'] | |
) | |
progress_bar.progress(90) | |
if result and result.get('success'): | |
progress_bar.progress(95) | |
status_text.text("Organizing results...") | |
# Process results in smaller chunks to prevent browser blocking | |
try: | |
# Limit result size to prevent memory issues but allow more tags | |
all_probs = result.get('all_probs', {}) | |
# Count total items | |
total_items = sum(len(cat_items) for cat_items in all_probs.values()) | |
# Increased limits - 256 per category, higher total limit | |
MAX_TAGS_PER_CATEGORY = 256 | |
MAX_TOTAL_TAGS = 1500 # Increased to accommodate more categories | |
limited_all_probs = {} | |
limited_tags = {} | |
total_processed = 0 | |
for category, cat_probs in all_probs.items(): | |
if total_processed >= MAX_TOTAL_TAGS: | |
break | |
# Limit items per category | |
limited_cat_probs = cat_probs[:MAX_TAGS_PER_CATEGORY] | |
limited_all_probs[category] = limited_cat_probs | |
# Get filtered tags for this category | |
filtered_cat_tags = result.get('tags', {}).get(category, []) | |
limited_cat_tags = filtered_cat_tags[:MAX_TAGS_PER_CATEGORY] | |
if limited_cat_tags: | |
limited_tags[category] = limited_cat_tags | |
total_processed += len(limited_cat_probs) | |
# Create limited all_tags list | |
limited_all_tags = [] | |
for category, cat_tags in limited_tags.items(): | |
for tag, _ in cat_tags: | |
limited_all_tags.append(tag) | |
# Store the limited results | |
st.session_state.all_probs = limited_all_probs | |
st.session_state.tags = limited_tags | |
st.session_state.all_tags = limited_all_tags | |
progress_bar.progress(100) | |
status_text.text("Analysis completed!") | |
# Show performance info | |
if 'inference_time' in result: | |
st.success(f"Analysis completed in {result['inference_time']:.2f} seconds! Found {len(limited_all_tags)} tags.") | |
else: | |
st.success(f"Analysis completed! Found {len(limited_all_tags)} tags.") | |
# Show limitation notice if we hit limits | |
if total_items > MAX_TOTAL_TAGS: | |
st.info(f"Note: Showing top {MAX_TOTAL_TAGS} results out of {total_items} total predictions for optimal performance.") | |
except Exception as result_e: | |
st.error(f"Error processing results: {result_e}") | |
# Clear progress indicators | |
progress_bar.empty() | |
status_text.empty() | |
else: | |
error_msg = result.get('error', 'Unknown error') if result else 'No result returned' | |
st.error(f"Analysis failed: {error_msg}") | |
progress_bar.empty() | |
status_text.empty() | |
except Exception as e: | |
st.error(f"Error during analysis: {str(e)}") | |
st.code(traceback.format_exc()) | |
progress_bar.empty() | |
status_text.empty() | |
# Display results | |
if image_path and hasattr(st.session_state, 'all_probs'): | |
st.header("Predictions") | |
# Apply current thresholds | |
filtered_tags, current_all_tags = apply_thresholds( | |
st.session_state.all_probs, | |
threshold_profile, | |
st.session_state.settings['active_threshold'], | |
st.session_state.settings.get('active_category_thresholds', {}), | |
st.session_state.settings['min_confidence'], | |
st.session_state.settings['selected_categories'] | |
) | |
all_tags = [] | |
# Display by category | |
for category in sorted(st.session_state.all_probs.keys()): | |
all_tags_in_category = st.session_state.all_probs.get(category, []) | |
filtered_tags_in_category = filtered_tags.get(category, []) | |
if all_tags_in_category: | |
expander_label = f"{category.capitalize()} ({len(filtered_tags_in_category)} tags)" | |
with st.expander(expander_label, expanded=True): | |
# Get threshold for this category (handle None case) | |
active_category_thresholds = st.session_state.settings.get('active_category_thresholds') or {} | |
threshold = active_category_thresholds.get(category, st.session_state.settings['active_threshold']) | |
# Determine tags to display | |
if st.session_state.settings['show_all_tags']: | |
tags_to_display = all_tags_in_category | |
else: | |
tags_to_display = [(tag, prob) for tag, prob in all_tags_in_category if prob >= threshold] | |
if not tags_to_display: | |
st.info(f"No tags above {st.session_state.settings['min_confidence']:.2f} confidence") | |
continue | |
# Display tags | |
if st.session_state.settings['compact_view']: | |
# Compact view | |
tag_list = [] | |
replace_underscores = st.session_state.settings.get('replace_underscores', False) | |
for tag, prob in tags_to_display: | |
percentage = int(prob * 100) | |
display_tag = tag.replace('_', ' ') if replace_underscores else tag | |
tag_list.append(f"{display_tag} ({percentage}%)") | |
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): | |
all_tags.append(tag) | |
st.markdown(", ".join(tag_list)) | |
else: | |
# Expanded view with progress bars | |
for tag, prob in tags_to_display: | |
replace_underscores = st.session_state.settings.get('replace_underscores', False) | |
display_tag = tag.replace('_', ' ') if replace_underscores else tag | |
if prob >= threshold and st.session_state.settings['selected_categories'].get(category, True): | |
all_tags.append(tag) | |
tag_display = f"**{display_tag}**" | |
else: | |
tag_display = display_tag | |
st.write(tag_display) | |
st.markdown(display_progress_bar(prob), unsafe_allow_html=True) | |
# All tags summary | |
st.markdown("---") | |
st.subheader(f"All Tags ({len(all_tags)} total)") | |
if all_tags: | |
replace_underscores = st.session_state.settings.get('replace_underscores', False) | |
if replace_underscores: | |
display_tags = [tag.replace('_', ' ') for tag in all_tags] | |
tags_text = ", ".join(display_tags) | |
else: | |
tags_text = ", ".join(all_tags) | |
st.write(tags_text) | |
# Add download button for tags | |
st.download_button( | |
label="📥 Download Tags", | |
data=tags_text, | |
file_name=f"{st.session_state.get('original_filename', 'image')}_tags.txt", | |
mime="text/plain" | |
) | |
else: | |
st.info("No tags detected above the threshold.") | |
if __name__ == "__main__": | |
image_tagger_app() |