EasyOCR-onnx / example.py
asmud's picture
Initial release: EasyOCR ONNX models with JPQD quantization
c1ac2fb
#!/usr/bin/env python3
"""
Example usage of EasyOCR ONNX models for text detection and recognition.
"""
import onnxruntime as ort
import cv2
import numpy as np
from typing import List
import argparse
import os
class EasyOCR_ONNX:
"""ONNX implementation of EasyOCR for text detection and recognition."""
def __init__(self,
detector_path: str = "craft_mlt_25k_jpqd.onnx",
recognizer_path: str = "english_g2_jpqd.onnx"):
"""
Initialize EasyOCR ONNX models.
Args:
detector_path: Path to CRAFT detection model
recognizer_path: Path to text recognition model
"""
print(f"Loading detector: {detector_path}")
self.detector = ort.InferenceSession(detector_path)
print(f"Loading recognizer: {recognizer_path}")
self.recognizer = ort.InferenceSession(recognizer_path)
# Character sets
self.english_charset = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
self.latin_charset = self._get_latin_charset()
# Determine charset based on model
if "english" in recognizer_path.lower():
self.charset = self.english_charset
elif "latin" in recognizer_path.lower():
self.charset = self.latin_charset
else:
self.charset = self.english_charset
def _get_latin_charset(self) -> str:
"""Get extended Latin character set."""
# This is a simplified version - in practice, you'd load the full 352-character set
basic = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ '
extended = 'àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚě'
return basic + extended
def preprocess_for_detection(self, image: np.ndarray, target_size: int = 640) -> np.ndarray:
"""Preprocess image for CRAFT text detection."""
# Resize to target size
image_resized = cv2.resize(image, (target_size, target_size))
# Normalize to [0, 1]
image_norm = image_resized.astype(np.float32) / 255.0
# Convert HWC to CHW
image_chw = np.transpose(image_norm, (2, 0, 1))
# Add batch dimension
image_batch = np.expand_dims(image_chw, axis=0)
return image_batch
def preprocess_for_recognition(self, text_region: np.ndarray) -> np.ndarray:
"""Preprocess text region for CRNN recognition."""
# Convert to grayscale if needed
if len(text_region.shape) == 3:
gray = cv2.cvtColor(text_region, cv2.COLOR_RGB2GRAY)
else:
gray = text_region
# Resize to model input size (32 height, 100 width)
resized = cv2.resize(gray, (100, 32))
# Normalize to [0, 1]
normalized = resized.astype(np.float32) / 255.0
# Add batch and channel dimensions [1, 1, 32, 100]
input_batch = np.expand_dims(np.expand_dims(normalized, axis=0), axis=0)
return input_batch
def detect_text(self, image: np.ndarray) -> np.ndarray:
"""
Detect text regions in image using CRAFT model.
Args:
image: Input image (RGB format)
Returns:
Detection output maps
"""
# Preprocess
input_batch = self.preprocess_for_detection(image)
# Run inference
outputs = self.detector.run(None, {"input": input_batch})
# Ensure we return a numpy array
if isinstance(outputs[0], np.ndarray):
return outputs[0]
else:
return np.array(outputs[0]) # Convert to numpy array if needed
def recognize_text(self, text_regions: List[np.ndarray]) -> List[str]:
"""
Recognize text in detected regions.
Args:
text_regions: List of cropped text region images
Returns:
List of recognized text strings
"""
results = []
for region in text_regions:
# Preprocess
input_batch = self.preprocess_for_recognition(region)
# Run inference
outputs = self.recognizer.run(None, {"input": input_batch})
# Ensure output is numpy array and decode text
output_array = outputs[0] if isinstance(outputs[0], np.ndarray) else np.array(outputs[0])
text = self._decode_text(output_array)
results.append(text)
return results
def _decode_text(self, output: np.ndarray) -> str:
"""Decode recognition output to text string using greedy decoding."""
# Get character indices with highest probability
indices = np.argmax(output[0], axis=1)
# Convert indices to characters
text = ''
prev_char = ''
for idx in indices:
if idx < len(self.charset) and idx > 0: # Skip blank token (index 0)
char = self.charset[idx]
# Simple CTC-like decoding: skip repeated characters
if char != prev_char:
text += char
prev_char = char
return text.strip()
def extract_simple_regions(self, detection_output: np.ndarray,
original_image: np.ndarray,
threshold: float = 0.3) -> List[np.ndarray]:
"""
Extract text regions from detection output (simplified version).
In practice, you'd implement proper CRAFT post-processing.
"""
# This is a simplified implementation for demonstration
# In practice, you'd use proper CRAFT post-processing to extract precise text regions
h, w = original_image.shape[:2]
# Handle different output shapes
if len(detection_output.shape) == 4: # [batch, channels, height, width]
detection_map = detection_output[0, 0] # First channel of first batch
elif len(detection_output.shape) == 3: # [channels, height, width]
detection_map = detection_output[0] # First channel
else:
detection_map = detection_output
# Normalize detection map to [0, 1] if needed
if detection_map.max() > 1.0:
detection_map = detection_map / detection_map.max()
# Lower threshold for better detection
binary_map = (detection_map > threshold).astype(np.uint8) * 255
binary_map = cv2.resize(binary_map, (w, h))
# Apply morphological operations to improve detection
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
binary_map = cv2.morphologyEx(binary_map, cv2.MORPH_CLOSE, kernel)
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
text_regions = []
for contour in contours:
# Get bounding box
x, y, w_box, h_box = cv2.boundingRect(contour)
# Filter small regions but be more permissive
if w_box > 15 and h_box > 8 and cv2.contourArea(contour) > 100:
# Add some padding
x = max(0, x - 2)
y = max(0, y - 2)
w_box = min(w - x, w_box + 4)
h_box = min(h - y, h_box + 4)
# Extract region from original image
region = original_image[y:y+h_box, x:x+w_box]
if region.size > 0: # Make sure region is not empty
text_regions.append(region)
# If no regions found with CRAFT, fall back to simple grid sampling
if len(text_regions) == 0:
print(" No CRAFT regions found, using fallback method...")
# Sample some regions from the image for demonstration
step_y, step_x = h // 4, w // 4
for y in range(0, h - 32, step_y):
for x in range(0, w - 100, step_x):
region = original_image[y:y+32, x:x+100]
if region.size > 0 and np.mean(region) < 240: # Skip mostly white regions
text_regions.append(region)
if len(text_regions) >= 4: # Limit to 4 samples
break
if len(text_regions) >= 4:
break
return text_regions
def main():
parser = argparse.ArgumentParser(description="EasyOCR ONNX Example")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument("--detector", type=str, default="craft_mlt_25k_jpqd.onnx",
help="Path to detection model")
parser.add_argument("--recognizer", type=str, default="english_g2_jpqd.onnx",
help="Path to recognition model")
parser.add_argument("--output", type=str, help="Path to save output image with detections")
args = parser.parse_args()
# Check if files exist
if not os.path.exists(args.image):
print(f"Error: Image file not found: {args.image}")
return
if not os.path.exists(args.detector):
print(f"Error: Detector model not found: {args.detector}")
return
if not os.path.exists(args.recognizer):
print(f"Error: Recognizer model not found: {args.recognizer}")
return
# Initialize OCR
print("Initializing EasyOCR ONNX...")
ocr = EasyOCR_ONNX(args.detector, args.recognizer)
# Load image
print(f"Loading image: {args.image}")
image = cv2.imread(args.image)
if image is None:
print(f"Error: Could not load image: {args.image}")
return
# Convert BGR to RGB
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Detect text
print("Detecting text regions...")
detection_output = ocr.detect_text(image_rgb)
# Extract text regions (simplified)
text_regions = ocr.extract_simple_regions(detection_output, image_rgb)
print(f"Found {len(text_regions)} text regions")
# Recognize text
if text_regions:
print("Recognizing text...")
recognized_texts = ocr.recognize_text(text_regions)
# Print results
print(f"\nRecognized text ({len(recognized_texts)} regions):")
print("-" * 50)
for i, text in enumerate(recognized_texts):
print(f"Region {i+1}: '{text}'")
else:
print("No text regions detected")
# Save output image with bounding boxes (if requested)
if args.output and text_regions:
output_image = image.copy()
# This would draw bounding boxes on the image
cv2.imwrite(args.output, output_image)
print(f"Output saved to: {args.output}")
if __name__ == "__main__":
main()