File size: 4,676 Bytes
18e4106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
import random
from PIL import Image, ImageDraw
from pdf_extract_kit.registry.registry import TASK_REGISTRY
from pdf_extract_kit.utils.data_preprocess import load_pdf
from pdf_extract_kit.tasks.base_task import BaseTask


@TASK_REGISTRY.register("ocr")
class OCRTask(BaseTask):
    def __init__(self, model):
        """init the task based on the given model.
        
        Args:
            model: task model, must contains predict function.
        """
        super().__init__(model)

    def predict_image(self, image):
        """predict on one image, reture text detection and recognition results.
        
        Args:
            image: PIL.Image.Image, (if the model.predict function support other types, remenber add change-format-function in model.predict)
            
        Returns:
            List[dict]: list of text bbox with it's content
            
        Return example:
            [
                {
                    "category_type": "text",
                    "poly": [
                        380.6792698635707,
                        159.85058512958923,
                        765.1419999999998,
                        159.85058512958923,
                        765.1419999999998,
                        192.51073013642917,
                        380.6792698635707,
                        192.51073013642917
                    ],
                    "text": "this is an example text",
                    "score": 0.97
                },
                ...
            ]
        """
        return self.model.predict(image)
        
    def prepare_input_files(self, input_path):
        if os.path.isdir(input_path):
            file_list = [os.path.join(input_path, fname) for fname in os.listdir(input_path)]
        else:
            file_list = [input_path]
        return file_list
            
    def process(self, input_path, save_dir=None, visualize=False):
        file_list = self.prepare_input_files(input_path)
        res_list = []
        for fpath in file_list:
            basename = os.path.basename(fpath)[:-4]
            if fpath.endswith(".pdf") or fpath.endswith(".PDF"):
                images = load_pdf(fpath)
                pdf_res = []
                for page, img in enumerate(images):
                    page_res = self.predict_image(img)
                    pdf_res.append(page_res)
                    if save_dir:
                        os.makedirs(os.path.join(save_dir, basename), exist_ok=True)
                        self.save_json_result(page_res, os.path.join(save_dir, basename, f"page_{page+1}.json"))
                        if visualize:
                            self.visualize_image(img, page_res, os.path.join(save_dir, basename, f"page_{page+1}.jpg"))
                        
                res_list.append(pdf_res)
            else:
                image = Image.open(fpath)
                img_res = self.predict_image(image)
                res_list.append(img_res)
                if save_dir:
                    os.makedirs(save_dir, exist_ok=True)
                    self.save_json_result(img_res, os.path.join(save_dir, f"{basename}.json"))
                    if visualize:
                        self.visualize_image(image, img_res, os.path.join(save_dir, f"{basename}.png"))
                
        return res_list
    
    def visualize_image(self, image, ocr_res, save_path="", cate2color={}):
        """plot each result's bbox and category on image.
        
        Args:
            image: PIL.Image.Image
            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function
            save_path: path to save visualized image
        """
        draw = ImageDraw.Draw(image)
        for res in ocr_res:
            box_color = cate2color.get(res['category_type'], (0, 255, 0))
            x_min, y_min = int(res['poly'][0]), int(res['poly'][1])
            x_max, y_max = int(res['poly'][4]), int(res['poly'][5])
            draw.rectangle([x_min, y_min, x_max, y_max], fill=None, outline=box_color, width=1)
            draw.text((x_min, y_min), res['category_type'], (255, 0, 0))
        if save_path:
            image.save(save_path)
        
    def save_json_result(self, ocr_res, save_path):
        """save results to a json file.
        
        Args:
            ocr_res: list of ocr det and rec, whose format following the results of self.predict_image function
            save_path: path to save visualized image
        """
        with open(save_path, "w", encoding="utf-8") as f:
            f.write(json.dumps(ocr_res, indent=2, ensure_ascii=False))