File size: 10,282 Bytes
fbea007 253a9b0 fbea007 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
#!/usr/bin/env python3
"""
Example usage of Docling TableFormer ONNX models for table structure recognition.
"""
import onnxruntime as ort
import cv2
import numpy as np
from typing import Dict, List, Tuple, Optional
import argparse
import os
class TableFormerONNX:
"""ONNX wrapper for TableFormer models"""
def __init__(self, model_path: str, model_type: str = "accurate"):
"""
Initialize TableFormer ONNX model
Args:
model_path: Path to ONNX model file
model_type: "accurate" or "fast"
"""
print(f"Loading {model_type} TableFormer model: {model_path}")
self.session = ort.InferenceSession(model_path)
self.model_type = model_type
# Get model input/output information
self.input_name = self.session.get_inputs()[0].name
self.input_shape = self.session.get_inputs()[0].shape
self.input_type = self.session.get_inputs()[0].type
self.output_names = [output.name for output in self.session.get_outputs()]
print(f"โ Model loaded successfully")
print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})")
print(f" Outputs: {len(self.output_names)} tensors")
def create_dummy_input(self) -> np.ndarray:
"""Create dummy input tensor for testing"""
if self.input_type == 'tensor(int64)':
# Create dummy integer input
dummy_input = np.random.randint(0, 100, self.input_shape).astype(np.int64)
else:
# Create dummy float input
dummy_input = np.random.randn(*self.input_shape).astype(np.float32)
return dummy_input
def preprocess_table_region(self, table_image: np.ndarray) -> np.ndarray:
"""
Preprocess table region image for TableFormer inference
Note: This is a simplified preprocessing example.
The actual TableFormer preprocessing may be more complex and specific
to the training procedure.
"""
# Convert to RGB if needed
if len(table_image.shape) == 3 and table_image.shape[2] == 3:
# Already RGB
processed = table_image
elif len(table_image.shape) == 3 and table_image.shape[2] == 4:
# RGBA to RGB
processed = cv2.cvtColor(table_image, cv2.COLOR_RGBA2RGB)
elif len(table_image.shape) == 2:
# Grayscale to RGB
processed = cv2.cvtColor(table_image, cv2.COLOR_GRAY2RGB)
else:
processed = table_image
# Resize to expected input size (this would depend on actual model requirements)
# For now, we'll create a dummy tensor matching the model's expected input
if self.input_type == 'tensor(int64)':
# For models expecting integer inputs (like sequence models)
dummy_features = np.random.randint(0, 100, self.input_shape).astype(np.int64)
else:
# For models expecting float inputs
dummy_features = np.random.randn(*self.input_shape).astype(np.float32)
return dummy_features
def predict(self, input_tensor: np.ndarray) -> Dict[str, np.ndarray]:
"""Run table structure prediction"""
# Validate input shape
expected_shape = tuple(self.input_shape)
if input_tensor.shape != expected_shape:
print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}")
# Run inference
outputs = self.session.run(None, {self.input_name: input_tensor})
# Package results
result = {}
for i, name in enumerate(self.output_names):
result[name] = outputs[i]
return result
def extract_table_structure(self, table_image: np.ndarray) -> Dict:
"""
Extract table structure from table region image
Args:
table_image: RGB image of table region
Returns:
Dictionary containing table structure information
"""
# Preprocess image
input_tensor = self.preprocess_table_region(table_image)
# Get raw predictions
raw_outputs = self.predict(input_tensor)
# Post-process to extract table structure
# Note: This is a simplified example. The actual post-processing
# would depend on the specific output format of the TableFormer model
table_structure = {
"model_type": self.model_type,
"raw_outputs": {name: output.shape for name, output in raw_outputs.items()},
"cells": [], # Would contain cell boundary and type information
"rows": [], # Would contain row definitions
"columns": [], # Would contain column definitions
"confidence": 0.95, # Placeholder confidence score
"processing_note": "This is a demonstration output. Real implementation would parse model outputs."
}
# In a real implementation, you would:
# 1. Parse the raw model outputs
# 2. Extract cell boundaries and classifications
# 3. Determine row and column structure
# 4. Generate structured table representation
return table_structure
def benchmark(self, num_iterations: int = 100) -> Dict[str, float]:
"""Benchmark model performance"""
print(f"Running benchmark with {num_iterations} iterations...")
# Create dummy input
dummy_input = self.create_dummy_input()
# Warmup
for _ in range(5):
_ = self.predict(dummy_input)
# Benchmark
import time
times = []
for i in range(num_iterations):
start_time = time.time()
_ = self.predict(dummy_input)
end_time = time.time()
times.append(end_time - start_time)
if (i + 1) % 10 == 0:
print(f" Progress: {i + 1}/{num_iterations}")
# Calculate statistics
times = np.array(times)
stats = {
"mean_time_ms": float(np.mean(times) * 1000),
"std_time_ms": float(np.std(times) * 1000),
"min_time_ms": float(np.min(times) * 1000),
"max_time_ms": float(np.max(times) * 1000),
"median_time_ms": float(np.median(times) * 1000),
"throughput_fps": float(1.0 / np.mean(times))
}
return stats
def main():
parser = argparse.ArgumentParser(description="TableFormer ONNX Example")
parser.add_argument("--model", type=str,
choices=["accurate", "fast"],
default="accurate",
help="Model variant to use")
parser.add_argument("--image", type=str,
help="Path to table image (optional)")
parser.add_argument("--benchmark", action="store_true",
help="Run performance benchmark")
parser.add_argument("--iterations", type=int, default=100,
help="Number of benchmark iterations")
args = parser.parse_args()
# Model paths
model_files = {
"accurate": "tableformer_accurate.onnx",
"fast": "tableformer_fast.onnx"
}
model_path = model_files[args.model]
# Check if model file exists
if not os.path.exists(model_path):
print(f"Error: Model file not found: {model_path}")
print("Please ensure the ONNX model files are in the current directory.")
return
# Initialize model
print("=" * 60)
print(f"TableFormer ONNX Example - {args.model.title()} Model")
print("=" * 60)
tableformer = TableFormerONNX(model_path, args.model)
# Run benchmark if requested
if args.benchmark:
print(f"\n๐ Running performance benchmark...")
stats = tableformer.benchmark(args.iterations)
print(f"\n๐ Benchmark Results ({args.model} model):")
print(f" Mean inference time: {stats['mean_time_ms']:.2f} ยฑ {stats['std_time_ms']:.2f} ms")
print(f" Median inference time: {stats['median_time_ms']:.2f} ms")
print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms")
print(f" Throughput: {stats['throughput_fps']:.1f} FPS")
# Process image if provided
if args.image:
if not os.path.exists(args.image):
print(f"Error: Image file not found: {args.image}")
return
print(f"\n๐ผ๏ธ Processing image: {args.image}")
# Load image
image = cv2.imread(args.image)
if image is None:
print(f"Error: Could not load image: {args.image}")
return
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Extract table structure
structure = tableformer.extract_table_structure(image_rgb)
print(f"โ Table structure extracted:")
print(f" Model: {structure['model_type']}")
print(f" Raw outputs: {structure['raw_outputs']}")
print(f" Confidence: {structure['confidence']}")
print(f" Note: {structure['processing_note']}")
# Demo with dummy data
if not args.image:
print(f"\n๐ฌ Running demo with dummy data...")
# Create dummy table image
dummy_image = np.random.randint(0, 255, (300, 400, 3), dtype=np.uint8)
# Process dummy image
structure = tableformer.extract_table_structure(dummy_image)
print(f"โ Demo completed:")
print(f" Model: {structure['model_type']}")
print(f" Raw outputs: {structure['raw_outputs']}")
print(f" Processing: {structure['processing_note']}")
print(f"\nโ
Example completed successfully!")
print(f"\nTo process a real image, use: python example.py --model {args.model} --image your_table.jpg")
print(f"To run a benchmark, use: python example.py --model {args.model} --benchmark")
if __name__ == "__main__":
main() |