""" ONNX-based batch image processing for the Image Tagger application. Updated with proper ImageNet normalization and new metadata format. """ import os import json import time import traceback import numpy as np import glob import onnxruntime as ort from PIL import Image import torchvision.transforms as transforms from concurrent.futures import ThreadPoolExecutor def preprocess_image(image_path, image_size=512): """ Process an image for ImageTagger inference with proper ImageNet normalization """ if not os.path.exists(image_path): raise ValueError(f"Image not found at path: {image_path}") # ImageNet normalization - CRITICAL for your model transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) try: with Image.open(image_path) as img: # Convert RGBA or Palette images to RGB if img.mode in ('RGBA', 'P'): img = img.convert('RGB') # Get original dimensions width, height = img.size aspect_ratio = width / height # Calculate new dimensions to maintain aspect ratio if aspect_ratio > 1: new_width = image_size new_height = int(new_width / aspect_ratio) else: new_height = image_size new_width = int(new_height * aspect_ratio) # Resize with LANCZOS filter img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create new image with padding (use ImageNet mean for padding) # Using RGB values close to ImageNet mean: (0.485*255, 0.456*255, 0.406*255) pad_color = (124, 116, 104) new_image = Image.new('RGB', (image_size, image_size), pad_color) paste_x = (image_size - new_width) // 2 paste_y = (image_size - new_height) // 2 new_image.paste(img, (paste_x, paste_y)) # Apply transforms (including ImageNet normalization) img_tensor = transform(new_image) return img_tensor.numpy() except Exception as e: raise Exception(f"Error processing {image_path}: {str(e)}") def process_single_image_onnx(image_path, model_path, metadata, threshold_profile="Overall", active_threshold=0.35, active_category_thresholds=None, min_confidence=0.1): """ Process a single image using ONNX model with new metadata format Args: image_path: Path to the image file model_path: Path to the ONNX model file metadata: Model metadata dictionary threshold_profile: The threshold profile being used active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds min_confidence: Minimum confidence to include in results Returns: Dictionary with tags and probabilities """ try: # Create ONNX tagger for this image (or reuse an existing one) if hasattr(process_single_image_onnx, 'tagger'): tagger = process_single_image_onnx.tagger else: # Create new tagger tagger = ONNXImageTagger(model_path, metadata) # Cache it for future calls process_single_image_onnx.tagger = tagger # Preprocess the image start_time = time.time() img_array = preprocess_image(image_path) # Run inference results = tagger.predict_batch( [img_array], threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) inference_time = time.time() - start_time if results: result = results[0] result['inference_time'] = inference_time result['success'] = True return result else: return { 'success': False, 'error': 'Failed to process image', 'all_tags': [], 'all_probs': {}, 'tags': {} } except Exception as e: print(f"Error in process_single_image_onnx: {str(e)}") traceback.print_exc() return { 'success': False, 'error': str(e), 'all_tags': [], 'all_probs': {}, 'tags': {} } def preprocess_images_parallel(image_paths, image_size=512, max_workers=8): """Process multiple images in parallel""" processed_images = [] valid_paths = [] # Define a worker function def process_single_image(path): try: return preprocess_image(path, image_size), path except Exception as e: print(f"Error processing {path}: {str(e)}") return None, path # Process images in parallel with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(process_single_image, image_paths)) # Filter results for img_array, path in results: if img_array is not None: processed_images.append(img_array) valid_paths.append(path) return processed_images, valid_paths def apply_category_limits(result, category_limits): """ Apply category limits to a result dictionary. Args: result: Result dictionary containing tags and all_tags category_limits: Dictionary mapping categories to their tag limits (0 = exclude category, -1 = no limit/include all) Returns: Updated result dictionary with limits applied """ if not category_limits or not result['success']: return result # Get the filtered tags filtered_tags = result['tags'] # Apply limits to each category for category, cat_tags in list(filtered_tags.items()): # Get limit for this category, default to -1 (no limit) limit = category_limits.get(category, -1) if limit == 0: # Exclude this category entirely del filtered_tags[category] elif limit > 0 and len(cat_tags) > limit: # Limit to top N tags for this category filtered_tags[category] = cat_tags[:limit] # Regenerate all_tags list after applying limits all_tags = [] for category, cat_tags in filtered_tags.items(): for tag, _ in cat_tags: all_tags.append(tag) # Update the result with limited tags result['tags'] = filtered_tags result['all_tags'] = all_tags return result class ONNXImageTagger: """ONNX-based image tagger for fast batch inference with updated metadata format""" def __init__(self, model_path, metadata): # Load model self.model_path = model_path try: self.session = ort.InferenceSession( model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) print(f"Using providers: {self.session.get_providers()}") except Exception as e: print(f"CUDA not available, using CPU: {e}") self.session = ort.InferenceSession( model_path, providers=['CPUExecutionProvider'] ) print(f"Using providers: {self.session.get_providers()}") # Store metadata (passed as dict, not loaded from file) self.metadata = metadata # Extract tag mappings from new metadata structure if 'dataset_info' in metadata: # New metadata format self.tag_mapping = metadata['dataset_info']['tag_mapping'] self.idx_to_tag = self.tag_mapping['idx_to_tag'] self.tag_to_category = self.tag_mapping['tag_to_category'] self.total_tags = metadata['dataset_info']['total_tags'] else: # Fallback for older format self.idx_to_tag = metadata.get('idx_to_tag', {}) self.tag_to_category = metadata.get('tag_to_category', {}) self.total_tags = metadata.get('total_tags', len(self.idx_to_tag)) # Get input name self.input_name = self.session.get_inputs()[0].name print(f"Model loaded successfully. Input name: {self.input_name}") print(f"Total tags: {self.total_tags}, Categories: {len(set(self.tag_to_category.values()))}") def predict_batch(self, image_arrays, threshold=0.5, category_thresholds=None, min_confidence=0.1): """Run batch inference on preprocessed image arrays""" # Stack arrays into batch batch_input = np.stack(image_arrays) # Run inference start_time = time.time() outputs = self.session.run(None, {self.input_name: batch_input}) inference_time = time.time() - start_time print(f"Batch inference completed in {inference_time:.4f} seconds ({inference_time/len(image_arrays):.4f} s/image)") # Process outputs - handle both single and multi-output models if len(outputs) >= 2: # Multi-output model (initial_predictions, refined_predictions, selected_candidates) initial_logits = outputs[0] refined_logits = outputs[1] # Use refined predictions as main output main_logits = refined_logits print(f"Using refined predictions (shape: {refined_logits.shape})") else: # Single output model main_logits = outputs[0] print(f"Using single output (shape: {main_logits.shape})") # Apply sigmoid to get probabilities main_probs = 1.0 / (1.0 + np.exp(-main_logits)) # Process results for each image in batch batch_results = [] for i in range(main_probs.shape[0]): probs = main_probs[i] # Extract and organize all probabilities all_probs = {} for idx in range(probs.shape[0]): prob_value = float(probs[idx]) if prob_value >= min_confidence: idx_str = str(idx) tag_name = self.idx_to_tag.get(idx_str, f"unknown-{idx}") category = self.tag_to_category.get(tag_name, "general") if category not in all_probs: all_probs[category] = [] all_probs[category].append((tag_name, prob_value)) # Sort tags by probability within each category for category in all_probs: all_probs[category] = sorted( all_probs[category], key=lambda x: x[1], reverse=True ) # Get the filtered tags based on the selected threshold tags = {} for category, cat_tags in all_probs.items(): # Use category-specific threshold if available if category_thresholds and category in category_thresholds: cat_threshold = category_thresholds[category] else: cat_threshold = threshold tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= cat_threshold] # Create a flat list of all tags above threshold all_tags = [] for category, cat_tags in tags.items(): for tag, _ in cat_tags: all_tags.append(tag) batch_results.append({ 'tags': tags, 'all_probs': all_probs, 'all_tags': all_tags, 'success': True }) return batch_results def batch_process_images_onnx(folder_path, model_path, metadata_path, threshold_profile, active_threshold, active_category_thresholds, save_dir=None, progress_callback=None, min_confidence=0.1, batch_size=16, category_limits=None): """ Process all images in a folder using the ONNX model with new metadata format. Args: folder_path: Path to folder containing images model_path: Path to the ONNX model file metadata_path: Path to the model metadata file threshold_profile: Selected threshold profile active_threshold: Overall threshold value active_category_thresholds: Category-specific thresholds save_dir: Directory to save tag files (if None uses default) progress_callback: Optional callback for progress updates min_confidence: Minimum confidence threshold batch_size: Number of images to process at once category_limits: Dictionary mapping categories to their tag limits Returns: Dictionary with results for each image """ from utils.file_utils import save_tags_to_file # Import here to avoid circular imports # Find all image files in the folder image_extensions = ['*.jpg', '*.jpeg', '*.png'] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(folder_path, ext))) image_files.extend(glob.glob(os.path.join(folder_path, ext.upper()))) # Remove duplicates (Windows case-insensitive filesystems) if os.name == 'nt': # Windows unique_paths = set() unique_files = [] for file_path in image_files: normalized_path = os.path.normpath(file_path).lower() if normalized_path not in unique_paths: unique_paths.add(normalized_path) unique_files.append(file_path) image_files = unique_files if not image_files: return { 'success': False, 'error': f"No images found in {folder_path}", 'results': {} } # Use the provided save directory or create a default one if save_dir is None: app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) save_dir = os.path.join(app_dir, "saved_tags") # Ensure the directory exists os.makedirs(save_dir, exist_ok=True) # Load metadata try: with open(metadata_path, 'r') as f: metadata = json.load(f) except Exception as e: return { 'success': False, 'error': f"Failed to load metadata: {e}", 'results': {} } # Create ONNX tagger try: tagger = ONNXImageTagger(model_path, metadata) except Exception as e: return { 'success': False, 'error': f"Failed to load model: {e}", 'results': {} } # Process images in batches results = {} total_images = len(image_files) processed = 0 start_time = time.time() # Process in batches for i in range(0, total_images, batch_size): batch_start = time.time() # Get current batch of images batch_files = image_files[i:i+batch_size] batch_size_actual = len(batch_files) # Update progress if callback provided if progress_callback: progress_callback(processed, total_images, batch_files[0] if batch_files else None) print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images") try: # Preprocess images in parallel processed_images, valid_paths = preprocess_images_parallel(batch_files) if processed_images: # Run batch prediction batch_results = tagger.predict_batch( processed_images, threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) # Process results for each image for j, (image_path, result) in enumerate(zip(valid_paths, batch_results)): # Update progress if callback provided if progress_callback: progress_callback(processed + j, total_images, image_path) # Apply category limits if specified if category_limits and result['success']: print(f"Applying limits to {os.path.basename(image_path)}: {len(result['all_tags'])} → ", end="") result = apply_category_limits(result, category_limits) print(f"{len(result['all_tags'])} tags") # Save the tags to a file if result['success']: try: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) except Exception as e: print(f"Error saving tags for {image_path}: {e}") result['save_error'] = str(e) # Store the result results[image_path] = result processed += batch_size_actual # Calculate batch timing batch_end = time.time() batch_time = batch_end - batch_start print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)") except Exception as e: print(f"Error processing batch: {str(e)}") traceback.print_exc() # Process failed images one by one as fallback for j, image_path in enumerate(batch_files): try: # Update progress if callback provided if progress_callback: progress_callback(processed + j, total_images, image_path) # Preprocess single image img_array = preprocess_image(image_path) # Run inference on single image single_results = tagger.predict_batch( [img_array], threshold=active_threshold, category_thresholds=active_category_thresholds, min_confidence=min_confidence ) if single_results: result = single_results[0] # Apply category limits if specified if category_limits and result['success']: result = apply_category_limits(result, category_limits) # Save the tags to a file if result['success']: try: output_path = save_tags_to_file( image_path=image_path, all_tags=result['all_tags'], custom_dir=save_dir, overwrite=True ) result['output_path'] = str(output_path) except Exception as e: print(f"Error saving tags for {image_path}: {e}") result['save_error'] = str(e) results[image_path] = result else: results[image_path] = { 'success': False, 'error': 'Failed to process image', 'all_tags': [] } except Exception as img_e: print(f"Error processing single image {image_path}: {str(img_e)}") results[image_path] = { 'success': False, 'error': str(img_e), 'all_tags': [] } processed += batch_size_actual # Final progress update if progress_callback: progress_callback(total_images, total_images, None) end_time = time.time() total_time = end_time - start_time print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image") return { 'success': True, 'total': total_images, 'processed': len(results), 'results': results, 'save_dir': save_dir, 'time_elapsed': end_time - start_time } def test_onnx_imagetagger(model_path, metadata_path, image_path, threshold=0.5, top_k=256): """ Test ImageTagger ONNX model with proper handling of all outputs and new metadata format Args: model_path: Path to ONNX model file metadata_path: Path to metadata JSON file image_path: Path to test image threshold: Confidence threshold for predictions top_k: Maximum number of predictions to show """ import onnxruntime as ort import numpy as np import json import time from collections import defaultdict print(f"Loading ImageTagger ONNX model from {model_path}") # Load metadata with proper error handling try: with open(metadata_path, 'r') as f: metadata = json.load(f) except Exception as e: raise ValueError(f"Failed to load metadata: {e}") # Extract tag mappings from new metadata structure try: if 'dataset_info' in metadata: # New metadata format dataset_info = metadata['dataset_info'] tag_mapping = dataset_info['tag_mapping'] idx_to_tag = tag_mapping['idx_to_tag'] tag_to_category = tag_mapping['tag_to_category'] total_tags = dataset_info['total_tags'] else: # Fallback for older format idx_to_tag = metadata.get('idx_to_tag', {}) tag_to_category = metadata.get('tag_to_category', {}) total_tags = metadata.get('total_tags', len(idx_to_tag)) print(f"Model info: {total_tags} tags, {len(set(tag_to_category.values()))} categories") except KeyError as e: raise ValueError(f"Invalid metadata structure, missing key: {e}") # Initialize ONNX session with robust provider handling providers = [] if ort.get_device() == 'GPU': providers.append('CUDAExecutionProvider') providers.append('CPUExecutionProvider') try: session = ort.InferenceSession(model_path, providers=providers) active_provider = session.get_providers()[0] print(f"Using provider: {active_provider}") # Print model info inputs = session.get_inputs() outputs = session.get_outputs() print(f"Model inputs: {len(inputs)}") print(f"Model outputs: {len(outputs)}") for i, output in enumerate(outputs): print(f" Output {i}: {output.name} {output.shape}") except Exception as e: raise RuntimeError(f"Failed to create ONNX session: {e}") # Preprocess image print(f"Processing image: {image_path}") try: # Get image size from metadata img_size = metadata.get('model_info', {}).get('img_size', 512) img_tensor = preprocess_image(image_path, image_size=img_size) img_numpy = img_tensor[np.newaxis, :] # Add batch dimension print(f"Input shape: {img_numpy.shape}, dtype: {img_numpy.dtype}") except Exception as e: raise ValueError(f"Image preprocessing failed: {e}") # Run inference input_name = session.get_inputs()[0].name print("Running inference...") start_time = time.time() try: outputs = session.run(None, {input_name: img_numpy}) inference_time = time.time() - start_time print(f"Inference completed in {inference_time:.4f} seconds") except Exception as e: raise RuntimeError(f"Inference failed: {e}") # Handle outputs properly if len(outputs) >= 2: initial_logits = outputs[0] refined_logits = outputs[1] selected_candidates = outputs[2] if len(outputs) > 2 else None # Use refined predictions as main output main_logits = refined_logits print(f"Using refined predictions (shape: {refined_logits.shape})") else: # Fallback to single output main_logits = outputs[0] print(f"Using single output (shape: {main_logits.shape})") # Apply sigmoid to get probabilities main_probs = 1.0 / (1.0 + np.exp(-main_logits)) # Apply threshold and get predictions predictions_mask = (main_probs >= threshold) indices = np.where(predictions_mask[0])[0] if len(indices) == 0: print(f"No predictions above threshold {threshold}") # Show top 5 regardless of threshold top_indices = np.argsort(main_probs[0])[-5:][::-1] print("Top 5 predictions:") for idx in top_indices: idx_str = str(idx) tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") prob = float(main_probs[0, idx]) print(f" {tag_name}: {prob:.3f}") return {} # Group by category tags_by_category = defaultdict(list) for idx in indices: idx_str = str(idx) tag_name = idx_to_tag.get(idx_str, f"unknown-{idx}") category = tag_to_category.get(tag_name, "general") prob = float(main_probs[0, idx]) tags_by_category[category].append((tag_name, prob)) # Sort by probability within each category for category in tags_by_category: tags_by_category[category] = sorted( tags_by_category[category], key=lambda x: x[1], reverse=True )[:top_k] # Limit per category # Print results total_predictions = sum(len(tags) for tags in tags_by_category.values()) print(f"\nPredicted tags (threshold: {threshold}): {total_predictions} total") # Category order for consistent display category_order = ['general', 'character', 'copyright', 'artist', 'meta', 'year', 'rating'] for category in category_order: if category in tags_by_category: tags = tags_by_category[category] print(f"\n{category.upper()} ({len(tags)}):") for tag, prob in tags: print(f" {tag}: {prob:.3f}") # Show any other categories not in standard order for category in sorted(tags_by_category.keys()): if category not in category_order: tags = tags_by_category[category] print(f"\n{category.upper()} ({len(tags)}):") for tag, prob in tags: print(f" {tag}: {prob:.3f}") # Performance stats print(f"\nPerformance:") print(f" Inference time: {inference_time:.4f}s") print(f" Provider: {active_provider}") print(f" Max confidence: {main_probs.max():.3f}") if total_predictions > 0: avg_conf = np.mean([prob for tags in tags_by_category.values() for _, prob in tags]) print(f" Average confidence: {avg_conf:.3f}") return dict(tags_by_category)