|
import os |
|
import logging |
|
import argparse |
|
|
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import unimernet.tasks as tasks |
|
from unimernet.common.config import Config |
|
from unimernet.processors import load_processor |
|
|
|
from pdf_extract_kit.registry import MODEL_REGISTRY |
|
|
|
|
|
@MODEL_REGISTRY.register('formula_recognition_unimernet') |
|
class FormulaRecognitionUniMERNet: |
|
def __init__(self, config): |
|
""" |
|
Initialize the FormulaRecognitionUniMERNet class. |
|
|
|
Args: |
|
config (dict): Configuration dictionary containing model parameters. |
|
""" |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model_dir = config['model_path'] |
|
self.cfg_path = config.get('cfg_path', "pdf_extract_kit/configs/unimernet.yaml") |
|
self.batch_size = config.get('batch_size', 1) |
|
|
|
|
|
self.model, self.vis_processor = self.load_model_and_processor() |
|
|
|
def load_model_and_processor(self): |
|
try: |
|
args = argparse.Namespace(cfg_path=self.cfg_path, options=None) |
|
cfg = Config(args) |
|
cfg.config.model.pretrained = os.path.join(self.model_dir, "pytorch_model.pth") |
|
cfg.config.model.model_config.model_name = self.model_dir |
|
cfg.config.model.tokenizer_config.path = self.model_dir |
|
task = tasks.setup_task(cfg) |
|
model = task.build_model(cfg).to(self.device) |
|
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) |
|
return model, vis_processor |
|
except Exception as e: |
|
logging.error(f"Error loading model and processor: {e}") |
|
raise |
|
|
|
def predict(self, images, result_path): |
|
results = [] |
|
for image_path in images: |
|
|
|
open_cv_image = cv2.imread(image_path) |
|
if open_cv_image is None: |
|
logging.error(f"Error: Unable to open image at {image_path}") |
|
continue |
|
|
|
raw_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) |
|
|
|
try: |
|
|
|
image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) |
|
|
|
|
|
output = self.model.generate({"image": image}) |
|
pred = output["pred_str"][0] |
|
logging.info(f'Prediction for {image_path}:\n{pred}') |
|
|
|
|
|
|
|
|
|
|
|
results.append(pred) |
|
except Exception as e: |
|
logging.error(f"Error processing image {image_path}: {e}") |
|
|
|
return results |