yamanavijayavardhan's picture
printing extracted text6
c083a98
raw
history blame
3.5 kB
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import os
import torch
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import notification_queue, log_print
# Global variables for model and processor
processor = None
model = None
def initialize_model():
"""Initialize the TrOCR model and processor"""
global processor, model
MODEL_NAME = "microsoft/trocr-large-handwritten"
try:
log_print("Initializing TrOCR model...")
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
if torch.cuda.is_available():
model = model.to('cuda')
log_print("Model moved to CUDA")
log_print("TrOCR model initialized successfully")
except Exception as e:
error_msg = str(e)
log_print(f"Error initializing TrOCR model: {error_msg}", "ERROR")
raise
def text(image_cv):
try:
# Initialize model if not already initialized
if processor is None or model is None:
log_print("TrOCR model not initialized, initializing now...")
initialize_model()
if processor is None or model is None:
raise RuntimeError("Failed to initialize TrOCR model")
if not isinstance(image_cv, list):
image_cv = [image_cv]
t = ""
total_images = len(image_cv)
log_print(f"Processing {total_images} image(s) for text extraction")
for i, img in enumerate(image_cv):
try:
log_print(f"Processing image {i+1}/{total_images}")
# Validate image
if img is None:
log_print(f"Skipping image {i+1} - Image is None", "WARNING")
continue
# Convert to RGB
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = Image.fromarray(img_rgb)
# Get pixel values
pixel_values = processor(image, return_tensors="pt").pixel_values
if torch.cuda.is_available():
pixel_values = pixel_values.to('cuda')
# Generate text
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Clean up the text
cleaned_text = generated_text.replace(" ", "")
t = t + cleaned_text + " "
log_print(f"Successfully extracted text from image {i+1}: {cleaned_text}")
# Clean up CUDA memory
if torch.cuda.is_available():
del pixel_values
del generated_ids
torch.cuda.empty_cache()
except Exception as e:
log_print(f"Error processing image {i+1}: {str(e)}", "ERROR")
continue
return t.strip()
except Exception as e:
error_msg = f"Error in text function: {str(e)}"
log_print(error_msg, "ERROR")
notification_queue.put({
"type": "error",
"message": error_msg
})
return ""