|
|
|
""" |
|
Example usage of Docling TableFormer ONNX models for table structure recognition. |
|
""" |
|
|
|
import onnxruntime as ort |
|
import cv2 |
|
import numpy as np |
|
from typing import Dict, List, Tuple, Optional |
|
import argparse |
|
import os |
|
|
|
class TableFormerONNX: |
|
"""ONNX wrapper for TableFormer models""" |
|
|
|
def __init__(self, model_path: str, model_type: str = "accurate"): |
|
""" |
|
Initialize TableFormer ONNX model |
|
|
|
Args: |
|
model_path: Path to ONNX model file |
|
model_type: "accurate" or "fast" |
|
""" |
|
print(f"Loading {model_type} TableFormer model: {model_path}") |
|
self.session = ort.InferenceSession(model_path) |
|
self.model_type = model_type |
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name |
|
self.input_shape = self.session.get_inputs()[0].shape |
|
self.input_type = self.session.get_inputs()[0].type |
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
|
|
print(f"โ Model loaded successfully") |
|
print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") |
|
print(f" Outputs: {len(self.output_names)} tensors") |
|
|
|
def create_dummy_input(self) -> np.ndarray: |
|
"""Create dummy input tensor for testing""" |
|
if self.input_type == 'tensor(int64)': |
|
|
|
dummy_input = np.random.randint(0, 100, self.input_shape).astype(np.int64) |
|
else: |
|
|
|
dummy_input = np.random.randn(*self.input_shape).astype(np.float32) |
|
|
|
return dummy_input |
|
|
|
def preprocess_table_region(self, table_image: np.ndarray) -> np.ndarray: |
|
""" |
|
Preprocess table region image for TableFormer inference |
|
|
|
Note: This is a simplified preprocessing example. |
|
The actual TableFormer preprocessing may be more complex and specific |
|
to the training procedure. |
|
""" |
|
|
|
|
|
if len(table_image.shape) == 3 and table_image.shape[2] == 3: |
|
|
|
processed = table_image |
|
elif len(table_image.shape) == 3 and table_image.shape[2] == 4: |
|
|
|
processed = cv2.cvtColor(table_image, cv2.COLOR_RGBA2RGB) |
|
elif len(table_image.shape) == 2: |
|
|
|
processed = cv2.cvtColor(table_image, cv2.COLOR_GRAY2RGB) |
|
else: |
|
processed = table_image |
|
|
|
|
|
|
|
if self.input_type == 'tensor(int64)': |
|
|
|
dummy_features = np.random.randint(0, 100, self.input_shape).astype(np.int64) |
|
else: |
|
|
|
dummy_features = np.random.randn(*self.input_shape).astype(np.float32) |
|
|
|
return dummy_features |
|
|
|
def predict(self, input_tensor: np.ndarray) -> Dict[str, np.ndarray]: |
|
"""Run table structure prediction""" |
|
|
|
|
|
expected_shape = tuple(self.input_shape) |
|
if input_tensor.shape != expected_shape: |
|
print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}") |
|
|
|
|
|
outputs = self.session.run(None, {self.input_name: input_tensor}) |
|
|
|
|
|
result = {} |
|
for i, name in enumerate(self.output_names): |
|
result[name] = outputs[i] |
|
|
|
return result |
|
|
|
def extract_table_structure(self, table_image: np.ndarray) -> Dict: |
|
""" |
|
Extract table structure from table region image |
|
|
|
Args: |
|
table_image: RGB image of table region |
|
|
|
Returns: |
|
Dictionary containing table structure information |
|
""" |
|
|
|
|
|
input_tensor = self.preprocess_table_region(table_image) |
|
|
|
|
|
raw_outputs = self.predict(input_tensor) |
|
|
|
|
|
|
|
|
|
|
|
table_structure = { |
|
"model_type": self.model_type, |
|
"raw_outputs": {name: output.shape for name, output in raw_outputs.items()}, |
|
"cells": [], |
|
"rows": [], |
|
"columns": [], |
|
"confidence": 0.95, |
|
"processing_note": "This is a demonstration output. Real implementation would parse model outputs." |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return table_structure |
|
|
|
def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: |
|
"""Benchmark model performance""" |
|
|
|
print(f"Running benchmark with {num_iterations} iterations...") |
|
|
|
|
|
dummy_input = self.create_dummy_input() |
|
|
|
|
|
for _ in range(5): |
|
_ = self.predict(dummy_input) |
|
|
|
|
|
import time |
|
times = [] |
|
|
|
for i in range(num_iterations): |
|
start_time = time.time() |
|
_ = self.predict(dummy_input) |
|
end_time = time.time() |
|
times.append(end_time - start_time) |
|
|
|
if (i + 1) % 10 == 0: |
|
print(f" Progress: {i + 1}/{num_iterations}") |
|
|
|
|
|
times = np.array(times) |
|
stats = { |
|
"mean_time_ms": float(np.mean(times) * 1000), |
|
"std_time_ms": float(np.std(times) * 1000), |
|
"min_time_ms": float(np.min(times) * 1000), |
|
"max_time_ms": float(np.max(times) * 1000), |
|
"median_time_ms": float(np.median(times) * 1000), |
|
"throughput_fps": float(1.0 / np.mean(times)) |
|
} |
|
|
|
return stats |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="TableFormer ONNX Example") |
|
parser.add_argument("--model", type=str, |
|
choices=["accurate", "fast"], |
|
default="accurate", |
|
help="Model variant to use") |
|
parser.add_argument("--image", type=str, |
|
help="Path to table image (optional)") |
|
parser.add_argument("--benchmark", action="store_true", |
|
help="Run performance benchmark") |
|
parser.add_argument("--iterations", type=int, default=100, |
|
help="Number of benchmark iterations") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
model_files = { |
|
"accurate": "tableformer_accurate.onnx", |
|
"fast": "tableformer_fast.onnx" |
|
} |
|
|
|
model_path = model_files[args.model] |
|
|
|
|
|
if not os.path.exists(model_path): |
|
print(f"Error: Model file not found: {model_path}") |
|
print("Please ensure the ONNX model files are in the current directory.") |
|
return |
|
|
|
|
|
print("=" * 60) |
|
print(f"TableFormer ONNX Example - {args.model.title()} Model") |
|
print("=" * 60) |
|
|
|
tableformer = TableFormerONNX(model_path, args.model) |
|
|
|
|
|
if args.benchmark: |
|
print(f"\n๐ Running performance benchmark...") |
|
stats = tableformer.benchmark(args.iterations) |
|
|
|
print(f"\n๐ Benchmark Results ({args.model} model):") |
|
print(f" Mean inference time: {stats['mean_time_ms']:.2f} ยฑ {stats['std_time_ms']:.2f} ms") |
|
print(f" Median inference time: {stats['median_time_ms']:.2f} ms") |
|
print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") |
|
print(f" Throughput: {stats['throughput_fps']:.1f} FPS") |
|
|
|
|
|
if args.image: |
|
if not os.path.exists(args.image): |
|
print(f"Error: Image file not found: {args.image}") |
|
return |
|
|
|
print(f"\n๐ผ๏ธ Processing image: {args.image}") |
|
|
|
|
|
image = cv2.imread(args.image) |
|
if image is None: |
|
print(f"Error: Could not load image: {args.image}") |
|
return |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
structure = tableformer.extract_table_structure(image_rgb) |
|
|
|
print(f"โ Table structure extracted:") |
|
print(f" Model: {structure['model_type']}") |
|
print(f" Raw outputs: {structure['raw_outputs']}") |
|
print(f" Confidence: {structure['confidence']}") |
|
print(f" Note: {structure['processing_note']}") |
|
|
|
|
|
if not args.image: |
|
print(f"\n๐ฌ Running demo with dummy data...") |
|
|
|
|
|
dummy_image = np.random.randint(0, 255, (300, 400, 3), dtype=np.uint8) |
|
|
|
|
|
structure = tableformer.extract_table_structure(dummy_image) |
|
|
|
print(f"โ Demo completed:") |
|
print(f" Model: {structure['model_type']}") |
|
print(f" Raw outputs: {structure['raw_outputs']}") |
|
print(f" Processing: {structure['processing_note']}") |
|
|
|
print(f"\nโ
Example completed successfully!") |
|
print(f"\nTo process a real image, use: python example.py --model {args.model} --image your_table.jpg") |
|
print(f"To run a benchmark, use: python example.py --model {args.model} --benchmark") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |