asmud's picture
Rename model name...
253a9b0
#!/usr/bin/env python3
"""
Example usage of Docling TableFormer ONNX models for table structure recognition.
"""
import onnxruntime as ort
import cv2
import numpy as np
from typing import Dict, List, Tuple, Optional
import argparse
import os
class TableFormerONNX:
"""ONNX wrapper for TableFormer models"""
def __init__(self, model_path: str, model_type: str = "accurate"):
"""
Initialize TableFormer ONNX model
Args:
model_path: Path to ONNX model file
model_type: "accurate" or "fast"
"""
print(f"Loading {model_type} TableFormer model: {model_path}")
self.session = ort.InferenceSession(model_path)
self.model_type = model_type
# Get model input/output information
self.input_name = self.session.get_inputs()[0].name
self.input_shape = self.session.get_inputs()[0].shape
self.input_type = self.session.get_inputs()[0].type
self.output_names = [output.name for output in self.session.get_outputs()]
print(f"โœ“ Model loaded successfully")
print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})")
print(f" Outputs: {len(self.output_names)} tensors")
def create_dummy_input(self) -> np.ndarray:
"""Create dummy input tensor for testing"""
if self.input_type == 'tensor(int64)':
# Create dummy integer input
dummy_input = np.random.randint(0, 100, self.input_shape).astype(np.int64)
else:
# Create dummy float input
dummy_input = np.random.randn(*self.input_shape).astype(np.float32)
return dummy_input
def preprocess_table_region(self, table_image: np.ndarray) -> np.ndarray:
"""
Preprocess table region image for TableFormer inference
Note: This is a simplified preprocessing example.
The actual TableFormer preprocessing may be more complex and specific
to the training procedure.
"""
# Convert to RGB if needed
if len(table_image.shape) == 3 and table_image.shape[2] == 3:
# Already RGB
processed = table_image
elif len(table_image.shape) == 3 and table_image.shape[2] == 4:
# RGBA to RGB
processed = cv2.cvtColor(table_image, cv2.COLOR_RGBA2RGB)
elif len(table_image.shape) == 2:
# Grayscale to RGB
processed = cv2.cvtColor(table_image, cv2.COLOR_GRAY2RGB)
else:
processed = table_image
# Resize to expected input size (this would depend on actual model requirements)
# For now, we'll create a dummy tensor matching the model's expected input
if self.input_type == 'tensor(int64)':
# For models expecting integer inputs (like sequence models)
dummy_features = np.random.randint(0, 100, self.input_shape).astype(np.int64)
else:
# For models expecting float inputs
dummy_features = np.random.randn(*self.input_shape).astype(np.float32)
return dummy_features
def predict(self, input_tensor: np.ndarray) -> Dict[str, np.ndarray]:
"""Run table structure prediction"""
# Validate input shape
expected_shape = tuple(self.input_shape)
if input_tensor.shape != expected_shape:
print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}")
# Run inference
outputs = self.session.run(None, {self.input_name: input_tensor})
# Package results
result = {}
for i, name in enumerate(self.output_names):
result[name] = outputs[i]
return result
def extract_table_structure(self, table_image: np.ndarray) -> Dict:
"""
Extract table structure from table region image
Args:
table_image: RGB image of table region
Returns:
Dictionary containing table structure information
"""
# Preprocess image
input_tensor = self.preprocess_table_region(table_image)
# Get raw predictions
raw_outputs = self.predict(input_tensor)
# Post-process to extract table structure
# Note: This is a simplified example. The actual post-processing
# would depend on the specific output format of the TableFormer model
table_structure = {
"model_type": self.model_type,
"raw_outputs": {name: output.shape for name, output in raw_outputs.items()},
"cells": [], # Would contain cell boundary and type information
"rows": [], # Would contain row definitions
"columns": [], # Would contain column definitions
"confidence": 0.95, # Placeholder confidence score
"processing_note": "This is a demonstration output. Real implementation would parse model outputs."
}
# In a real implementation, you would:
# 1. Parse the raw model outputs
# 2. Extract cell boundaries and classifications
# 3. Determine row and column structure
# 4. Generate structured table representation
return table_structure
def benchmark(self, num_iterations: int = 100) -> Dict[str, float]:
"""Benchmark model performance"""
print(f"Running benchmark with {num_iterations} iterations...")
# Create dummy input
dummy_input = self.create_dummy_input()
# Warmup
for _ in range(5):
_ = self.predict(dummy_input)
# Benchmark
import time
times = []
for i in range(num_iterations):
start_time = time.time()
_ = self.predict(dummy_input)
end_time = time.time()
times.append(end_time - start_time)
if (i + 1) % 10 == 0:
print(f" Progress: {i + 1}/{num_iterations}")
# Calculate statistics
times = np.array(times)
stats = {
"mean_time_ms": float(np.mean(times) * 1000),
"std_time_ms": float(np.std(times) * 1000),
"min_time_ms": float(np.min(times) * 1000),
"max_time_ms": float(np.max(times) * 1000),
"median_time_ms": float(np.median(times) * 1000),
"throughput_fps": float(1.0 / np.mean(times))
}
return stats
def main():
parser = argparse.ArgumentParser(description="TableFormer ONNX Example")
parser.add_argument("--model", type=str,
choices=["accurate", "fast"],
default="accurate",
help="Model variant to use")
parser.add_argument("--image", type=str,
help="Path to table image (optional)")
parser.add_argument("--benchmark", action="store_true",
help="Run performance benchmark")
parser.add_argument("--iterations", type=int, default=100,
help="Number of benchmark iterations")
args = parser.parse_args()
# Model paths
model_files = {
"accurate": "tableformer_accurate.onnx",
"fast": "tableformer_fast.onnx"
}
model_path = model_files[args.model]
# Check if model file exists
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
print("Please ensure the ONNX model files are in the current directory.")
return
# Initialize model
print("=" * 60)
print(f"TableFormer ONNX Example - {args.model.title()} Model")
print("=" * 60)
tableformer = TableFormerONNX(model_path, args.model)
# Run benchmark if requested
if args.benchmark:
print(f"\n๐Ÿ“Š Running performance benchmark...")
stats = tableformer.benchmark(args.iterations)
print(f"\n๐Ÿ“ˆ Benchmark Results ({args.model} model):")
print(f" Mean inference time: {stats['mean_time_ms']:.2f} ยฑ {stats['std_time_ms']:.2f} ms")
print(f" Median inference time: {stats['median_time_ms']:.2f} ms")
print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms")
print(f" Throughput: {stats['throughput_fps']:.1f} FPS")
# Process image if provided
if args.image:
if not os.path.exists(args.image):
print(f"Error: Image file not found: {args.image}")
return
print(f"\n๐Ÿ–ผ๏ธ Processing image: {args.image}")
# Load image
image = cv2.imread(args.image)
if image is None:
print(f"Error: Could not load image: {args.image}")
return
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Extract table structure
structure = tableformer.extract_table_structure(image_rgb)
print(f"โœ“ Table structure extracted:")
print(f" Model: {structure['model_type']}")
print(f" Raw outputs: {structure['raw_outputs']}")
print(f" Confidence: {structure['confidence']}")
print(f" Note: {structure['processing_note']}")
# Demo with dummy data
if not args.image:
print(f"\n๐Ÿ”ฌ Running demo with dummy data...")
# Create dummy table image
dummy_image = np.random.randint(0, 255, (300, 400, 3), dtype=np.uint8)
# Process dummy image
structure = tableformer.extract_table_structure(dummy_image)
print(f"โœ“ Demo completed:")
print(f" Model: {structure['model_type']}")
print(f" Raw outputs: {structure['raw_outputs']}")
print(f" Processing: {structure['processing_note']}")
print(f"\nโœ… Example completed successfully!")
print(f"\nTo process a real image, use: python example.py --model {args.model} --image your_table.jpg")
print(f"To run a benchmark, use: python example.py --model {args.model} --benchmark")
if __name__ == "__main__":
main()