File size: 2,967 Bytes
18e4106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
import logging
import argparse
import cv2
import torch
import numpy as np
from PIL import Image
import unimernet.tasks as tasks
from unimernet.common.config import Config
from unimernet.processors import load_processor
from pdf_extract_kit.registry import MODEL_REGISTRY
@MODEL_REGISTRY.register('formula_recognition_unimernet')
class FormulaRecognitionUniMERNet:
def __init__(self, config):
"""
Initialize the FormulaRecognitionUniMERNet class.
Args:
config (dict): Configuration dictionary containing model parameters.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_dir = config['model_path']
self.cfg_path = config.get('cfg_path', "pdf_extract_kit/configs/unimernet.yaml")
self.batch_size = config.get('batch_size', 1)
# Load the UniMERNet model
self.model, self.vis_processor = self.load_model_and_processor()
def load_model_and_processor(self):
try:
args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(self.model_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = self.model_dir
cfg.config.model.tokenizer_config.path = self.model_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg).to(self.device)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
except Exception as e:
logging.error(f"Error loading model and processor: {e}")
raise
def predict(self, images, result_path):
results = []
for image_path in images:
# Read the image using OpenCV
open_cv_image = cv2.imread(image_path)
if open_cv_image is None:
logging.error(f"Error: Unable to open image at {image_path}")
continue
# Convert the OpenCV image to PIL.Image format
raw_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
try:
# Process the image using the visual processor and prepare it for the model
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
# Generate the prediction using the model
output = self.model.generate({"image": image})
pred = output["pred_str"][0]
logging.info(f'Prediction for {image_path}:\n{pred}')
# cv2.imshow('Original Image', open_cv_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
results.append(pred)
except Exception as e:
logging.error(f"Error processing image {image_path}: {e}")
return results |