yamanavijayavardhan's picture
fix circular call 3
237de6f
raw
history blame
2.11 kB
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import cv2
import os
import torch
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import notification_queue, log_print
# Global variables for model and processor
processor = None
model = None
def initialize_model():
"""Initialize the TrOCR model and processor"""
global processor, model
MODEL_NAME = "microsoft/trocr-large-handwritten"
try:
log_print("Initializing TrOCR model...")
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
if torch.cuda.is_available():
model = model.to('cuda')
log_print("Model moved to CUDA")
log_print("TrOCR model initialized successfully")
except Exception as e:
error_msg = str(e)
log_print(f"Error initializing TrOCR model: {error_msg}", "ERROR")
raise
def text(image_cv):
try:
# Initialize model if not already initialized
if processor is None or model is None:
initialize_model()
if not isinstance(image_cv, list):
image_cv = [image_cv]
t = ""
total_images = len(image_cv)
notification_queue.put({
"type": "info",
"message": f"Processing {total_images} image(s)..."
})
for i in image_cv:
img_rgb = cv2.cvtColor(i, cv2.COLOR_BGR2RGB)
image = Image.fromarray(img_rgb)
pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
t = t + generated_text.replace(" ", "") + " "
return t
except Exception as e:
error_msg = str(e)
notification_queue.put({
"type": "error",
"message": f"Error in text function: {error_msg}"
})
return ""