File size: 2,967 Bytes
18e4106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import logging
import argparse

import cv2
import torch
import numpy as np
from PIL import Image
import unimernet.tasks as tasks
from unimernet.common.config import Config
from unimernet.processors import load_processor

from pdf_extract_kit.registry import MODEL_REGISTRY


@MODEL_REGISTRY.register('formula_recognition_unimernet')
class FormulaRecognitionUniMERNet:
    def __init__(self, config):
        """
        Initialize the FormulaRecognitionUniMERNet class.

        Args:
            config (dict): Configuration dictionary containing model parameters.
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_dir = config['model_path']
        self.cfg_path = config.get('cfg_path', "pdf_extract_kit/configs/unimernet.yaml")
        self.batch_size = config.get('batch_size', 1)

        # Load the UniMERNet model
        self.model, self.vis_processor = self.load_model_and_processor()

    def load_model_and_processor(self):
        try:
            args = argparse.Namespace(cfg_path=self.cfg_path, options=None)
            cfg = Config(args)
            cfg.config.model.pretrained = os.path.join(self.model_dir, "pytorch_model.pth")
            cfg.config.model.model_config.model_name = self.model_dir
            cfg.config.model.tokenizer_config.path = self.model_dir
            task = tasks.setup_task(cfg)
            model = task.build_model(cfg).to(self.device)
            vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
            return model, vis_processor
        except Exception as e:
            logging.error(f"Error loading model and processor: {e}")
            raise
    
    def predict(self, images, result_path):
        results = []
        for image_path in images:
            # Read the image using OpenCV
            open_cv_image = cv2.imread(image_path)
            if open_cv_image is None:
                logging.error(f"Error: Unable to open image at {image_path}")
                continue
            # Convert the OpenCV image to PIL.Image format
            raw_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))

            try:
                # Process the image using the visual processor and prepare it for the model
                image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)

                # Generate the prediction using the model
                output = self.model.generate({"image": image})
                pred = output["pred_str"][0]
                logging.info(f'Prediction for {image_path}:\n{pred}')

                # cv2.imshow('Original Image', open_cv_image)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()

                results.append(pred)
            except Exception as e:
                logging.error(f"Error processing image {image_path}: {e}")
    
        return results