ahmedzein commited on
Commit
c6a18bd
·
verified ·
1 Parent(s): 46b4413

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +7 -0
  2. README.md +2 -9
  3. helper.py +226 -0
  4. main.py +82 -0
  5. model.py +84 -0
  6. pdftoword.py +9 -0
  7. requirements.txt +18 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *tmp*
2
+ __pycache__
3
+ node_modules
4
+ .env
5
+ *.pdf
6
+ *.jpg
7
+ *.docx
README.md CHANGED
@@ -1,10 +1,3 @@
1
- ---
2
- title: Tableocr
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # table_ocr
 
 
 
 
 
 
 
2
 
3
+ 🎉
helper.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm.auto import tqdm
4
+
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
+ from matplotlib.patches import Patch
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ GREEN = "\033[92m"
13
+ RESET = "\033[0m"
14
+
15
+ class MaxResize(object):
16
+ def __init__(self, max_size=800):
17
+ self.max_size = max_size
18
+
19
+ def __call__(self, image):
20
+ width, height = image.size
21
+ current_max_size = max(width, height)
22
+ scale = self.max_size / current_max_size
23
+ resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
24
+
25
+ return resized_image
26
+
27
+ # for output bounding box post-processing
28
+ def box_cxcywh_to_xyxy(x):
29
+ x_c, y_c, w, h = x.unbind(-1)
30
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
31
+ return torch.stack(b, dim=1)
32
+
33
+
34
+ def rescale_bboxes(out_bbox, size):
35
+ img_w, img_h = size
36
+ b = box_cxcywh_to_xyxy(out_bbox)
37
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
38
+ return b
39
+
40
+ def outputs_to_objects(outputs, img_size, id2label):
41
+ m = outputs.logits.softmax(-1).max(-1)
42
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
43
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
44
+ pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
45
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
46
+
47
+ objects = []
48
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
49
+ class_label = id2label[int(label)]
50
+ if not class_label == 'no object':
51
+ objects.append({'label': class_label, 'score': float(score),
52
+ 'bbox': [float(elem) for elem in bbox]})
53
+
54
+ return objects
55
+
56
+
57
+ def fig2img(fig):
58
+ """Convert a Matplotlib figure to a PIL Image and return it"""
59
+ import io
60
+ buf = io.BytesIO()
61
+ fig.savefig(buf)
62
+ buf.seek(0)
63
+ img = Image.open(buf)
64
+ return img
65
+
66
+
67
+ def visualize_detected_tables(img, det_tables, out_path=None):
68
+ plt.imshow(img, interpolation="lanczos")
69
+ fig = plt.gcf()
70
+ fig.set_size_inches(20, 20)
71
+ ax = plt.gca()
72
+
73
+ for det_table in det_tables:
74
+ bbox = det_table['bbox']
75
+
76
+ if det_table['label'] == 'table':
77
+ facecolor = (1, 0, 0.45)
78
+ edgecolor = (1, 0, 0.45)
79
+ alpha = 0.3
80
+ linewidth = 2
81
+ hatch='//////'
82
+ elif det_table['label'] == 'table rotated':
83
+ facecolor = (0.95, 0.6, 0.1)
84
+ edgecolor = (0.95, 0.6, 0.1)
85
+ alpha = 0.3
86
+ linewidth = 2
87
+ hatch='//////'
88
+ else:
89
+ continue
90
+
91
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
92
+ edgecolor='none',facecolor=facecolor, alpha=0.1)
93
+ ax.add_patch(rect)
94
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
95
+ edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha)
96
+ ax.add_patch(rect)
97
+ rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
98
+ edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2)
99
+ ax.add_patch(rect)
100
+
101
+ plt.xticks([], [])
102
+ plt.yticks([], [])
103
+
104
+ legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45),
105
+ label='Table', hatch='//////', alpha=0.3),
106
+ Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1),
107
+ label='Table (rotated)', hatch='//////', alpha=0.3)]
108
+ plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0,
109
+ fontsize=10, ncol=2)
110
+ plt.gcf().set_size_inches(10, 10)
111
+ plt.axis('off')
112
+
113
+ if out_path is not None:
114
+ plt.savefig(out_path, bbox_inches='tight', dpi=150)
115
+
116
+ return fig
117
+
118
+ def objects_to_crops(img, tokens, objects, class_thresholds, padding=10):
119
+ """
120
+ Process the bounding boxes produced by the table detection model into
121
+ cropped table images and cropped tokens.
122
+ """
123
+
124
+ table_crops = []
125
+ for obj in objects:
126
+ if obj['score'] < class_thresholds[obj['label']]:
127
+ continue
128
+
129
+ cropped_table = {}
130
+
131
+ bbox = obj['bbox']
132
+ bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
133
+
134
+ cropped_img = img.crop(bbox)
135
+
136
+ table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
137
+ for token in table_tokens:
138
+ token['bbox'] = [token['bbox'][0]-bbox[0],
139
+ token['bbox'][1]-bbox[1],
140
+ token['bbox'][2]-bbox[0],
141
+ token['bbox'][3]-bbox[1]]
142
+
143
+ # If table is predicted to be rotated, rotate cropped image and tokens/words:
144
+ if obj['label'] == 'table rotated':
145
+ cropped_img = cropped_img.rotate(270, expand=True)
146
+ for token in table_tokens:
147
+ bbox = token['bbox']
148
+ bbox = [cropped_img.size[0]-bbox[3]-1,
149
+ bbox[0],
150
+ cropped_img.size[0]-bbox[1]-1,
151
+ bbox[2]]
152
+ token['bbox'] = bbox
153
+
154
+ cropped_table['image'] = cropped_img
155
+ cropped_table['tokens'] = table_tokens
156
+
157
+ table_crops.append(cropped_table)
158
+
159
+ return table_crops
160
+
161
+
162
+ def get_cell_coordinates_by_row(table_data):
163
+ # Extract rows and columns
164
+ rows = [entry for entry in table_data if entry['label'] == 'table row']
165
+ columns = [entry for entry in table_data if entry['label'] == 'table column']
166
+
167
+ # Sort rows and columns by their Y and X coordinates, respectively
168
+ rows.sort(key=lambda x: x['bbox'][1])
169
+ columns.sort(key=lambda x: x['bbox'][0])
170
+
171
+ # Function to find cell coordinates
172
+ def find_cell_coordinates(row, column):
173
+ cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
174
+ return cell_bbox
175
+
176
+ # Generate cell coordinates and count cells in each row
177
+ cell_coordinates = []
178
+
179
+ for row in rows:
180
+ row_cells = []
181
+ for column in columns:
182
+ cell_bbox = find_cell_coordinates(row, column)
183
+ row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
184
+
185
+ # Sort cells in the row by X coordinate
186
+ row_cells.sort(key=lambda x: x['column'][0])
187
+
188
+ # Append row information to cell_coordinates
189
+ cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
190
+
191
+ # Sort rows from top to bottom
192
+ cell_coordinates.sort(key=lambda x: x['row'][1])
193
+
194
+ return cell_coordinates
195
+
196
+ def apply_ocr(cell_coordinates, cropped_table, reader):
197
+ # let's OCR row by row
198
+ data = dict()
199
+ max_num_columns = 0
200
+ for idx, row in enumerate(tqdm(cell_coordinates)):
201
+ row_text = []
202
+ for cell in row["cells"]:
203
+ # crop cell out of image
204
+ cell_image = np.array(cropped_table.crop(cell["cell"]))
205
+ # apply OCR
206
+ result = reader.readtext(np.array(cell_image))
207
+ if len(result) > 0:
208
+ # print([x[1] for x in list(result)])
209
+ text = " ".join([x[1] for x in result])
210
+ row_text.append(text)
211
+
212
+ if len(row_text) > max_num_columns:
213
+ max_num_columns = len(row_text)
214
+
215
+ data[idx] = row_text
216
+
217
+ # print("Max number of columns:", max_num_columns)
218
+
219
+ # pad rows which don't have max_num_columns elements
220
+ # to make sure all rows have the same number of columns
221
+ for row, row_data in data.copy().items():
222
+ if len(row_data) != max_num_columns:
223
+ row_data = row_data + ["" for _ in range(max_num_columns - len(row_data))]
224
+ data[row] = row_data
225
+
226
+ return data
main.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import io
2
+ import os
3
+
4
+ from fastapi import FastAPI, File, HTTPException, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ from starlette.responses import FileResponse
7
+ from starlette.middleware.cors import CORSMiddleware
8
+
9
+ # From PIL import Image
10
+ from pdftoword import convertPDFtoWORD
11
+
12
+ # from model import inference
13
+
14
+
15
+ app = FastAPI()
16
+
17
+ origins = ["http://localhost:3000"] # Replace with your frontend origin URL
18
+
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Allows all origins
22
+ allow_credentials=True,
23
+ allow_methods=["*"], # Allows all methods
24
+ allow_headers=["*"], # Allows all headers
25
+ )
26
+
27
+
28
+ @app.post("/upload")
29
+ async def extract_table_data(image: UploadFile = File(...)):
30
+ return f"table ocr is disabled 😔"
31
+ # try:
32
+ # # Read image data
33
+ # image_data = await image.read()
34
+
35
+ # # Open image in memory
36
+ # image = Image.open(io.BytesIO(image_data))
37
+ # rgb_img = image.convert("RGB")
38
+ # rgb_img.save('output.jpg')
39
+ # image = Image.open('output.jpg')
40
+
41
+ # table_fram= inference(image)
42
+ # if table_fram.empty:
43
+ # return "<h2 style=\"color: darkslategrey;\">💡 the image has no tables 💡</h2>"
44
+
45
+ # return table_fram.to_html(escape=True,border=1,index=False).replace('\n', '')
46
+
47
+ # except Exception as e:
48
+ # # Handle and log exceptions appropriately
49
+ # print(f"Error processing image: {e}")
50
+ # raise HTTPException(status_code=500, detail="Internal server error")
51
+
52
+
53
+
54
+ @app.post("/convert")
55
+ async def convert_pdf(docxFile: UploadFile = File(...)):
56
+ uploaded_file = docxFile
57
+ try:
58
+ if not uploaded_file.content_type.startswith("application/pdf"):
59
+ raise HTTPException(415, detail="Unsupported file format. Please upload a PDF file.")
60
+
61
+ # Create uploads directory if it doesn't exist
62
+ os.makedirs("uploads", exist_ok=True)
63
+
64
+ # Save the uploaded file
65
+ pdf_file_path = os.path.join("uploads", uploaded_file.filename)
66
+ with open(pdf_file_path, "wb+") as file_object:
67
+ file_object.write(uploaded_file.file.read())
68
+
69
+ # Process the PDF
70
+ docx_path = convertPDFtoWORD(pdf_file_path)
71
+
72
+ # remove the uploaded pdf
73
+ os.unlink(pdf_file_path)
74
+
75
+ return FileResponse(docx_path, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", filename="converted_document.docx")
76
+
77
+ except FileNotFoundError as e:
78
+ # Handle case where conversion fails (e.g., missing converter)
79
+ return JSONResponse({"error": "Conversion failed. Please check the converter or file."}, status_code=500)
80
+ except Exception as e:
81
+ # Catch any unexpected errors
82
+ return JSONResponse({"error": f"An unexpected error occurred: {str(e)}"}, status_code=500)
model.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+
4
+ from transformers import AutoModelForObjectDetection
5
+ from transformers import TableTransformerForObjectDetection
6
+
7
+ import easyocr
8
+ import pandas as pd
9
+
10
+ from helper import *
11
+
12
+ reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory
13
+
14
+ model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
15
+
16
+ model.config.id2label
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model.to(device)
20
+ crop_padding = 5
21
+
22
+ structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
23
+ structure_model.to(device)
24
+ structure_transform = transforms.Compose([
25
+ MaxResize(1000),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28
+ ])
29
+
30
+ detection_class_thresholds = {
31
+ "table": 0.5,
32
+ "table rotated": 0.5,
33
+ "no object": 1000
34
+ }
35
+
36
+ detection_transform = transforms.Compose([
37
+ MaxResize(800),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
40
+ ])
41
+
42
+ # update id2label to include "no object"
43
+ id2label = model.config.id2label
44
+ id2label[len(model.config.id2label)] = "no object"
45
+
46
+
47
+ def inference(image):
48
+ print(f"{GREEN}>>> inference started{RESET}")
49
+
50
+ pixel_values = detection_transform(image).unsqueeze(0)
51
+ pixel_values = pixel_values.to(device)
52
+
53
+ with torch.no_grad():
54
+ outputs = model(pixel_values)
55
+
56
+
57
+ objects = outputs_to_objects(outputs, image.size, id2label)
58
+
59
+ tokens = []
60
+ tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
61
+ if len(tables_crops)==0:
62
+ return pd.DataFrame()
63
+ cropped_table = tables_crops[0]['image'].convert("RGB")
64
+
65
+
66
+ pixel_values = structure_transform(cropped_table).unsqueeze(0)
67
+ pixel_values = pixel_values.to(device)
68
+
69
+ with torch.no_grad():
70
+ outputs = structure_model(pixel_values)
71
+
72
+ structure_id2label = structure_model.config.id2label
73
+ structure_id2label[len(structure_id2label)] = "no object"
74
+
75
+ cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
76
+
77
+ cell_coordinates = get_cell_coordinates_by_row(cells)
78
+
79
+ data = apply_ocr(cell_coordinates, cropped_table, reader)
80
+ tf= pd.DataFrame(data).T
81
+
82
+ return tf
83
+
84
+
pdftoword.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Import the required modules
2
+ from pdf2docx import Converter
3
+
4
+ def convertPDFtoWORD(pdfpath: str)->str:
5
+ docx_file = pdfpath.split('.pdf')[0]+'.docx'
6
+ converter = Converter(pdfpath)
7
+ converter.convert(docx_file)
8
+ converter.close()
9
+ return docx_file
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pdf2docx
2
+
3
+ fastapi
4
+ uvicorn[standard]
5
+
6
+ # matplotlib
7
+ # Pillow
8
+ # pandas
9
+ # joblib
10
+ # scipy
11
+ # numpy
12
+
13
+ # easyocr
14
+
15
+ # tqdm
16
+ # torch
17
+ # torchvision
18
+ # transformers