tableocr / model.py
ahmedzein's picture
Update model.py
8d05638 verified
raw
history blame
2.36 kB
import torch
from torchvision import transforms
from transformers import AutoModelForObjectDetection
from transformers import TableTransformerForObjectDetection
import easyocr
import pandas as pd
from helper import *
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
model.config.id2label
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
crop_padding = 5
structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all")
structure_model.to(device)
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
detection_class_thresholds = {
"table": 0.5,
"table rotated": 0.5,
"no object": 1000
}
detection_transform = transforms.Compose([
MaxResize(800),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# update id2label to include "no object"
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
def inference(image):
print(f"{GREEN}>>> inference started{RESET}")
pixel_values = detection_transform(image).unsqueeze(0)
pixel_values = pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values)
objects = outputs_to_objects(outputs, image.size, id2label)
tokens = []
tables_crops = objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
if len(tables_crops)==0:
return pd.DataFrame()
cropped_table = tables_crops[0]['image'].convert("RGB")
pixel_values = structure_transform(cropped_table).unsqueeze(0)
pixel_values = pixel_values.to(device)
with torch.no_grad():
outputs = structure_model(pixel_values)
structure_id2label = structure_model.config.id2label
structure_id2label[len(structure_id2label)] = "no object"
cells = outputs_to_objects(outputs, cropped_table.size, structure_id2label)
cell_coordinates = get_cell_coordinates_by_row(cells)
reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory
data = apply_ocr(cell_coordinates, cropped_table, reader)
tf= pd.DataFrame(data).T
return tf