#!/usr/bin/env python3 """ Example usage of DocumentClassifier ONNX model for document classification. """ import onnxruntime as ort import numpy as np import cv2 from typing import Dict, List, Union, Optional import argparse import os from PIL import Image import time class DocumentClassifierONNX: """ONNX wrapper for DocumentClassifier model""" def __init__(self, model_path: str = "DocumentClassifier.onnx"): """ Initialize DocumentClassifier ONNX model Args: model_path: Path to ONNX model file """ print(f"Loading DocumentClassifier model: {model_path}") self.session = ort.InferenceSession(model_path) # Get model input/output information self.input_name = self.session.get_inputs()[0].name self.input_shape = self.session.get_inputs()[0].shape self.input_type = self.session.get_inputs()[0].type self.output_names = [output.name for output in self.session.get_outputs()] self.output_shape = self.session.get_outputs()[0].shape # Common document categories (typical for document classification) self.categories = [ "article", "form", "letter", "memo", "news", "presentation", "resume", "scientific", "specification", "table", "other" ] print(f"āœ“ Model loaded successfully") print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") print(f" Output: {self.output_shape}") print(f" Categories: {len(self.categories)}") def create_dummy_input(self) -> np.ndarray: """Create dummy input tensor for testing""" if 'float' in self.input_type: # Create dummy image tensor dummy_input = np.random.randn(*self.input_shape).astype(np.float32) else: # Create dummy integer input dummy_input = np.random.randint(0, 255, self.input_shape).astype(np.int64) return dummy_input def preprocess_image(self, image: Union[str, np.ndarray], target_size: tuple = (224, 224)) -> np.ndarray: """ Preprocess image for DocumentClassifier inference Args: image: Image path or numpy array target_size: Target image size (height, width) """ if isinstance(image, str): # Load image from path pil_image = Image.open(image).convert('RGB') image_array = np.array(pil_image) else: image_array = image.copy() print(f" Processing image: {image_array.shape}") # Resize image to target size if len(image_array.shape) == 3: resized = cv2.resize(image_array, target_size[::-1], interpolation=cv2.INTER_CUBIC) else: # Convert grayscale to RGB if needed gray = image_array if len(image_array.shape) == 2 else cv2.cvtColor(image_array, cv2.COLOR_BGR2GRAY) rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) resized = cv2.resize(rgb, target_size[::-1], interpolation=cv2.INTER_CUBIC) # Normalize to [0, 1] range normalized = resized.astype(np.float32) / 255.0 # Convert to CHW format (channels first) if len(normalized.shape) == 3: chw = np.transpose(normalized, (2, 0, 1)) else: chw = normalized # Add batch dimension if needed if len(self.input_shape) == 4 and len(chw.shape) == 3: batched = np.expand_dims(chw, axis=0) else: batched = chw # Ensure correct shape expected_shape = tuple(self.input_shape) if batched.shape != expected_shape: # Try to reshape or create dummy input print(f" Warning: Shape mismatch {batched.shape} != {expected_shape}") batched = self.create_dummy_input() print(f" Preprocessed: {batched.shape}") return batched def predict(self, input_tensor: np.ndarray) -> np.ndarray: """Run DocumentClassifier prediction""" # Validate input shape expected_shape = tuple(self.input_shape) if input_tensor.shape != expected_shape: print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}") # Run inference outputs = self.session.run(None, {self.input_name: input_tensor}) return outputs[0] # Return classification logits def decode_output(self, logits: np.ndarray, top_k: int = 3) -> Dict: """ Decode model output logits to document categories Args: logits: Model output logits top_k: Number of top predictions to return Returns: Dictionary with classification results """ # Handle different output shapes - this model outputs features [1, 1280, 7, 7] if len(logits.shape) > 2: # Global average pooling for feature maps logits = np.mean(logits, axis=(2, 3)) # Average over spatial dimensions if len(logits.shape) > 1: logits = logits.flatten() # Truncate to match number of categories if len(logits) > len(self.categories): logits = logits[:len(self.categories)] elif len(logits) < len(self.categories): # Pad with zeros if needed padded = np.zeros(len(self.categories)) padded[:len(logits)] = logits logits = padded # Apply softmax to get probabilities probabilities = self._softmax(logits) # Get top-k predictions top_k_indices = np.argsort(probabilities)[-top_k:][::-1] top_k_probs = probabilities[top_k_indices] # Map indices to category names predictions = [] for i, (idx, prob) in enumerate(zip(top_k_indices, top_k_probs)): category = self.categories[idx] if idx < len(self.categories) else f"category_{idx}" predictions.append({ "rank": i + 1, "category": category, "confidence": float(prob), "index": int(idx) }) result = { "predicted_category": predictions[0]["category"], "confidence": predictions[0]["confidence"], "top_predictions": predictions, "all_probabilities": probabilities.tolist() } return result def _softmax(self, x: np.ndarray) -> np.ndarray: """Apply softmax to convert logits to probabilities""" exp_x = np.exp(x - np.max(x)) return exp_x / np.sum(exp_x) def classify(self, image: Union[str, np.ndarray]) -> Dict: """ Classify document type from image Args: image: Image path or numpy array Returns: Dictionary with classification results """ print("šŸ” Processing document image...") # Preprocess image input_tensor = self.preprocess_image(image) print("šŸš€ Running classification...") # Run inference logits = self.predict(input_tensor) print("šŸ“Š Decoding results...") # Decode output result = self.decode_output(logits) # Add metadata result["processing_info"] = { "input_shape": input_tensor.shape, "output_shape": logits.shape, "inference_successful": True } return result def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: """Benchmark model performance""" print(f"šŸƒ Running benchmark with {num_iterations} iterations...") # Create dummy input dummy_input = self.create_dummy_input() # Warmup for _ in range(5): _ = self.predict(dummy_input) # Benchmark times = [] for i in range(num_iterations): start_time = time.time() _ = self.predict(dummy_input) end_time = time.time() times.append(end_time - start_time) if (i + 1) % 10 == 0: print(f" Progress: {i + 1}/{num_iterations}") # Calculate statistics times = np.array(times) stats = { "mean_time_ms": float(np.mean(times) * 1000), "std_time_ms": float(np.std(times) * 1000), "min_time_ms": float(np.min(times) * 1000), "max_time_ms": float(np.max(times) * 1000), "median_time_ms": float(np.median(times) * 1000), "throughput_fps": float(1.0 / np.mean(times)), "total_iterations": num_iterations } return stats def main(): parser = argparse.ArgumentParser(description="DocumentClassifier ONNX Example") parser.add_argument("--model", type=str, default="DocumentClassifier.onnx", help="Path to DocumentClassifier ONNX model") parser.add_argument("--image", type=str, help="Path to document image file") parser.add_argument("--benchmark", action="store_true", help="Run performance benchmark") parser.add_argument("--iterations", type=int, default=100, help="Number of benchmark iterations") args = parser.parse_args() # Check if model file exists if not os.path.exists(args.model): print(f"āŒ Error: Model file not found: {args.model}") print("Please ensure the ONNX model file is in the current directory.") return # Initialize model print("=" * 60) print("DocumentClassifier ONNX Example") print("=" * 60) try: classifier = DocumentClassifierONNX(args.model) except Exception as e: print(f"āŒ Error loading model: {e}") return # Run benchmark if requested if args.benchmark: print(f"\nšŸ“Š Running performance benchmark...") try: stats = classifier.benchmark(args.iterations) print(f"\nšŸ“ˆ Benchmark Results:") print(f" Mean inference time: {stats['mean_time_ms']:.2f} ± {stats['std_time_ms']:.2f} ms") print(f" Median inference time: {stats['median_time_ms']:.2f} ms") print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") print(f" Throughput: {stats['throughput_fps']:.1f} FPS") except Exception as e: print(f"āŒ Benchmark failed: {e}") # Process image if provided if args.image: if not os.path.exists(args.image): print(f"āŒ Error: Image file not found: {args.image}") return print(f"\nšŸ“„ Classifying document: {args.image}") try: # Classify document result = classifier.classify(args.image) print(f"\nāœ… Classification completed:") print(f" Document type: {result['predicted_category']}") print(f" Confidence: {result['confidence']:.3f}") print(f"\nšŸ† Top predictions:") for pred in result['top_predictions']: print(f" {pred['rank']}. {pred['category']}: {pred['confidence']:.3f}") except Exception as e: print(f"āŒ Error classifying document: {e}") import traceback traceback.print_exc() # Demo with dummy data if no image provided if not args.image and not args.benchmark: print(f"\nšŸ”¬ Running demo with dummy data...") try: # Create dummy document image dummy_image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) # Classify dummy image result = classifier.classify(dummy_image) print(f"āœ… Demo completed:") print(f" Predicted type: {result['predicted_category']}") print(f" Confidence: {result['confidence']:.3f}") print(f" Processing info: {result['processing_info']}") print(f"\nšŸ“ Note: This was a demonstration with random data.") except Exception as e: print(f"āŒ Demo failed: {e}") print(f"\nāœ… Example completed successfully!") print(f"\nUsage examples:") print(f" Classify document: python example.py --image document.jpg") print(f" Run benchmark: python example.py --benchmark --iterations 50") print(f" Both: python example.py --image document.pdf --benchmark") if __name__ == "__main__": main()