File size: 10,282 Bytes
fbea007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253a9b0
 
fbea007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#!/usr/bin/env python3
"""
Example usage of Docling TableFormer ONNX models for table structure recognition.
"""

import onnxruntime as ort
import cv2
import numpy as np
from typing import Dict, List, Tuple, Optional
import argparse
import os

class TableFormerONNX:
    """ONNX wrapper for TableFormer models"""
    
    def __init__(self, model_path: str, model_type: str = "accurate"):
        """
        Initialize TableFormer ONNX model
        
        Args:
            model_path: Path to ONNX model file
            model_type: "accurate" or "fast"
        """
        print(f"Loading {model_type} TableFormer model: {model_path}")
        self.session = ort.InferenceSession(model_path)
        self.model_type = model_type
        
        # Get model input/output information
        self.input_name = self.session.get_inputs()[0].name
        self.input_shape = self.session.get_inputs()[0].shape
        self.input_type = self.session.get_inputs()[0].type
        self.output_names = [output.name for output in self.session.get_outputs()]
        
        print(f"โœ“ Model loaded successfully")
        print(f"  Input: {self.input_name} {self.input_shape} ({self.input_type})")
        print(f"  Outputs: {len(self.output_names)} tensors")
    
    def create_dummy_input(self) -> np.ndarray:
        """Create dummy input tensor for testing"""
        if self.input_type == 'tensor(int64)':
            # Create dummy integer input
            dummy_input = np.random.randint(0, 100, self.input_shape).astype(np.int64)
        else:
            # Create dummy float input  
            dummy_input = np.random.randn(*self.input_shape).astype(np.float32)
        
        return dummy_input
    
    def preprocess_table_region(self, table_image: np.ndarray) -> np.ndarray:
        """
        Preprocess table region image for TableFormer inference
        
        Note: This is a simplified preprocessing example.
        The actual TableFormer preprocessing may be more complex and specific
        to the training procedure.
        """
        
        # Convert to RGB if needed
        if len(table_image.shape) == 3 and table_image.shape[2] == 3:
            # Already RGB
            processed = table_image
        elif len(table_image.shape) == 3 and table_image.shape[2] == 4:
            # RGBA to RGB
            processed = cv2.cvtColor(table_image, cv2.COLOR_RGBA2RGB)
        elif len(table_image.shape) == 2:
            # Grayscale to RGB
            processed = cv2.cvtColor(table_image, cv2.COLOR_GRAY2RGB)
        else:
            processed = table_image
        
        # Resize to expected input size (this would depend on actual model requirements)
        # For now, we'll create a dummy tensor matching the model's expected input
        if self.input_type == 'tensor(int64)':
            # For models expecting integer inputs (like sequence models)
            dummy_features = np.random.randint(0, 100, self.input_shape).astype(np.int64)
        else:
            # For models expecting float inputs
            dummy_features = np.random.randn(*self.input_shape).astype(np.float32)
        
        return dummy_features
    
    def predict(self, input_tensor: np.ndarray) -> Dict[str, np.ndarray]:
        """Run table structure prediction"""
        
        # Validate input shape
        expected_shape = tuple(self.input_shape)
        if input_tensor.shape != expected_shape:
            print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}")
        
        # Run inference
        outputs = self.session.run(None, {self.input_name: input_tensor})
        
        # Package results
        result = {}
        for i, name in enumerate(self.output_names):
            result[name] = outputs[i]
        
        return result
    
    def extract_table_structure(self, table_image: np.ndarray) -> Dict:
        """
        Extract table structure from table region image
        
        Args:
            table_image: RGB image of table region
            
        Returns:
            Dictionary containing table structure information
        """
        
        # Preprocess image
        input_tensor = self.preprocess_table_region(table_image)
        
        # Get raw predictions
        raw_outputs = self.predict(input_tensor)
        
        # Post-process to extract table structure
        # Note: This is a simplified example. The actual post-processing
        # would depend on the specific output format of the TableFormer model
        
        table_structure = {
            "model_type": self.model_type,
            "raw_outputs": {name: output.shape for name, output in raw_outputs.items()},
            "cells": [],  # Would contain cell boundary and type information
            "rows": [],   # Would contain row definitions  
            "columns": [], # Would contain column definitions
            "confidence": 0.95,  # Placeholder confidence score
            "processing_note": "This is a demonstration output. Real implementation would parse model outputs."
        }
        
        # In a real implementation, you would:
        # 1. Parse the raw model outputs
        # 2. Extract cell boundaries and classifications
        # 3. Determine row and column structure
        # 4. Generate structured table representation
        
        return table_structure
    
    def benchmark(self, num_iterations: int = 100) -> Dict[str, float]:
        """Benchmark model performance"""
        
        print(f"Running benchmark with {num_iterations} iterations...")
        
        # Create dummy input
        dummy_input = self.create_dummy_input()
        
        # Warmup
        for _ in range(5):
            _ = self.predict(dummy_input)
        
        # Benchmark
        import time
        times = []
        
        for i in range(num_iterations):
            start_time = time.time()
            _ = self.predict(dummy_input)
            end_time = time.time()
            times.append(end_time - start_time)
            
            if (i + 1) % 10 == 0:
                print(f"  Progress: {i + 1}/{num_iterations}")
        
        # Calculate statistics
        times = np.array(times)
        stats = {
            "mean_time_ms": float(np.mean(times) * 1000),
            "std_time_ms": float(np.std(times) * 1000),
            "min_time_ms": float(np.min(times) * 1000),
            "max_time_ms": float(np.max(times) * 1000),
            "median_time_ms": float(np.median(times) * 1000),
            "throughput_fps": float(1.0 / np.mean(times))
        }
        
        return stats


