import json
import copy
from PIL import Image
from pypdf import PdfReader
from vllm import LLM, SamplingParams
from ocrflux.image_utils import get_page_image
from ocrflux.table_format import table_matrix2html
from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
def build_qwen2_5_vl_prompt(question):
return (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def build_page_to_markdown_query(file_path: str, page_number: int, target_longest_image_dim: int = 1024, image_rotation: int = 0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image = get_page_image(file_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
question = build_page_to_markdown_prompt()
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def build_element_merge_detect_query(text_list_1,text_list_2) -> dict:
image = Image.new('RGB', (28, 28), color='black')
question = build_element_merge_detect_prompt(text_list_1,text_list_2)
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def build_html_table_merge_query(text_1,text_2) -> dict:
image = Image.new('RGB', (28, 28), color='black')
question = build_html_table_merge_prompt(text_1,text_2)
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
page_to_markdown_keys = list(page_to_markdown_result.keys())
element_merge_detect_keys = list(element_merge_detect_result.keys())
html_table_merge_keys = list(html_table_merge_result.keys())
for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
page_to_markdown_result[page_2][elem_idx_2] = ''
for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
else:
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
page_to_markdown_result[page_2][elem_idx_2] = ''
document_text_list = []
for page in page_to_markdown_keys:
page_text_list = [s for s in page_to_markdown_result[page] if s]
document_text_list += page_text_list
return "\n\n".join(document_text_list)
def parse(llm,file_path,skip_cross_page_merge=False,max_page_retries=0):
sampling_params = SamplingParams(temperature=0.0,max_tokens=8192)
if file_path.lower().endswith(".pdf"):
try:
reader = PdfReader(file_path)
num_pages = reader.get_num_pages()
except:
return None
else:
num_pages = 1
try:
# Stage 1: Page to Markdown
page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in range(1, num_pages + 1)]
responses = llm.generate(page_to_markdown_query_list, sampling_params=sampling_params)
results = [response.outputs[0].text for response in responses]
page_to_markdown_result = {}
retry_list = []
for i,result in enumerate(results):
try:
json_data = json.loads(result)
page_response = PageResponse(**json_data)
natural_text = page_response.natural_text
markdown_element_list = []
for text in natural_text.split('\n\n'):
if text.startswith("