|
|
|
""" |
|
Example usage of EasyOCR ONNX models for text detection and recognition. |
|
""" |
|
|
|
import onnxruntime as ort |
|
import cv2 |
|
import numpy as np |
|
from typing import List |
|
import argparse |
|
import os |
|
|
|
class EasyOCR_ONNX: |
|
"""ONNX implementation of EasyOCR for text detection and recognition.""" |
|
|
|
def __init__(self, |
|
detector_path: str = "craft_mlt_25k_jpqd.onnx", |
|
recognizer_path: str = "english_g2_jpqd.onnx"): |
|
""" |
|
Initialize EasyOCR ONNX models. |
|
|
|
Args: |
|
detector_path: Path to CRAFT detection model |
|
recognizer_path: Path to text recognition model |
|
""" |
|
print(f"Loading detector: {detector_path}") |
|
self.detector = ort.InferenceSession(detector_path) |
|
|
|
print(f"Loading recognizer: {recognizer_path}") |
|
self.recognizer = ort.InferenceSession(recognizer_path) |
|
|
|
|
|
self.english_charset = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' |
|
self.latin_charset = self._get_latin_charset() |
|
|
|
|
|
if "english" in recognizer_path.lower(): |
|
self.charset = self.english_charset |
|
elif "latin" in recognizer_path.lower(): |
|
self.charset = self.latin_charset |
|
else: |
|
self.charset = self.english_charset |
|
|
|
def _get_latin_charset(self) -> str: |
|
"""Get extended Latin character set.""" |
|
|
|
basic = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ ' |
|
extended = 'àáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćĈĉĊċČčĎďĐđĒēĔĕĖėĘęĚě' |
|
return basic + extended |
|
|
|
def preprocess_for_detection(self, image: np.ndarray, target_size: int = 640) -> np.ndarray: |
|
"""Preprocess image for CRAFT text detection.""" |
|
|
|
image_resized = cv2.resize(image, (target_size, target_size)) |
|
|
|
|
|
image_norm = image_resized.astype(np.float32) / 255.0 |
|
|
|
|
|
image_chw = np.transpose(image_norm, (2, 0, 1)) |
|
|
|
|
|
image_batch = np.expand_dims(image_chw, axis=0) |
|
|
|
return image_batch |
|
|
|
def preprocess_for_recognition(self, text_region: np.ndarray) -> np.ndarray: |
|
"""Preprocess text region for CRNN recognition.""" |
|
|
|
if len(text_region.shape) == 3: |
|
gray = cv2.cvtColor(text_region, cv2.COLOR_RGB2GRAY) |
|
else: |
|
gray = text_region |
|
|
|
|
|
resized = cv2.resize(gray, (100, 32)) |
|
|
|
|
|
normalized = resized.astype(np.float32) / 255.0 |
|
|
|
|
|
input_batch = np.expand_dims(np.expand_dims(normalized, axis=0), axis=0) |
|
|
|
return input_batch |
|
|
|
def detect_text(self, image: np.ndarray) -> np.ndarray: |
|
""" |
|
Detect text regions in image using CRAFT model. |
|
|
|
Args: |
|
image: Input image (RGB format) |
|
|
|
Returns: |
|
Detection output maps |
|
""" |
|
|
|
input_batch = self.preprocess_for_detection(image) |
|
|
|
|
|
outputs = self.detector.run(None, {"input": input_batch}) |
|
|
|
|
|
if isinstance(outputs[0], np.ndarray): |
|
return outputs[0] |
|
else: |
|
return np.array(outputs[0]) |
|
|
|
def recognize_text(self, text_regions: List[np.ndarray]) -> List[str]: |
|
""" |
|
Recognize text in detected regions. |
|
|
|
Args: |
|
text_regions: List of cropped text region images |
|
|
|
Returns: |
|
List of recognized text strings |
|
""" |
|
results = [] |
|
|
|
for region in text_regions: |
|
|
|
input_batch = self.preprocess_for_recognition(region) |
|
|
|
|
|
outputs = self.recognizer.run(None, {"input": input_batch}) |
|
|
|
|
|
output_array = outputs[0] if isinstance(outputs[0], np.ndarray) else np.array(outputs[0]) |
|
text = self._decode_text(output_array) |
|
results.append(text) |
|
|
|
return results |
|
|
|
def _decode_text(self, output: np.ndarray) -> str: |
|
"""Decode recognition output to text string using greedy decoding.""" |
|
|
|
indices = np.argmax(output[0], axis=1) |
|
|
|
|
|
text = '' |
|
prev_char = '' |
|
|
|
for idx in indices: |
|
if idx < len(self.charset) and idx > 0: |
|
char = self.charset[idx] |
|
|
|
if char != prev_char: |
|
text += char |
|
prev_char = char |
|
|
|
return text.strip() |
|
|
|
def extract_simple_regions(self, detection_output: np.ndarray, |
|
original_image: np.ndarray, |
|
threshold: float = 0.3) -> List[np.ndarray]: |
|
""" |
|
Extract text regions from detection output (simplified version). |
|
In practice, you'd implement proper CRAFT post-processing. |
|
""" |
|
|
|
|
|
|
|
h, w = original_image.shape[:2] |
|
|
|
|
|
if len(detection_output.shape) == 4: |
|
detection_map = detection_output[0, 0] |
|
elif len(detection_output.shape) == 3: |
|
detection_map = detection_output[0] |
|
else: |
|
detection_map = detection_output |
|
|
|
|
|
if detection_map.max() > 1.0: |
|
detection_map = detection_map / detection_map.max() |
|
|
|
|
|
binary_map = (detection_map > threshold).astype(np.uint8) * 255 |
|
binary_map = cv2.resize(binary_map, (w, h)) |
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) |
|
binary_map = cv2.morphologyEx(binary_map, cv2.MORPH_CLOSE, kernel) |
|
|
|
contours, _ = cv2.findContours(binary_map, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
text_regions = [] |
|
for contour in contours: |
|
|
|
x, y, w_box, h_box = cv2.boundingRect(contour) |
|
|
|
|
|
if w_box > 15 and h_box > 8 and cv2.contourArea(contour) > 100: |
|
|
|
x = max(0, x - 2) |
|
y = max(0, y - 2) |
|
w_box = min(w - x, w_box + 4) |
|
h_box = min(h - y, h_box + 4) |
|
|
|
|
|
region = original_image[y:y+h_box, x:x+w_box] |
|
if region.size > 0: |
|
text_regions.append(region) |
|
|
|
|
|
if len(text_regions) == 0: |
|
print(" No CRAFT regions found, using fallback method...") |
|
|
|
step_y, step_x = h // 4, w // 4 |
|
for y in range(0, h - 32, step_y): |
|
for x in range(0, w - 100, step_x): |
|
region = original_image[y:y+32, x:x+100] |
|
if region.size > 0 and np.mean(region) < 240: |
|
text_regions.append(region) |
|
if len(text_regions) >= 4: |
|
break |
|
if len(text_regions) >= 4: |
|
break |
|
|
|
return text_regions |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="EasyOCR ONNX Example") |
|
parser.add_argument("--image", type=str, required=True, help="Path to input image") |
|
parser.add_argument("--detector", type=str, default="craft_mlt_25k_jpqd.onnx", |
|
help="Path to detection model") |
|
parser.add_argument("--recognizer", type=str, default="english_g2_jpqd.onnx", |
|
help="Path to recognition model") |
|
parser.add_argument("--output", type=str, help="Path to save output image with detections") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.image): |
|
print(f"Error: Image file not found: {args.image}") |
|
return |
|
|
|
if not os.path.exists(args.detector): |
|
print(f"Error: Detector model not found: {args.detector}") |
|
return |
|
|
|
if not os.path.exists(args.recognizer): |
|
print(f"Error: Recognizer model not found: {args.recognizer}") |
|
return |
|
|
|
|
|
print("Initializing EasyOCR ONNX...") |
|
ocr = EasyOCR_ONNX(args.detector, args.recognizer) |
|
|
|
|
|
print(f"Loading image: {args.image}") |
|
image = cv2.imread(args.image) |
|
if image is None: |
|
print(f"Error: Could not load image: {args.image}") |
|
return |
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
print("Detecting text regions...") |
|
detection_output = ocr.detect_text(image_rgb) |
|
|
|
|
|
text_regions = ocr.extract_simple_regions(detection_output, image_rgb) |
|
print(f"Found {len(text_regions)} text regions") |
|
|
|
|
|
if text_regions: |
|
print("Recognizing text...") |
|
recognized_texts = ocr.recognize_text(text_regions) |
|
|
|
|
|
print(f"\nRecognized text ({len(recognized_texts)} regions):") |
|
print("-" * 50) |
|
for i, text in enumerate(recognized_texts): |
|
print(f"Region {i+1}: '{text}'") |
|
else: |
|
print("No text regions detected") |
|
|
|
|
|
if args.output and text_regions: |
|
output_image = image.copy() |
|
|
|
cv2.imwrite(args.output, output_image) |
|
print(f"Output saved to: {args.output}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |