|
|
|
""" |
|
Example usage of CodeFormula ONNX model for code and formula recognition. |
|
""" |
|
|
|
import onnxruntime as ort |
|
import numpy as np |
|
import cv2 |
|
from typing import Dict, List, Union, Optional |
|
import argparse |
|
import os |
|
from PIL import Image |
|
import time |
|
|
|
class CodeFormulaONNX: |
|
"""ONNX wrapper for CodeFormula model""" |
|
|
|
def __init__(self, model_path: str = "CodeFormula.onnx"): |
|
""" |
|
Initialize CodeFormula ONNX model |
|
|
|
Args: |
|
model_path: Path to ONNX model file |
|
""" |
|
print(f"Loading CodeFormula model: {model_path}") |
|
self.session = ort.InferenceSession(model_path) |
|
|
|
|
|
self.input_name = self.session.get_inputs()[0].name |
|
self.input_shape = self.session.get_inputs()[0].shape |
|
self.input_type = self.session.get_inputs()[0].type |
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
self.output_shape = self.session.get_outputs()[0].shape |
|
|
|
|
|
self.vocab_size = self.output_shape[-1] if len(self.output_shape) > 2 else 50827 |
|
self.sequence_length = self.output_shape[-2] if len(self.output_shape) > 2 else 10 |
|
|
|
print(f"โ Model loaded successfully") |
|
print(f" Input: {self.input_name} {self.input_shape} ({self.input_type})") |
|
print(f" Output: {self.output_shape}") |
|
print(f" Vocabulary size: {self.vocab_size}") |
|
print(f" Sequence length: {self.sequence_length}") |
|
|
|
def create_dummy_input(self) -> np.ndarray: |
|
"""Create dummy input tensor for testing""" |
|
if self.input_type == 'tensor(int64)': |
|
|
|
dummy_input = np.random.randint(0, min(self.vocab_size, 1000), self.input_shape).astype(np.int64) |
|
else: |
|
|
|
dummy_input = np.random.randn(*self.input_shape).astype(np.float32) |
|
|
|
return dummy_input |
|
|
|
def preprocess_image(self, image: Union[str, np.ndarray], target_dpi: int = 120) -> np.ndarray: |
|
""" |
|
Preprocess image for CodeFormula inference |
|
|
|
Note: This is a simplified preprocessing. The actual CodeFormula model |
|
requires specific preprocessing that converts images to token sequences. |
|
""" |
|
|
|
if isinstance(image, str): |
|
|
|
pil_image = Image.open(image).convert('RGB') |
|
image_array = np.array(pil_image) |
|
else: |
|
image_array = image.copy() |
|
|
|
|
|
print(f" Processing image at {target_dpi} DPI...") |
|
|
|
|
|
height, width = image_array.shape[:2] |
|
|
|
|
|
|
|
scale_factor = target_dpi / 72.0 |
|
new_height = int(height * scale_factor) |
|
new_width = int(width * scale_factor) |
|
|
|
if new_height != height or new_width != width: |
|
image_array = cv2.resize(image_array, (new_width, new_height), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
if len(image_array.shape) == 3: |
|
gray = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) |
|
else: |
|
gray = image_array |
|
|
|
|
|
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) |
|
enhanced = clahe.apply(gray) |
|
|
|
|
|
denoised = cv2.fastNlMeansDenoising(enhanced) |
|
|
|
print(f" Image preprocessed: {image_array.shape} -> {denoised.shape}") |
|
|
|
|
|
|
|
dummy_tokens = self.create_dummy_input() |
|
|
|
return dummy_tokens |
|
|
|
def predict(self, input_tokens: np.ndarray) -> np.ndarray: |
|
"""Run CodeFormula prediction""" |
|
|
|
|
|
expected_shape = tuple(self.input_shape) |
|
if input_tokens.shape != expected_shape: |
|
print(f"Warning: Input shape {input_tokens.shape} != expected {expected_shape}") |
|
|
|
|
|
outputs = self.session.run(None, {self.input_name: input_tokens}) |
|
|
|
return outputs[0] |
|
|
|
def decode_output(self, logits: np.ndarray, top_k: int = 1) -> Dict: |
|
""" |
|
Decode model output logits |
|
|
|
Args: |
|
logits: Model output logits [batch, sequence, vocab] |
|
top_k: Number of top predictions to return |
|
|
|
Returns: |
|
Dictionary with decoded results |
|
""" |
|
|
|
batch_size, seq_len, vocab_size = logits.shape |
|
|
|
|
|
top_k_indices = np.argsort(logits[0], axis=-1)[:, -top_k:] |
|
top_k_logits = np.take_along_axis(logits[0], top_k_indices, axis=-1) |
|
|
|
|
|
probabilities = self._softmax(top_k_logits) |
|
|
|
|
|
predicted_tokens = np.argmax(logits[0], axis=-1) |
|
max_probabilities = np.max(probabilities, axis=-1) |
|
|
|
result = { |
|
"predicted_tokens": predicted_tokens.tolist(), |
|
"probabilities": max_probabilities.tolist(), |
|
"mean_confidence": float(np.mean(max_probabilities)), |
|
"max_confidence": float(np.max(max_probabilities)), |
|
"min_confidence": float(np.min(max_probabilities)), |
|
"sequence_length": int(seq_len), |
|
"top_k_predictions": { |
|
"indices": top_k_indices.tolist(), |
|
"probabilities": probabilities.tolist() |
|
} |
|
} |
|
|
|
return result |
|
|
|
def _softmax(self, x: np.ndarray) -> np.ndarray: |
|
"""Apply softmax to convert logits to probabilities""" |
|
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) |
|
return exp_x / np.sum(exp_x, axis=-1, keepdims=True) |
|
|
|
def recognize(self, image: Union[str, np.ndarray]) -> Dict: |
|
""" |
|
Recognize code or formula from image |
|
|
|
Args: |
|
image: Image path or numpy array |
|
|
|
Returns: |
|
Dictionary with recognition results |
|
""" |
|
|
|
print("๐ Processing image...") |
|
|
|
|
|
input_tokens = self.preprocess_image(image) |
|
|
|
print("๐ Running inference...") |
|
|
|
|
|
logits = self.predict(input_tokens) |
|
|
|
print("๐ Decoding results...") |
|
|
|
|
|
decoded = self.decode_output(logits) |
|
|
|
|
|
output_type = self._classify_content_type(decoded["predicted_tokens"]) |
|
|
|
|
|
result = { |
|
"recognition_type": output_type, |
|
"model_output": decoded, |
|
"processing_info": { |
|
"input_shape": input_tokens.shape, |
|
"output_shape": logits.shape, |
|
"inference_successful": True |
|
} |
|
} |
|
|
|
return result |
|
|
|
def _classify_content_type(self, tokens: List[int]) -> str: |
|
""" |
|
Classify if the content is likely code or formula |
|
|
|
This is a simplified heuristic. In practice, you would: |
|
1. Decode tokens to actual text using the tokenizer |
|
2. Analyze the text content for patterns |
|
3. Look for programming language indicators or mathematical notation |
|
""" |
|
|
|
|
|
unique_tokens = len(set(tokens)) |
|
token_variance = np.var(tokens) if len(tokens) > 1 else 0 |
|
|
|
if unique_tokens > len(tokens) * 0.7: |
|
return "code" |
|
elif token_variance < 100: |
|
return "formula" |
|
else: |
|
return "unknown" |
|
|
|
def benchmark(self, num_iterations: int = 100) -> Dict[str, float]: |
|
"""Benchmark model performance""" |
|
|
|
print(f"๐ Running benchmark with {num_iterations} iterations...") |
|
|
|
|
|
dummy_input = self.create_dummy_input() |
|
|
|
|
|
for _ in range(5): |
|
_ = self.predict(dummy_input) |
|
|
|
|
|
times = [] |
|
|
|
for i in range(num_iterations): |
|
start_time = time.time() |
|
_ = self.predict(dummy_input) |
|
end_time = time.time() |
|
times.append(end_time - start_time) |
|
|
|
if (i + 1) % 10 == 0: |
|
print(f" Progress: {i + 1}/{num_iterations}") |
|
|
|
|
|
times = np.array(times) |
|
stats = { |
|
"mean_time_ms": float(np.mean(times) * 1000), |
|
"std_time_ms": float(np.std(times) * 1000), |
|
"min_time_ms": float(np.min(times) * 1000), |
|
"max_time_ms": float(np.max(times) * 1000), |
|
"median_time_ms": float(np.median(times) * 1000), |
|
"throughput_fps": float(1.0 / np.mean(times)), |
|
"total_iterations": num_iterations |
|
} |
|
|
|
return stats |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="CodeFormula ONNX Example") |
|
parser.add_argument("--model", type=str, default="CodeFormula.onnx", |
|
help="Path to CodeFormula ONNX model") |
|
parser.add_argument("--image", type=str, |
|
help="Path to image file (code snippet or formula)") |
|
parser.add_argument("--benchmark", action="store_true", |
|
help="Run performance benchmark") |
|
parser.add_argument("--iterations", type=int, default=100, |
|
help="Number of benchmark iterations") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.model): |
|
print(f"โ Error: Model file not found: {args.model}") |
|
print("Please ensure the ONNX model file is in the current directory.") |
|
return |
|
|
|
|
|
print("=" * 60) |
|
print("CodeFormula ONNX Example") |
|
print("=" * 60) |
|
|
|
try: |
|
codeformula = CodeFormulaONNX(args.model) |
|
except Exception as e: |
|
print(f"โ Error loading model: {e}") |
|
return |
|
|
|
|
|
if args.benchmark: |
|
print(f"\n๐ Running performance benchmark...") |
|
try: |
|
stats = codeformula.benchmark(args.iterations) |
|
|
|
print(f"\n๐ Benchmark Results:") |
|
print(f" Mean inference time: {stats['mean_time_ms']:.2f} ยฑ {stats['std_time_ms']:.2f} ms") |
|
print(f" Median inference time: {stats['median_time_ms']:.2f} ms") |
|
print(f" Min/Max: {stats['min_time_ms']:.2f} / {stats['max_time_ms']:.2f} ms") |
|
print(f" Throughput: {stats['throughput_fps']:.1f} FPS") |
|
except Exception as e: |
|
print(f"โ Benchmark failed: {e}") |
|
|
|
|
|
if args.image: |
|
if not os.path.exists(args.image): |
|
print(f"โ Error: Image file not found: {args.image}") |
|
return |
|
|
|
print(f"\n๐ผ๏ธ Processing image: {args.image}") |
|
|
|
try: |
|
|
|
result = codeformula.recognize(args.image) |
|
|
|
print(f"\nโ
Recognition completed:") |
|
print(f" Content type: {result['recognition_type']}") |
|
print(f" Confidence: {result['model_output']['mean_confidence']:.3f}") |
|
print(f" Sequence length: {result['model_output']['sequence_length']}") |
|
print(f" Predicted tokens: {result['model_output']['predicted_tokens'][:10]}{'...' if len(result['model_output']['predicted_tokens']) > 10 else ''}") |
|
|
|
|
|
print(f"\n๐ Note: This example uses dummy token decoding.") |
|
print(f" For actual text output, integrate with CodeFormula tokenizer.") |
|
|
|
except Exception as e: |
|
print(f"โ Error processing image: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
|
|
if not args.image and not args.benchmark: |
|
print(f"\n๐ฌ Running demo with dummy data...") |
|
|
|
try: |
|
|
|
dummy_image = np.random.randint(0, 255, (400, 600, 3), dtype=np.uint8) |
|
|
|
|
|
result = codeformula.recognize(dummy_image) |
|
|
|
print(f"โ
Demo completed:") |
|
print(f" Content type: {result['recognition_type']}") |
|
print(f" Mean confidence: {result['model_output']['mean_confidence']:.3f}") |
|
print(f" Processing info: {result['processing_info']}") |
|
print(f"\n๐ Note: This was a demonstration with random data.") |
|
|
|
except Exception as e: |
|
print(f"โ Demo failed: {e}") |
|
|
|
print(f"\nโ
Example completed successfully!") |
|
print(f"\nUsage examples:") |
|
print(f" Process image: python example.py --image code_snippet.jpg") |
|
print(f" Run benchmark: python example.py --benchmark --iterations 50") |
|
print(f" Both: python example.py --image formula.png --benchmark") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |