|
|
|
""" |
|
Example usage of DocumentClassifier ONNX model for document classification. |
|
""" |
|
|
|
import onnxruntime as ort |
|
import numpy as np |
|
import cv2 |
|
from typing import Dict, List, Union, Optional |
|
import argparse |
|
import os |
|
from PIL import Image |
|
import time |
|
|
|
class DocumentClassifierONNX: |
|
"""ONNX wrapper for DocumentClassifier model""" |
|
|
|
def __init__(self, model_path: str = "DocumentClassifier.onnx"): |
|
""" |
|
Initialize DocumentClassifier ONNX model |
|
|
|
Args: |
|
model_path: Path to ONNX model file |
|
""" |
|
print(f"Loading DocumentClassifier model: {model_path}") |
|
self.session = ort.InferenceSession(model_path) |
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name |
|
self.input_shape = self.session.get_inputs()[0].shape |
|
self.input_type = self.session.get_inputs()[0].type |
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
self.output_shape = self.session.get_outputs()[0].shape |
|
|
|
|
|
self.categories = [ |
|
"article", "form", "letter", "memo", "news", "presentation", |
|
"resume", "scientific", "specification", "table", "other" |
|
] |
|
|
|
print(f"β Model loaded successfully") |
|
print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") |
|
print(f" Output: {self.output_shape}") |
|
print(f" Categories: {len(self.categories)}") |
|
|
|
def create_dummy_input(self) -> np.ndarray: |
|
"""Create dummy input tensor for testing""" |
|
if 'float' in self.input_type: |
|
|
|
dummy_input = np.random.randn(*self.input_shape).astype(np.float32) |
|
else: |
|
|
|
dummy_input = np.random.randint(0, 255, self.input_shape).astype(np.int64) |
|
|
|
return dummy_input |
|
|
|
def preprocess_image(self, image: Union[str, np.ndarray], target_size: tuple = (224, 224)) -> np.ndarray: |
|
""" |
|
Preprocess image for DocumentClassifier inference |
|
|
|
Args: |
|
image: Image path or numpy array |
|
target_size: Target image size (height, width) |
|
""" |
|
|
|
if isinstance(image, str): |
|
|
|
pil_image = Image.open(image).convert('RGB') |
|
image_array = np.array(pil_image) |
|
else: |
|
image_array = image.copy() |
|
|
|
print(f" Processing image: {image_array.shape}") |
|
|
|
|
|
if len(image_array.shape) == 3: |
|
resized = cv2.resize(image_array, target_size[::-1], interpolation=cv2.INTER_CUBIC) |
|
else: |
|
|
|
gray = image_array if len(image_array.shape) == 2 else cv2.cvtColor(image_array, cv2.COLOR_BGR2GRAY) |
|
rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) |
|
resized = cv2.resize(rgb, target_size[::-1], interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
normalized = resized.astype(np.float32) / 255.0 |
|
|
|
|
|
if len(normalized.shape) == 3: |
|
chw = np.transpose(normalized, (2, 0, 1)) |
|
else: |
|
chw = normalized |
|
|
|
|
|
if len(self.input_shape) == 4 and len(chw.shape) == 3: |
|
batched = np.expand_dims(chw, axis=0) |
|
else: |
|
batched = chw |
|
|
|
|
|
expected_shape = tuple(self.input_shape) |
|
if batched.shape != expected_shape: |
|
|
|
print(f" Warning: Shape mismatch {batched.shape} != {expected_shape}") |
|
batched = self.create_dummy_input() |
|
|
|
print(f" Preprocessed: {batched.shape}") |
|
return batched |
|
|
|
def predict(self, input_tensor: np.ndarray) -> np.ndarray: |
|
"""Run DocumentClassifier prediction""" |
|
|
|
|
|
expected_shape = tuple(self.input_shape) |
|
if input_tensor.shape != expected_shape: |
|
print(f"Warning: Input shape {input_tensor.shape} != expected {expected_shape}") |
|
|
|
|
|
outputs = self.session.run(None, {self.input_name: input_tensor}) |
|
|
|
return outputs[0] |
|
|
|
def decode_output(self, logits: np.ndarray, top_k: int = 3) -> Dict: |
|
""" |
|
Decode model output logits to document categories |
|
|
|
Args: |
|
logits: Model output logits |
|
top_k: Number of top predictions to return |
|
|
|
Returns: |
|
Dictionary with classification results |
|
""" |
|
|
|
|
|
if len(logits.shape) > 2: |
|
|
|
logits = np.mean(logits, axis=(2, 3)) |
|
|
|
if len(logits.shape) > 1: |
|
logits = logits.flatten() |
|
|
|
|
|
if len(logits) > len(self.categories): |
|
logits = logits[:len(self.categories)] |
|
elif len(logits) < len(self.categories): |
|
|
|
padded = np.zeros(len(self.categories)) |
|
padded[:len(logits)] = logits |
|
logits = padded |
|
|
|
|
|
probabilities = self._softmax(logits) |
|
|
|
|
|
top_k_indices = np.argsort(probabilities)[-top_k:][::-1] |
|
top_k_probs = probabilities[top_k_indices] |
|
|
|
|
|
predictions = [] |
|
for i, (idx, prob) in enumerate(zip(top_k_indices, top_k_probs)): |
|
category = self.categories[idx] if idx < len(self.categories) else f"category_{idx}" |
|
predictions.append({ |
|
"rank": i + 1, |
|
"category": category, |
|
"confidence": float(prob), |
|
"index": int(idx) |
|
}) |
|
|
|
result = { |
|
"predicted_category": predictions[0]["category"], |
|
"confidence": predictions[0]["confidence"], |
|
"top_predictions": predictions, |
|
"all_probabilities": probabilities.tolist() |
|
} |
|
|
|
return result |
|
|
|
def _softmax(self, x: np.ndarray) -> np.ndarray: |
|
"""Apply softmax to convert logits to probabilities""" |
|
exp_x = np.exp(x - np.max(x)) |
|
return exp_x / np.sum(exp_x) |
|
|
|
def classify(self, image: Union[str, np.ndarray]) -> Dict: |
|
""" |
|
Classify document type from image |
|
|
|
Args: |
|
image: Image path or numpy array |
|
|
|
Returns: |
|
Dictionary with classification results |
|
""" |
|
|
|
print("π Processing document image...") |
|
|
|
|
|
input_tensor = self.preprocess_image(image) |
|
|
|
print("π Running classification...") |
|
|
|
|
|
logits = self.predict(input_tensor) |
|
|
|
print("π Decoding results...") |
|
|
|
|
|
result = self.decode_output(logits) |
|
|
|
|
|
result["processing_info"] = { |
|
"input_shape": input_tensor.shape, |
|
"output_shape": logits.shape, |
|
"inference_successful": True |
|
} |
|
|
|
return result |
|
|
|
def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: |
|
"""Benchmark model performance""" |
|
|
|
print(f"π Running benchmark with {num_iterations} iterations...") |
|
|
|
|
|
dummy_input = self.create_dummy_input() |
|
|
|
|
|
for _ in range(5): |
|
_ = self.predict(dummy_input) |
|
|
|
|
|
times = [] |
|
|
|
for i in range(num_iterations): |
|
start_time = time.time() |
|
_ = self.predict(dummy_input) |
|
end_time = time.time() |
|
times.append(end_time - start_time) |
|
|
|
if (i + 1) % 10 == 0: |
|
print(f" Progress: {i + 1}/{num_iterations}") |
|
|
|
|
|
times = np.array(times) |
|
stats = { |
|
"mean_time_ms": float(np.mean(times) * 1000), |
|
"std_time_ms": float(np.std(times) * 1000), |
|
"min_time_ms": float(np.min(times) * 1000), |
|
"max_time_ms": float(np.max(times) * 1000), |
|
"median_time_ms": float(np.median(times) * 1000), |
|
"throughput_fps": float(1.0 / np.mean(times)), |
|
"total_iterations": num_iterations |
|
} |
|
|
|
return stats |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="DocumentClassifier ONNX Example") |
|
parser.add_argument("--model", type=str, default="DocumentClassifier.onnx", |
|
help="Path to DocumentClassifier ONNX model") |
|
parser.add_argument("--image", type=str, |
|
help="Path to document image file") |
|
parser.add_argument("--benchmark", action="store_true", |
|
help="Run performance benchmark") |
|
parser.add_argument("--iterations", type=int, default=100, |
|
help="Number of benchmark iterations") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.model): |
|
print(f"β Error: Model file not found: {args.model}") |
|
print("Please ensure the ONNX model file is in the current directory.") |
|
return |
|
|
|
|
|
print("=" * 60) |
|
print("DocumentClassifier ONNX Example") |
|
print("=" * 60) |
|
|
|
try: |
|
classifier = DocumentClassifierONNX(args.model) |
|
except Exception as e: |
|
print(f"β Error loading model: {e}") |
|
return |
|
|
|
|
|
if args.benchmark: |
|
print(f"\nπ Running performance benchmark...") |
|
try: |
|
stats = classifier.benchmark(args.iterations) |
|
|
|
print(f"\nπ Benchmark Results:") |
|
print(f" Mean inference time: {stats['mean_time_ms']:.2f} Β± {stats['std_time_ms']:.2f} ms") |
|
print(f" Median inference time: {stats['median_time_ms']:.2f} ms") |
|
print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") |
|
print(f" Throughput: {stats['throughput_fps']:.1f} FPS") |
|
except Exception as e: |
|
print(f"β Benchmark failed: {e}") |
|
|
|
|
|
if args.image: |
|
if not os.path.exists(args.image): |
|
print(f"β Error: Image file not found: {args.image}") |
|
return |
|
|
|
print(f"\nπ Classifying document: {args.image}") |
|
|
|
try: |
|
|
|
result = classifier.classify(args.image) |
|
|
|
print(f"\nβ
Classification completed:") |
|
print(f" Document type: {result['predicted_category']}") |
|
print(f" Confidence: {result['confidence']:.3f}") |
|
print(f"\nπ Top predictions:") |
|
for pred in result['top_predictions']: |
|
print(f" {pred['rank']}. {pred['category']}: {pred['confidence']:.3f}") |
|
|
|
except Exception as e: |
|
print(f"β Error classifying document: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
if not args.image and not args.benchmark: |
|
print(f"\n㪠Running demo with dummy data...") |
|
|
|
try: |
|
|
|
dummy_image = np.random.randint(0, 255, (800, 600, 3), dtype=np.uint8) |
|
|
|
|
|
result = classifier.classify(dummy_image) |
|
|
|
print(f"β
Demo completed:") |
|
print(f" Predicted type: {result['predicted_category']}") |
|
print(f" Confidence: {result['confidence']:.3f}") |
|
print(f" Processing info: {result['processing_info']}") |
|
print(f"\nπ Note: This was a demonstration with random data.") |
|
|
|
except Exception as e: |
|
print(f"β Demo failed: {e}") |
|
|
|
print(f"\nβ
Example completed successfully!") |
|
print(f"\nUsage examples:") |
|
print(f" Classify document: python example.py --image document.jpg") |
|
print(f" Run benchmark: python example.py --benchmark --iterations 50") |
|
print(f" Both: python example.py --image document.pdf --benchmark") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |