Spaces:
Running
Running
""" | |
Image processing functions for the Image Tagger application. | |
""" | |
import os | |
import traceback | |
import glob | |
def process_image(image_path, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1): | |
""" | |
Process a single image and return the tags. | |
Args: | |
image_path: Path to the image | |
model: The image tagger model | |
thresholds: Thresholds dictionary | |
metadata: Metadata dictionary | |
threshold_profile: Selected threshold profile | |
active_threshold: Overall threshold value | |
active_category_thresholds: Category-specific thresholds | |
min_confidence: Minimum confidence to include in results | |
Returns: | |
Dictionary with tags, all probabilities, and other info | |
""" | |
try: | |
# Run inference directly using the model's predict method | |
if threshold_profile in ["Category-specific", "High Precision", "High Recall"]: | |
results = model.predict( | |
image_path=image_path, | |
category_thresholds=active_category_thresholds | |
) | |
else: | |
results = model.predict( | |
image_path=image_path, | |
threshold=active_threshold | |
) | |
# Extract and organize all probabilities | |
all_probs = {} | |
probs = results['refined_probabilities'][0] # Remove batch dimension | |
for idx in range(len(probs)): | |
prob_value = probs[idx].item() | |
if prob_value >= min_confidence: | |
tag, category = model.dataset.get_tag_info(idx) | |
if category not in all_probs: | |
all_probs[category] = [] | |
all_probs[category].append((tag, prob_value)) | |
# Sort tags by probability within each category | |
for category in all_probs: | |
all_probs[category] = sorted( | |
all_probs[category], | |
key=lambda x: x[1], | |
reverse=True | |
) | |
# Get the filtered tags based on the selected threshold | |
tags = {} | |
for category, cat_tags in all_probs.items(): | |
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold | |
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold] | |
# Create a flat list of all tags above threshold | |
all_tags = [] | |
for category, cat_tags in tags.items(): | |
for tag, _ in cat_tags: | |
all_tags.append(tag) | |
return { | |
'tags': tags, | |
'all_probs': all_probs, | |
'all_tags': all_tags, | |
'success': True | |
} | |
except Exception as e: | |
print(f"Error processing {image_path}: {str(e)}") | |
traceback.print_exc() | |
return { | |
'tags': {}, | |
'all_probs': {}, | |
'all_tags': [], | |
'success': False, | |
'error': str(e) | |
} | |
def apply_category_limits(result, category_limits): | |
""" | |
Apply category limits to a result dictionary. | |
Args: | |
result: Result dictionary containing tags and all_tags | |
category_limits: Dictionary mapping categories to their tag limits | |
(0 = exclude category, -1 = no limit/include all) | |
Returns: | |
Updated result dictionary with limits applied | |
""" | |
if not category_limits or not result['success']: | |
return result | |
# Get the filtered tags | |
filtered_tags = result['tags'] | |
# Apply limits to each category | |
for category, cat_tags in list(filtered_tags.items()): | |
# Get limit for this category, default to -1 (no limit) | |
limit = category_limits.get(category, -1) | |
if limit == 0: | |
# Exclude this category entirely | |
del filtered_tags[category] | |
elif limit > 0 and len(cat_tags) > limit: | |
# Limit to top N tags for this category | |
filtered_tags[category] = cat_tags[:limit] | |
# Regenerate all_tags list after applying limits | |
all_tags = [] | |
for category, cat_tags in filtered_tags.items(): | |
for tag, _ in cat_tags: | |
all_tags.append(tag) | |
# Update the result with limited tags | |
result['tags'] = filtered_tags | |
result['all_tags'] = all_tags | |
return result | |
def batch_process_images(folder_path, model, thresholds, metadata, threshold_profile, active_threshold, | |
active_category_thresholds, save_dir=None, progress_callback=None, | |
min_confidence=0.1, batch_size=1, category_limits=None): | |
""" | |
Process all images in a folder with optional batching for improved performance. | |
Args: | |
folder_path: Path to folder containing images | |
model: The image tagger model | |
thresholds: Thresholds dictionary | |
metadata: Metadata dictionary | |
threshold_profile: Selected threshold profile | |
active_threshold: Overall threshold value | |
active_category_thresholds: Category-specific thresholds | |
save_dir: Directory to save tag files (if None uses default) | |
progress_callback: Optional callback for progress updates | |
min_confidence: Minimum confidence threshold | |
batch_size: Number of images to process at once (default: 1) | |
category_limits: Dictionary mapping categories to their tag limits (0 = unlimited) | |
Returns: | |
Dictionary with results for each image | |
""" | |
from .file_utils import save_tags_to_file # Import here to avoid circular imports | |
import torch | |
from PIL import Image | |
import time | |
print(f"Starting batch processing on {folder_path} with batch size {batch_size}") | |
start_time = time.time() | |
# Find all image files in the folder | |
image_extensions = ['*.jpg', '*.jpeg', '*.png'] | |
image_files = [] | |
for ext in image_extensions: | |
image_files.extend(glob.glob(os.path.join(folder_path, ext))) | |
image_files.extend(glob.glob(os.path.join(folder_path, ext.upper()))) | |
# Use a set to remove duplicate files (Windows filesystems are case-insensitive) | |
if os.name == 'nt': # Windows | |
# Use lowercase paths for comparison on Windows | |
unique_paths = set() | |
unique_files = [] | |
for file_path in image_files: | |
normalized_path = os.path.normpath(file_path).lower() | |
if normalized_path not in unique_paths: | |
unique_paths.add(normalized_path) | |
unique_files.append(file_path) | |
image_files = unique_files | |
# Sort files for consistent processing order | |
image_files.sort() | |
if not image_files: | |
return { | |
'success': False, | |
'error': f"No images found in {folder_path}", | |
'results': {} | |
} | |
print(f"Found {len(image_files)} images to process") | |
# Use the provided save directory or create a default one | |
if save_dir is None: | |
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
save_dir = os.path.join(app_dir, "saved_tags") | |
# Ensure the directory exists | |
os.makedirs(save_dir, exist_ok=True) | |
# Process images in batches | |
results = {} | |
total_images = len(image_files) | |
processed = 0 | |
# Process in batches | |
for i in range(0, total_images, batch_size): | |
batch_start = time.time() | |
# Get current batch of images | |
batch_files = image_files[i:i+batch_size] | |
batch_size_actual = len(batch_files) | |
print(f"Processing batch {i//batch_size + 1}/{(total_images + batch_size - 1)//batch_size}: {batch_size_actual} images") | |
if batch_size > 1: | |
# True batch processing for multiple images at once | |
try: | |
# Using batch processing if batch_size > 1 | |
batch_results = process_image_batch( | |
image_paths=batch_files, | |
model=model, | |
thresholds=thresholds, | |
metadata=metadata, | |
threshold_profile=threshold_profile, | |
active_threshold=active_threshold, | |
active_category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
# Process and save results for each image in the batch | |
for j, image_path in enumerate(batch_files): | |
# Update progress if callback provided | |
if progress_callback: | |
progress_callback(processed + j, total_images, image_path) | |
if j < len(batch_results): | |
result = batch_results[j] | |
# Apply category limits if specified | |
if category_limits and result['success']: | |
# Use the apply_category_limits function instead of the inline code | |
result = apply_category_limits(result, category_limits) | |
# Debug print if you want | |
print(f"Applied limits for {os.path.basename(image_path)}, remaining tags: {len(result['all_tags'])}") | |
# Save the tags to a file | |
if result['success']: | |
output_path = save_tags_to_file( | |
image_path=image_path, | |
all_tags=result['all_tags'], | |
custom_dir=save_dir, | |
overwrite=True | |
) | |
result['output_path'] = str(output_path) | |
# Store the result | |
results[image_path] = result | |
else: | |
# Handle case where batch processing returned fewer results than expected | |
results[image_path] = { | |
'success': False, | |
'error': 'Batch processing error: missing result', | |
'all_tags': [] | |
} | |
except Exception as e: | |
print(f"Batch processing error: {str(e)}") | |
traceback.print_exc() | |
# Fall back to processing images one by one in this batch | |
for j, image_path in enumerate(batch_files): | |
if progress_callback: | |
progress_callback(processed + j, total_images, image_path) | |
result = process_image( | |
image_path=image_path, | |
model=model, | |
thresholds=thresholds, | |
metadata=metadata, | |
threshold_profile=threshold_profile, | |
active_threshold=active_threshold, | |
active_category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
# Apply category limits if specified | |
if category_limits and result['success']: | |
# Use the apply_category_limits function | |
result = apply_category_limits(result, category_limits) | |
if result['success']: | |
output_path = save_tags_to_file( | |
image_path=image_path, | |
all_tags=result['all_tags'], | |
custom_dir=save_dir, | |
overwrite=True | |
) | |
result['output_path'] = str(output_path) | |
results[image_path] = result | |
else: | |
# Process one by one if batch_size is 1 | |
for j, image_path in enumerate(batch_files): | |
if progress_callback: | |
progress_callback(processed + j, total_images, image_path) | |
result = process_image( | |
image_path=image_path, | |
model=model, | |
thresholds=thresholds, | |
metadata=metadata, | |
threshold_profile=threshold_profile, | |
active_threshold=active_threshold, | |
active_category_thresholds=active_category_thresholds, | |
min_confidence=min_confidence | |
) | |
# Apply category limits if specified | |
if category_limits and result['success']: | |
# Use the apply_category_limits function | |
result = apply_category_limits(result, category_limits) | |
if result['success']: | |
output_path = save_tags_to_file( | |
image_path=image_path, | |
all_tags=result['all_tags'], | |
custom_dir=save_dir, | |
overwrite=True | |
) | |
result['output_path'] = str(output_path) | |
results[image_path] = result | |
# Update processed count | |
processed += batch_size_actual | |
# Calculate batch timing | |
batch_end = time.time() | |
batch_time = batch_end - batch_start | |
print(f"Batch processed in {batch_time:.2f} seconds ({batch_time/batch_size_actual:.2f} seconds per image)") | |
# Final progress update | |
if progress_callback: | |
progress_callback(total_images, total_images, None) | |
end_time = time.time() | |
total_time = end_time - start_time | |
print(f"Batch processing finished. Total time: {total_time:.2f} seconds, Average: {total_time/total_images:.2f} seconds per image") | |
return { | |
'success': True, | |
'total': total_images, | |
'processed': len(results), | |
'results': results, | |
'save_dir': save_dir, | |
'time_elapsed': end_time - start_time | |
} | |
def process_image_batch(image_paths, model, thresholds, metadata, threshold_profile, active_threshold, active_category_thresholds, min_confidence=0.1): | |
""" | |
Process a batch of images at once. | |
Args: | |
image_paths: List of paths to the images | |
model: The image tagger model | |
thresholds: Thresholds dictionary | |
metadata: Metadata dictionary | |
threshold_profile: Selected threshold profile | |
active_threshold: Overall threshold value | |
active_category_thresholds: Category-specific thresholds | |
min_confidence: Minimum confidence to include in results | |
Returns: | |
List of dictionaries with tags, all probabilities, and other info for each image | |
""" | |
try: | |
import torch | |
from PIL import Image | |
import torchvision.transforms as transforms | |
# Identify the model type we're using for better error handling | |
model_type = model.__class__.__name__ | |
print(f"Running batch processing with model type: {model_type}") | |
# Prepare the transformation for the images | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), # Adjust based on your model's expected input | |
transforms.ToTensor(), | |
]) | |
# Get model information | |
device = next(model.parameters()).device | |
dtype = next(model.parameters()).dtype | |
print(f"Model is using device: {device}, dtype: {dtype}") | |
# Load and preprocess all images | |
batch_tensor = [] | |
valid_images = [] | |
for img_path in image_paths: | |
try: | |
img = Image.open(img_path).convert('RGB') | |
img_tensor = transform(img) | |
img_tensor = img_tensor.to(device=device, dtype=dtype) | |
batch_tensor.append(img_tensor) | |
valid_images.append(img_path) | |
except Exception as e: | |
print(f"Error loading image {img_path}: {str(e)}") | |
if not batch_tensor: | |
return [] | |
# Stack all tensors into a single batch | |
batch_input = torch.stack(batch_tensor) | |
# Process entire batch at once | |
with torch.no_grad(): | |
try: | |
# Forward pass on the whole batch | |
output = model(batch_input) | |
# Handle tuple output format | |
if isinstance(output, tuple): | |
probs_batch = torch.sigmoid(output[1]) | |
else: | |
probs_batch = torch.sigmoid(output) | |
# Process each image's results | |
results = [] | |
for i, img_path in enumerate(valid_images): | |
probs = probs_batch[i].unsqueeze(0) # Add batch dimension back | |
# Extract and organize all probabilities | |
all_probs = {} | |
for idx in range(probs.size(1)): | |
prob_value = probs[0, idx].item() | |
if prob_value >= min_confidence: | |
tag, category = model.dataset.get_tag_info(idx) | |
if category not in all_probs: | |
all_probs[category] = [] | |
all_probs[category].append((tag, prob_value)) | |
# Sort tags by probability | |
for category in all_probs: | |
all_probs[category] = sorted(all_probs[category], key=lambda x: x[1], reverse=True) | |
# Get filtered tags | |
tags = {} | |
for category, cat_tags in all_probs.items(): | |
threshold = active_category_thresholds.get(category, active_threshold) if active_category_thresholds else active_threshold | |
tags[category] = [(tag, prob) for tag, prob in cat_tags if prob >= threshold] | |
# Create a flat list of all tags above threshold | |
all_tags = [] | |
for category, cat_tags in tags.items(): | |
for tag, _ in cat_tags: | |
all_tags.append(tag) | |
results.append({ | |
'tags': tags, | |
'all_probs': all_probs, | |
'all_tags': all_tags, | |
'success': True | |
}) | |
return results | |
except RuntimeError as e: | |
# If we encounter CUDA out of memory or another runtime error, | |
# fall back to processing one by one | |
print(f"Error in batch processing: {str(e)}") | |
print("Falling back to one-by-one processing...") | |
# Process one by one as fallback | |
results = [] | |
for i, (img_tensor, img_path) in enumerate(zip(batch_tensor, valid_images)): | |
try: | |
input_tensor = img_tensor.unsqueeze(0) | |
output = model(input_tensor) | |
if isinstance(output, tuple): | |
probs = torch.sigmoid(output[1]) | |
else: | |
probs = torch.sigmoid(output) | |
# Same post-processing as before... | |
# [Code omitted for brevity] | |
except Exception as e: | |
print(f"Error processing image {img_path}: {str(e)}") | |
results.append({ | |
'tags': {}, | |
'all_probs': {}, | |
'all_tags': [], | |
'success': False, | |
'error': str(e) | |
}) | |
return results | |
except Exception as e: | |
print(f"Error in batch processing: {str(e)}") | |
import traceback | |
traceback.print_exc() |