#!/usr/bin/env python3 """ Example usage of Docling TableFormer ONNX models for table structure recognition. """ import onnxruntime as ort import cv2 import numpy as np from typing import Dict, List, Tuple, Optional import argparse import os class TableFormerONNX: """ONNX wrapper for TableFormer models""" def __init__(self, model_path: str, model_type: str = "accurate"): """ Initialize TableFormer ONNX model Args: model_path: Path to ONNX model file model_type: "accurate" or "fast" """ print(f"Loading {model_type} TableFormer model: {model_path}") self.session = ort.InferenceSession(model_path) self.model_type = model_type # Get model input/output information self.input_name = self.session.get_inputs()[0].name self.input_shape = self.session.get_inputs()[0].shape self.input_type = self.session.get_inputs()[0].type self.output_names = [output.name for output in self.session.get_outputs()] print(f"āœ“ Model loaded successfully") print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") print(f" Outputs: {len(self.output_names)} tensors") def create_dummy_input(self) -> np.ndarray: """Create dummy input tensor for testing""" if self.input_type == 'tensor(int64)': # Create dummy integer input dummy_input = np.random.randint(0, 100, self.input_shape).astype(np.int64) else: # Create dummy float input dummy_input = np.random.randn(*self.input_shape).astype(np.float32) return dummy_input def preprocess_table_region(self, table_image: np.ndarray) -> np.ndarray: """ Preprocess table region image for TableFormer inference Note: This is a simplified preprocessing example. The actual TableFormer preprocessing may be more complex and specific to the training procedure. """ # Convert to RGB if needed if len(table_image.shape) == 3 and table_image.shape[2] == 3: # Already RGB processed = table_image elif len(table_image.shape) == 3 and table_image.shape[2] == 4: # RGBA to RGB processed = cv2.cvtColor(table_image, cv2.COLOR_RGBA2RGB) elif len(table_image.shape) == 2: # Grayscale to RGB processed = cv2.cvtColor(table_image, cv2.COLOR_GRAY2RGB) else: processed = table_image # Resize to expected input size (this would depend on actual model requirements) # For now, we'll create a dummy tensor matching the model's expected input if self.input_type == 'tensor(int64)': # For models expecting integer inputs (like sequence models) dummy_features = np.random.randint(0, 100, self.input_shape).astype(np.int64) else: # For models expecting float inputs dummy_features = np.random.randn(*self.input_shape).astype(np.float32) return dummy_features def predict(self, input_tensor: np.ndarray) -> Dict[str, np.ndarray]: """Run table structure prediction""" # Validate input shape expected_shape = tuple(self.input_shape) if input_tensor.shape != expected_shape: print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}") # Run inference outputs = self.session.run(None, {self.input_name: input_tensor}) # Package results result = {} for i, name in enumerate(self.output_names): result[name] = outputs[i] return result def extract_table_structure(self, table_image: np.ndarray) -> Dict: """ Extract table structure from table region image Args: table_image: RGB image of table region Returns: Dictionary containing table structure information """ # Preprocess image input_tensor = self.preprocess_table_region(table_image) # Get raw predictions raw_outputs = self.predict(input_tensor) # Post-process to extract table structure # Note: This is a simplified example. The actual post-processing # would depend on the specific output format of the TableFormer model table_structure = { "model_type": self.model_type, "raw_outputs": {name: output.shape for name, output in raw_outputs.items()}, "cells": [], # Would contain cell boundary and type information "rows": [], # Would contain row definitions "columns": [], # Would contain column definitions "confidence": 0.95, # Placeholder confidence score "processing_note": "This is a demonstration output. Real implementation would parse model outputs." } # In a real implementation, you would: # 1. Parse the raw model outputs # 2. Extract cell boundaries and classifications # 3. Determine row and column structure # 4. Generate structured table representation return table_structure def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: """Benchmark model performance""" print(f"Running benchmark with {num_iterations} iterations...") # Create dummy input dummy_input = self.create_dummy_input() # Warmup for _ in range(5): _ = self.predict(dummy_input) # Benchmark import time times = [] for i in range(num_iterations): start_time = time.time() _ = self.predict(dummy_input) end_time = time.time() times.append(end_time - start_time) if (i + 1) % 10 == 0: print(f" Progress: {i + 1}/{num_iterations}") # Calculate statistics times = np.array(times) stats = { "mean_time_ms": float(np.mean(times) * 1000), "std_time_ms": float(np.std(times) * 1000), "min_time_ms": float(np.min(times) * 1000), "max_time_ms": float(np.max(times) * 1000), "median_time_ms": float(np.median(times) * 1000), "throughput_fps": float(1.0 / np.mean(times)) } return stats def main(): parser = argparse.ArgumentParser(description="TableFormer ONNX Example") parser.add_argument("--model", type=str, choices=["accurate", "fast"], default="accurate", help="Model variant to use") parser.add_argument("--image", type=str, help="Path to table image (optional)") parser.add_argument("--benchmark", action="store_true", help="Run performance benchmark") parser.add_argument("--iterations", type=int, default=100, help="Number of benchmark iterations") args = parser.parse_args() # Model paths model_files = { "accurate": "tableformer_accurate.onnx", "fast": "tableformer_fast.onnx" } model_path = model_files[args.model] # Check if model file exists if not os.path.exists(model_path): print(f"Error: Model file not found: {model_path}") print("Please ensure the ONNX model files are in the current directory.") return # Initialize model print("=" * 60) print(f"TableFormer ONNX Example - {args.model.title()} Model") print("=" * 60) tableformer = TableFormerONNX(model_path, args.model) # Run benchmark if requested if args.benchmark: print(f"\nšŸ“Š Running performance benchmark...") stats = tableformer.benchmark(args.iterations) print(f"\nšŸ“ˆ Benchmark Results ({args.model} model):") print(f" Mean inference time: {stats['mean_time_ms']:.2f} ± {stats['std_time_ms']:.2f} ms") print(f" Median inference time: {stats['median_time_ms']:.2f} ms") print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") print(f" Throughput: {stats['throughput_fps']:.1f} FPS") # Process image if provided if args.image: if not os.path.exists(args.image): print(f"Error: Image file not found: {args.image}") return print(f"\nšŸ–¼ļø Processing image: {args.image}") # Load image image = cv2.imread(args.image) if image is None: print(f"Error: Could not load image: {args.image}") return # Convert BGR to RGB image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Extract table structure structure = tableformer.extract_table_structure(image_rgb) print(f"āœ“ Table structure extracted:") print(f" Model: {structure['model_type']}") print(f" Raw outputs: {structure['raw_outputs']}") print(f" Confidence: {structure['confidence']}") print(f" Note: {structure['processing_note']}") # Demo with dummy data if not args.image: print(f"\nšŸ”¬ Running demo with dummy data...") # Create dummy table image dummy_image = np.random.randint(0, 255, (300, 400, 3), dtype=np.uint8) # Process dummy image structure = tableformer.extract_table_structure(dummy_image) print(f"āœ“ Demo completed:") print(f" Model: {structure['model_type']}") print(f" Raw outputs: {structure['raw_outputs']}") print(f" Processing: {structure['processing_note']}") print(f"\nāœ… Example completed successfully!") print(f"\nTo process a real image, use: python example.py --model {args.model} --image your_table.jpg") print(f"To run a benchmark, use: python example.py --model {args.model} --benchmark") if __name__ == "__main__": main()