def main():
    parser = argparse.ArgumentParser(description="TableFormer ONNX Example")
    parser.add_argument("--model", type=str, 
                       choices=["accurate", "fast"], 
                       default="accurate",
                       help="Model variant to use")
    parser.add_argument("--image", type=str,
                       help="Path to table image (optional)")
    parser.add_argument("--benchmark", action="store_true",
                       help="Run performance benchmark")
    parser.add_argument("--iterations", type=int, default=100,
                       help="Number of benchmark iterations")
    
    args = parser.parse_args()
    
    # Model paths
    model_files = {
        "accurate": "tableformer_accurate.onnx",
        "fast": "tableformer_fast.onnx"
    }
    
    model_path = model_files[args.model]
    
    # Check if model file exists
    if not os.path.exists(model_path):
        print(f"Error: Model file not found: {model_path}")
        print("Please ensure the ONNX model files are in the current directory.")
        return
    
    # Initialize model
    print("=" * 60)
    print(f"TableFormer ONNX Example - {args.model.title()} Model")
    print("=" * 60)
    
    tableformer = TableFormerONNX(model_path, args.model)
    
    # Run benchmark if requested
    if args.benchmark:
        print(f"\n๐Ÿ“Š Running performance benchmark...")
        stats = tableformer.benchmark(args.iterations)
        
        print(f"\n๐Ÿ“ˆ Benchmark Results ({args.model} model):")
        print(f"  Mean inference time: {stats['mean_time_ms']:.2f} ยฑ {stats['std_time_ms']:.2f} ms")
        print(f"  Median inference time: {stats['median_time_ms']:.2f} ms")
        print(f"  Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms")
        print(f"  Throughput: {stats['throughput_fps']:.1f} FPS")
    
    # Process image if provided
    if args.image:
        if not os.path.exists(args.image):
            print(f"Error: Image file not found: {args.image}")
            return
            
        print(f"\n๐Ÿ–ผ๏ธ  Processing image: {args.image}")
        
        # Load image
        image = cv2.imread(args.image)
        if image is None:
            print(f"Error: Could not load image: {args.image}")
            return
        
        # Convert BGR to RGB
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Extract table structure
        structure = tableformer.extract_table_structure(image_rgb)
        
        print(f"โœ“ Table structure extracted:")
        print(f"  Model: {structure['model_type']}")
        print(f"  Raw outputs: {structure['raw_outputs']}")
        print(f"  Confidence: {structure['confidence']}")
        print(f"  Note: {structure['processing_note']}")
    
    # Demo with dummy data
    if not args.image:
        print(f"\n๐Ÿ”ฌ Running demo with dummy data...")
        
        # Create dummy table image
        dummy_image = np.random.randint(0, 255, (300, 400, 3), dtype=np.uint8)
        
        # Process dummy image
        structure = tableformer.extract_table_structure(dummy_image)
        
        print(f"โœ“ Demo completed:")
        print(f"  Model: {structure['model_type']}")
        print(f"  Raw outputs: {structure['raw_outputs']}")
        print(f"  Processing: {structure['processing_note']}")
    
    print(f"\nโœ… Example completed successfully!")
    print(f"\nTo process a real image, use: python example.py --model {args.model} --image your_table.jpg")
    print(f"To run a benchmark, use: python example.py --model {args.model} --benchmark")


if __name__ == "__main__":
    main()