|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
import cv2 |
|
import os |
|
import torch |
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from utils import notification_queue, log_print |
|
|
|
|
|
processor = None |
|
model = None |
|
|
|
def initialize_model(): |
|
"""Initialize the TrOCR model and processor""" |
|
global processor, model |
|
MODEL_NAME = "microsoft/trocr-large-handwritten" |
|
try: |
|
log_print("Initializing TrOCR model...") |
|
processor = TrOCRProcessor.from_pretrained(MODEL_NAME) |
|
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) |
|
if torch.cuda.is_available(): |
|
model = model.to('cuda') |
|
log_print("Model moved to CUDA") |
|
log_print("TrOCR model initialized successfully") |
|
except Exception as e: |
|
error_msg = str(e) |
|
log_print(f"Error initializing TrOCR model: {error_msg}", "ERROR") |
|
raise |
|
|
|
def text(image_cv): |
|
try: |
|
|
|
if processor is None or model is None: |
|
log_print("TrOCR model not initialized, initializing now...") |
|
initialize_model() |
|
|
|
if processor is None or model is None: |
|
raise RuntimeError("Failed to initialize TrOCR model") |
|
|
|
if not isinstance(image_cv, list): |
|
image_cv = [image_cv] |
|
|
|
t = "" |
|
total_images = len(image_cv) |
|
log_print(f"Processing {total_images} image(s) for text extraction") |
|
|
|
for i, img in enumerate(image_cv): |
|
try: |
|
log_print(f"Processing image {i+1}/{total_images}") |
|
|
|
|
|
if img is None: |
|
log_print(f"Skipping image {i+1} - Image is None", "WARNING") |
|
continue |
|
|
|
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
image = Image.fromarray(img_rgb) |
|
|
|
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
if torch.cuda.is_available(): |
|
pixel_values = pixel_values.to('cuda') |
|
|
|
|
|
generated_ids = model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
cleaned_text = generated_text.replace(" ", "") |
|
t = t + cleaned_text + " " |
|
|
|
log_print(f"Successfully extracted text from image {i+1}: {cleaned_text}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
del pixel_values |
|
del generated_ids |
|
torch.cuda.empty_cache() |
|
|
|
except Exception as e: |
|
log_print(f"Error processing image {i+1}: {str(e)}", "ERROR") |
|
continue |
|
|
|
return t.strip() |
|
|
|
except Exception as e: |
|
error_msg = f"Error in text function: {str(e)}" |
|
log_print(error_msg, "ERROR") |
|
notification_queue.put({ |
|
"type": "error", |
|
"message": error_msg |
|
}) |
|
return "" |
|
|
|
|