Spaces:
Running
Running
""" | |
ONNX-based batch image processing for the Image Tagger application. | |
Updated with proper ImageNet normalization and new metadata format. | |
""" | |
import os | |
import json | |
import time | |
import traceback | |
import numpy as np | |
import glob | |
import onnxruntime as ort | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from concurrent.futures import ThreadPoolExecutor | |
def preprocess_image(image_path, image_size=512): | |
""" | |
Process an image for ImageTagger inference with proper ImageNet normalization | |
""" | |
if not os.path.exists(image_path): | |
raise ValueError(f"Image not found at path: {image_path}") | |
# ImageNet normalization - CRITICAL for your model | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
try: | |
with Image.open(image_path) as img: | |
# Convert RGBA or Palette images to RGB | |
if img.mode in ('RGBA', 'P'): | |
img = img.convert('RGB') | |
# Get original dimensions | |
width, height = img.size | |
aspect_ratio = width / height | |
# Calculate new dimensions to maintain aspect ratio | |
if aspect_ratio > 1: | |
new_width = image_size | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = image_size | |
new_width = int(new_height * aspect_ratio) | |
# Resize with LANCZOS filter | |
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
# Create new image with padding (use ImageNet mean for padding) | |
# Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255) | |
pad_color = (124, 116, 104) | |
new_image = Image.new('RGB', (image_size, image_size), pad_color) | |
paste_x = (image_size - new_width) // 2 | |
paste_y = (image_size - new_height) // 2 | |
new_image.paste(img, (paste_x, paste_y)) | |
# Apply transforms (including ImageNet normalization) | |
img_tensor = transform(new_image) | |
return img_tensor.numpy() | |
except Exception as e: | |
raise Exception(f"Error processing {image_path}: {str(e)}") | |
def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall", | |
active_threshold=0.35, active_category_thresholds=None, | |
min_confidence=0.1): | |
""" | |
Process a single image using ONNX model with new metadata format | |
Args: | |
image_path: Path to the image file | |
model_path: Path to the ONNX model file | |
metadata: Model metadata dictionary | |
threshold_profile: The threshold profile being used | |
active_threshold: Overall threshold value | |
active_category_thresholds: Category-specific thresholds | |
min_confidence: Minimum confidence to include in results | |
Returns: | |
Dictionary with tags and probabilities | |
""" | |
try: | |
# Create ONNX tagger for this image (or reuse an existing one) | |
if hasattr(process_single_image_onnx, 'tagger'): | |
tagger = process_single_image_onnx.tagger | |
else: | |
# Create new tagger | |
tagger = ONNXImageTagger(model_path, metadata) | |
# Cache it for future calls | |
process_single_image_onnx.tagger = tagger | |
# Preprocess the image | |
start_time = time.time() | |
img_array = preprocess_image(image_path) | |
# Run inference | |
results = tagger.predict_batch( | |
[img_array], | |
threshold=active_threshold, | |
category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
inference_time = time.time() - start_time | |
if results: | |
result = results[0] | |
result['inference_time'] = inference_time | |
result['success'] = True | |
return result | |
else: | |
return { | |
'success': False, | |
'error': 'Failed to process image', | |
'all_tags': [], | |
'all_probs': {}, | |
'tags': {} | |
} | |
except Exception as e: | |
print(f"Error in process_single_image_onnx: {str(e)}") | |
traceback.print_exc() | |
return { | |
'success': False, | |
'error': str(e), | |
'all_tags': [], | |
'all_probs': {}, | |
'tags': {} | |
} | |
def preprocess_images_parallel(image_paths, image_size=512, max_workers=8): | |
"""Process multiple images in parallel""" | |
processed_images = [] | |
valid_paths = [] | |
# Define a worker function | |
def process_single_image(path): | |
try: | |
return preprocess_image(path, image_size), path | |
except Exception as e: | |
print(f"Error processing {path}: {str(e)}") | |
return None, path | |
# Process images in parallel | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
results = list(executor.map(process_single_image, image_paths)) | |
# Filter results | |
for img_array, path in results: | |
if img_array is not None: | |
processed_images.append(img_array) | |
valid_paths.append(path) | |
return processed_images, valid_paths | |
def apply_category_limits(result, category_limits): | |
""" | |
Apply category limits to a result dictionary. | |
Args: | |
result: Result dictionary containing tags and all_tags | |
category_limits: Dictionary mapping categories to their tag limits | |
(0 = exclude category, -1 = no limit/include all) | |
Returns: | |
Updated result dictionary with limits applied | |
""" | |
if not category_limits or not result['success']: | |
return result | |
# Get the filtered tags | |
filtered_tags = result['tags'] | |
# Apply limits to each category | |
for category, cat_tags in list(filtered_tags.items()): | |
# Get limit for this category, default to -1 (no limit) | |
limit = category_limits.get(category, -1) | |
if limit == 0: | |
# Exclude this category entirely | |
del filtered_tags[category] | |
elif limit > 0 and len(cat_tags) > limit: | |
# Limit to top N tags for this category | |
filtered_tags[category] = cat_tags[:limit] | |
# Regenerate all_tags list after applying limits | |
all_tags = [] | |
for category, cat_tags in filtered_tags.items(): | |
for tag, _ in cat_tags: | |
all_tags.append(tag) | |
# Update the result with limited tags | |
result['tags'] = filtered_tags | |
result['all_tags'] = all_tags | |
return result | |
class ONNXImageTagger: | |
"""ONNX-based image tagger for fast batch inference with updated metadata format""" | |
def __init__(self, model_path, metadata): | |
# Load model | |
self.model_path = model_path | |
try: | |
self.session = ort.InferenceSession( | |
model_path, | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
print(f"Using providers: {self.session.get_providers()}") | |
except Exception as e: | |
print(f"CUDA not available, using CPU: {e}") | |
self.session = ort.InferenceSession( | |
model_path, | |
providers=['CPUExecutionProvider'] | |
) | |
print(f"Using providers: {self.session.get_providers()}") | |
# Store metadata (passed as dict, not loaded from file) | |
self.metadata = metadata | |
# Extract tag mappings from new metadata structure | |
if 'dataset_info' in metadata: | |
# New metadata format | |
self.tag_mapping = metadata['dataset_info']['tag_mapping'] | |
self.idx_to_tag = self.tag_mapping['idx_to_tag'] | |
self.tag_to_category = self.tag_mapping['tag_to_category'] | |
self.total_tags = metadata['dataset_info']['total_tags'] | |
else: | |
# Fallback for older format | |
self.idx_to_tag = metadata.get('idx_to_tag', {}) | |
self.tag_to_category = metadata.get('tag_to_category', {}) | |
self.total_tags = metadata.get('total_tags', len(self.idx_to_tag)) | |
# Get input name | |
self.input_name = self.session.get_inputs()[0].name | |
print(f"Model loaded successfully. Input name: {self.input_name}") | |
print(f"Total tags: {self.total_tags}, Categories: {len(set(self.tag_to_category.values()))}") | |
def predict_batch(self, image_arrays, threshold=0.5, category_thresholds=None, min_confidence=0.1): | |
"""Run batch inference on preprocessed image arrays""" | |
# Stack arrays into batch | |
batch_input = np.stack(image_arrays) | |
# Run inference | |
start_time = time.time() | |
outputs = self.session.run(None, {self.input_name: batch_input}) | |
inference_time = time.time() - start_time | |
print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)") | |
# Process outputs - handle both single and multi-output models | |
if len(outputs) >= 2: | |
# Multi-output model (initial_predictions, refined_predictions, selected_candidates) | |
initial_logits = outputs[0] | |
refined_logits = outputs[1] | |
# Use refined predictions as main output | |
main_logits = refined_logits | |
print(f"Using refined predictions (shape: {refined_logits.shape})") | |
else: | |
# Single output model | |
main_logits = outputs[0] | |
print(f"Using single output (shape: {main_logits.shape})") | |
# Apply sigmoid to get probabilities | |
main_probs = 1.0 / (1.0 + np.exp(-main_logits)) | |
# Process results for each image in batch | |
batch_results = [] | |
for i in range(main_probs.shape[0]): | |
probs = main_probs[i] | |
# Extract and organize all probabilities | |
all_probs = {} | |
for idx in range(probs.shape[0]): | |
prob_value = float(probs[idx]) | |
if prob_value >= min_confidence: | |
idx_str = str(idx) | |
tag_name = self.idx_to_tag.get(idx_str, f"unknown-{idx}") | |
category = self.tag_to_category.get(tag_name, "general") | |
if category not in all_probs: | |
all_probs[category] = [] | |
all_probs[category].append((tag_name, prob_value)) | |
# Sort tags by probability within each category | |
for category in all_probs: | |
all_probs[category] = sorted( | |
all_probs[category], | |
key=lambda x: x[1], | |
reverse=True | |
) | |
# Get the filtered tags based on the selected threshold | |
tags = {} | |
for category, cat_tags in all_probs.items(): | |
# Use category-specific threshold if available | |
if category_thresholds and category in category_thresholds: | |
cat_threshold = category_thresholds[category] | |
else: | |
cat_threshold = threshold | |
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold] | |
# Create a flat list of all tags above threshold | |
all_tags = [] | |
for category, cat_tags in tags.items(): | |
for tag, _ in cat_tags: | |
all_tags.append(tag) | |
batch_results.append({ | |
'tags': tags, | |
'all_probs': all_probs, | |
'all_tags': all_tags, | |
'success': True | |
}) | |
return batch_results | |
def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile, | |
active_threshold, active_category_thresholds, save_dir=None, | |
progress_callback=None, min_confidence=0.1, batch_size=16, | |
category_limits=None): | |
""" | |
Process all images in a folder using the ONNX model with new metadata format. | |
Args: | |
folder_path: Path to folder containing images | |
model_path: Path to the ONNX model file | |
metadata_path: Path to the model metadata file | |
threshold_profile: Selected threshold profile | |
active_threshold: Overall threshold value | |
active_category_thresholds: Category-specific thresholds | |
save_dir: Directory to save tag files (if None uses default) | |
progress_callback: Optional callback for progress updates | |
min_confidence: Minimum confidence threshold | |
batch_size: Number of images to process at once | |
category_limits: Dictionary mapping categories to their tag limits | |
Returns: | |
Dictionary with results for each image | |
""" | |
from utils.file_utils import save_tags_to_file # Import here to avoid circular imports | |
# Find all image files in the folder | |
image_extensions = ['*.jpg', '*.jpeg', '*.png'] | |
image_files = [] | |
for ext in image_extensions: | |
image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper()))) | |
# Remove duplicates (Windows case-insensitive filesystems) | |
if os.name == 'nt': # Windows | |
unique_paths = set() | |
unique_files = [] | |
for file_path in image_files: | |
normalized_path = os.path.normpath(file_path).lower() | |
if normalized_path not in unique_paths: | |
unique_paths.add(normalized_path) | |
unique_files.append(file_path) | |
image_files = unique_files | |
if not image_files: | |
return { | |
'success': False, | |
'error': f"No images found in {folder_path}", | |
'results': {} | |
} | |
# Use the provided save directory or create a default one | |
if save_dir is None: | |
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
save_dir = os.path.join(app_dir, "saved_tags") | |
# Ensure the directory exists | |
os.makedirs(save_dir, exist_ok=True) | |
# Load metadata | |
try: | |
with open(metadata_path, 'r') as f: | |
metadata = json.load(f) | |
except Exception as e: | |
return { | |
'success': False, | |
'error': f"Failed to load metadata: {e}", | |
'results': {} | |
} | |
# Create ONNX tagger | |
try: | |
tagger = ONNXImageTagger(model_path, metadata) | |
except Exception as e: | |
return { | |
'success': False, | |
'error': f"Failed to load model: {e}", | |
'results': {} | |
} | |
# Process images in batches | |
results = {} | |
total_images = len(image_files) | |
processed = 0 | |
start_time = time.time() | |
# Process in batches | |
for i in range(0, total_images, batch_size): | |
batch_start = time.time() | |
# Get current batch of images | |
batch_files = image_files[i:i+batch_size] | |
batch_size_actual = len(batch_files) | |
# Update progress if callback provided | |
if progress_callback: | |
progress_callback(processed, total_images, batch_files[0] if batch_files else None) | |
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images") | |
try: | |
# Preprocess images in parallel | |
processed_images, valid_paths = preprocess_images_parallel(batch_files) | |
if processed_images: | |
# Run batch prediction | |
batch_results = tagger.predict_batch( | |
processed_images, | |
threshold=active_threshold, | |
category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
# Process results for each image | |
for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)): | |
# Update progress if callback provided | |
if progress_callback: | |
progress_callback(processed + j, total_images, image_path) | |
# Apply category limits if specified | |
if category_limits and result['success']: | |
print(f"Applying limits to {os.path.basename(image_path)}: {len(result['all_tags'])} → ", end="") | |
result = apply_category_limits(result, category_limits) | |
print(f"{len(result['all_tags'])} tags") | |
# Save the tags to a file | |
if result['success']: | |
try: | |
output_path = save_tags_to_file( | |
image_path=image_path, | |
all_tags=result['all_tags'], | |
custom_dir=save_dir, | |
overwrite=True | |
) | |
result['output_path'] = str(output_path) | |
except Exception as e: | |
print(f"Error saving tags for {image_path}: {e}") | |
result['save_error'] = str(e) | |
# Store the result | |
results[image_path] = result | |
processed += batch_size_actual | |
# Calculate batch timing | |
batch_end = time.time() | |
batch_time = batch_end - batch_start | |
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)") | |
except Exception as e: | |
print(f"Error processing batch: {str(e)}") | |
traceback.print_exc() | |
# Process failed images one by one as fallback | |
for j, image_path in enumerate(batch_files): | |
try: | |
# Update progress if callback provided | |
if progress_callback: | |
progress_callback(processed + j, total_images, image_path) | |
# Preprocess single image | |
img_array = preprocess_image(image_path) | |
# Run inference on single image | |
single_results = tagger.predict_batch( | |
[img_array], | |
threshold=active_threshold, | |
category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
if single_results: | |
result = single_results[0] | |
# Apply category limits if specified | |
if category_limits and result['success']: | |
result = apply_category_limits(result, category_limits) | |
# Save the tags to a file | |
if result['success']: | |
try: | |
output_path = save_tags_to_file( | |
image_path=image_path, | |
all_tags=result['all_tags'], | |
custom_dir=save_dir, | |
overwrite=True | |
) | |
result['output_path'] = str(output_path) | |
except Exception as e: | |
print(f"Error saving tags for {image_path}: {e}") | |
result['save_error'] = str(e) | |
results[image_path] = result | |
else: | |
results[image_path] = { | |
'success': False, | |
'error': 'Failed to process image', | |
'all_tags': [] | |
} | |
except Exception as img_e: | |
print(f"Error processing single image {image_path}: {str(img_e)}") | |
results[image_path] = { | |
'success': False, | |
'error': str(img_e), | |
'all_tags': [] | |
} | |
processed += batch_size_actual | |
# Final progress update | |
if progress_callback: | |
progress_callback(total_images, total_images, None) | |
end_time = time.time() | |
total_time = end_time - start_time | |
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image") | |
return { | |
'success': True, | |
'total': total_images, | |
'processed': len(results), | |
'results': results, | |
'save_dir': save_dir, | |
'time_elapsed': end_time - start_time | |
} | |
def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=256): | |
""" | |
Test ImageTagger ONNX model with proper handling of all outputs and new metadata format | |
Args: | |
model_path: Path to ONNX model file | |
metadata_path: Path to metadata JSON file | |
image_path: Path to test image | |
threshold: Confidence threshold for predictions | |
top_k: Maximum number of predictions to show | |
""" | |
import onnxruntime as ort | |
import numpy as np | |
import json | |
import time | |
from collections import defaultdict | |
print(f"Loading ImageTagger ONNX model from {model_path}") | |
# Load metadata with proper error handling | |
try: | |
with open(metadata_path, 'r') as f: | |
metadata = json.load(f) | |
except Exception as e: | |
raise ValueError(f"Failed to load metadata: {e}") | |
# Extract tag mappings from new metadata structure | |
try: | |
if 'dataset_info' in metadata: | |
# New metadata format | |
dataset_info = metadata['dataset_info'] | |
tag_mapping = dataset_info['tag_mapping'] | |
idx_to_tag = tag_mapping['idx_to_tag'] | |
tag_to_category = tag_mapping['tag_to_category'] | |
total_tags = dataset_info['total_tags'] | |
else: | |
# Fallback for older format | |
idx_to_tag = metadata.get('idx_to_tag', {}) | |
tag_to_category = metadata.get('tag_to_category', {}) | |
total_tags = metadata.get('total_tags', len(idx_to_tag)) | |
print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories") | |
except KeyError as e: | |
raise ValueError(f"Invalid metadata structure, missing key: {e}") | |
# Initialize ONNX session with robust provider handling | |
providers = [] | |
if ort.get_device() == 'GPU': | |
providers.append('CUDAExecutionProvider') | |
providers.append('CPUExecutionProvider') | |
try: | |
session = ort.InferenceSession(model_path, providers=providers) | |
active_provider = session.get_providers()[0] | |
print(f"Using provider: {active_provider}") | |
# Print model info | |
inputs = session.get_inputs() | |
outputs = session.get_outputs() | |
print(f"Model inputs: {len(inputs)}") | |
print(f"Model outputs: {len(outputs)}") | |
for i, output in enumerate(outputs): | |
print(f" Output {i}: {output.name} {output.shape}") | |
except Exception as e: | |
raise RuntimeError(f"Failed to create ONNX session: {e}") | |
# Preprocess image | |
print(f"Processing image: {image_path}") | |
try: | |
# Get image size from metadata | |
img_size = metadata.get('model_info', {}).get('img_size', 512) | |
img_tensor = preprocess_image(image_path, image_size=img_size) | |
img_numpy = img_tensor[np.newaxis, :] # Add batch dimension | |
print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}") | |
except Exception as e: | |
raise ValueError(f"Image preprocessing failed: {e}") | |
# Run inference | |
input_name = session.get_inputs()[0].name | |
print("Running inference...") | |
start_time = time.time() | |
try: | |
outputs = session.run(None, {input_name: img_numpy}) | |
inference_time = time.time() - start_time | |
print(f"Inference completed in {inference_time:.4f} seconds") | |
except Exception as e: | |
raise RuntimeError(f"Inference failed: {e}") | |
# Handle outputs properly | |
if len(outputs) >= 2: | |
initial_logits = outputs[0] | |
refined_logits = outputs[1] | |
selected_candidates = outputs[2] if len(outputs) > 2 else None | |
# Use refined predictions as main output | |
main_logits = refined_logits | |
print(f"Using refined predictions (shape: {refined_logits.shape})") | |
else: | |
# Fallback to single output | |
main_logits = outputs[0] | |
print(f"Using single output (shape: {main_logits.shape})") | |
# Apply sigmoid to get probabilities | |
main_probs = 1.0 / (1.0 + np.exp(-main_logits)) | |
# Apply threshold and get predictions | |
predictions_mask = (main_probs >= threshold) | |
indices = np.where(predictions_mask[0])[0] | |
if len(indices) == 0: | |
print(f"No predictions above threshold {threshold}") | |
# Show top 5 regardless of threshold | |
top_indices = np.argsort(main_probs[0])[-5:][::-1] | |
print("Top 5 predictions:") | |
for idx in top_indices: | |
idx_str = str(idx) | |
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") | |
prob = float(main_probs[0, idx]) | |
print(f" {tag_name}: {prob:.3f}") | |
return {} | |
# Group by category | |
tags_by_category = defaultdict(list) | |
for idx in indices: | |
idx_str = str(idx) | |
tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") | |
category = tag_to_category.get(tag_name, "general") | |
prob = float(main_probs[0, idx]) | |
tags_by_category[category].append((tag_name, prob)) | |
# Sort by probability within each category | |
for category in tags_by_category: | |
tags_by_category[category] = sorted( | |
tags_by_category[category], | |
key=lambda x: x[1], | |
reverse=True | |
)[:top_k] # Limit per category | |
# Print results | |
total_predictions = sum(len(tags) for tags in tags_by_category.values()) | |
print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total") | |
# Category order for consistent display | |
category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating'] | |
for category in category_order: | |
if category in tags_by_category: | |
tags = tags_by_category[category] | |
print(f"\n{category.upper()} ({len(tags)}):") | |
for tag, prob in tags: | |
print(f" {tag}: {prob:.3f}") | |
# Show any other categories not in standard order | |
for category in sorted(tags_by_category.keys()): | |
if category not in category_order: | |
tags = tags_by_category[category] | |
print(f"\n{category.upper()} ({len(tags)}):") | |
for tag, prob in tags: | |
print(f" {tag}: {prob:.3f}") | |
# Performance stats | |
print(f"\nPerformance:") | |
print(f" Inference time: {inference_time:.4f}s") | |
print(f" Provider: {active_provider}") | |
print(f" Max confidence: {main_probs.max():.3f}") | |
if total_predictions > 0: | |
avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags]) | |
print(f" Average confidence: {avg_conf:.3f}") | |
return dict(tags_by_category) |