OCRFlux / ocrflux /inference.py
mirnaresearch's picture
Initial commit for HF Space (no images)
ca5b08e
import json
import copy
from PIL import Image
from pypdf import PdfReader
from vllm import LLM, SamplingParams
from ocrflux.image_utils import get_page_image
from ocrflux.table_format import table_matrix2html
from ocrflux.prompts import PageResponse, build_page_to_markdown_prompt, build_element_merge_detect_prompt, build_html_table_merge_prompt
def build_qwen2_5_vl_prompt(question):
return (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
def build_page_to_markdown_query(file_path: str, page_number: int, target_longest_image_dim: int = 1024, image_rotation: int = 0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image = get_page_image(file_path, page_number, target_longest_image_dim=target_longest_image_dim, image_rotation=image_rotation)
question = build_page_to_markdown_prompt()
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def build_element_merge_detect_query(text_list_1,text_list_2) -> dict:
image = Image.new('RGB', (28, 28), color='black')
question = build_element_merge_detect_prompt(text_list_1,text_list_2)
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def build_html_table_merge_query(text_1,text_2) -> dict:
image = Image.new('RGB', (28, 28), color='black')
question = build_html_table_merge_prompt(text_1,text_2)
prompt = build_qwen2_5_vl_prompt(question)
query = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
return query
def bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result):
page_to_markdown_keys = list(page_to_markdown_result.keys())
element_merge_detect_keys = list(element_merge_detect_result.keys())
html_table_merge_keys = list(html_table_merge_result.keys())
for page_1,page_2,elem_idx_1,elem_idx_2 in sorted(html_table_merge_keys,key=lambda x: -x[0]):
page_to_markdown_result[page_1][elem_idx_1] = html_table_merge_result[(page_1,page_2,elem_idx_1,elem_idx_2)]
page_to_markdown_result[page_2][elem_idx_2] = ''
for page_1,page_2 in sorted(element_merge_detect_keys,key=lambda x: -x[0]):
for elem_idx_1,elem_idx_2 in element_merge_detect_result[(page_1,page_2)]:
if len(page_to_markdown_result[page_1][elem_idx_1]) == 0 or page_to_markdown_result[page_1][elem_idx_1][-1] == '-' or ('\u4e00' <= page_to_markdown_result[page_1][elem_idx_1][-1] <= '\u9fff'):
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + '' + page_to_markdown_result[page_2][elem_idx_2]
else:
page_to_markdown_result[page_1][elem_idx_1] = page_to_markdown_result[page_1][elem_idx_1] + ' ' + page_to_markdown_result[page_2][elem_idx_2]
page_to_markdown_result[page_2][elem_idx_2] = ''
document_text_list = []
for page in page_to_markdown_keys:
page_text_list = [s for s in page_to_markdown_result[page] if s]
document_text_list += page_text_list
return "\n\n".join(document_text_list)
def parse(llm,file_path,skip_cross_page_merge=False,max_page_retries=0):
sampling_params = SamplingParams(temperature=0.0,max_tokens=8192)
if file_path.lower().endswith(".pdf"):
try:
reader = PdfReader(file_path)
num_pages = reader.get_num_pages()
except:
return None
else:
num_pages = 1
try:
# Stage 1: Page to Markdown
page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in range(1, num_pages + 1)]
responses = llm.generate(page_to_markdown_query_list, sampling_params=sampling_params)
results = [response.outputs[0].text for response in responses]
page_to_markdown_result = {}
retry_list = []
for i,result in enumerate(results):
try:
json_data = json.loads(result)
page_response = PageResponse(**json_data)
natural_text = page_response.natural_text
markdown_element_list = []
for text in natural_text.split('\n\n'):
if text.startswith("<Image>") and text.endswith("</Image>"):
pass
elif text.startswith("<table>") and text.endswith("</table>"):
try:
new_text = table_matrix2html(text)
except:
new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
markdown_element_list.append(new_text)
else:
markdown_element_list.append(text)
page_to_markdown_result[i+1] = markdown_element_list
except:
retry_list.append(i)
attempt = 0
while len(retry_list) > 0 and attempt < max_page_retries:
retry_page_to_markdown_query_list = [build_page_to_markdown_query(file_path,page_num) for page_num in retry_list]
retry_sampling_params = SamplingParams(temperature=0.1*attempt, max_tokens=8192)
responses = llm.generate(retry_page_to_markdown_query_list, sampling_params=retry_sampling_params)
results = [response.outputs[0].text for response in responses]
next_retry_list = []
for i,result in zip(retry_list,results):
try:
json_data = json.loads(result)
page_response = PageResponse(**json_data)
natural_text = page_response.natural_text
markdown_element_list = []
for text in natural_text.split('\n\n'):
if text.startswith("<Image>") and text.endswith("</Image>"):
pass
elif text.startswith("<table>") and text.endswith("</table>"):
try:
new_text = table_matrix2html(text)
except:
new_text = text.replace("<t>","").replace("<l>","").replace("<lt>","")
markdown_element_list.append(new_text)
else:
markdown_element_list.append(text)
page_to_markdown_result[i+1] = markdown_element_list
except:
next_retry_list.append(i)
retry_list = next_retry_list
attempt += 1
page_texts = {}
fallback_pages = []
for page_number in range(1, num_pages+1):
if page_number not in page_to_markdown_result.keys():
fallback_pages.append(page_number-1)
else:
page_texts[str(page_number-1)] = "\n\n".join(page_to_markdown_result[page_number])
if skip_cross_page_merge:
document_text_list = []
for i in range(num_pages):
if i not in fallback_pages:
document_text_list.append(page_texts[str(i)])
document_text = "\n\n".join(document_text_list)
return {
"orig_path": file_path,
"num_pages": num_pages,
"document_text": document_text,
"page_texts": page_texts,
"fallback_pages": fallback_pages,
}
# Stage 2: Element Merge Detect
element_merge_detect_keys = []
element_merge_detect_query_list = []
for page_num in range(1,num_pages):
if page_num in page_to_markdown_result.keys() and page_num+1 in page_to_markdown_result.keys():
element_merge_detect_query_list.append(build_element_merge_detect_query(page_to_markdown_result[page_num],page_to_markdown_result[page_num+1]))
element_merge_detect_keys.append((page_num,page_num+1))
responses = llm.generate(element_merge_detect_query_list, sampling_params=sampling_params)
results = [response.outputs[0].text for response in responses]
element_merge_detect_result = {}
for key,result in zip(element_merge_detect_keys,results):
try:
element_merge_detect_result[key] = eval(result)
except:
pass
# Stage 3: HTML Table Merge
html_table_merge_keys = []
for key,result in element_merge_detect_result.items():
page_1,page_2 = key
for elem_idx_1,elem_idx_2 in result:
text_1 = page_to_markdown_result[page_1][elem_idx_1]
text_2 = page_to_markdown_result[page_2][elem_idx_2]
if text_1.startswith("<table>") and text_1.endswith("</table>") and text_2.startswith("<table>") and text_2.endswith("</table>"):
html_table_merge_keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
html_table_merge_keys = sorted(html_table_merge_keys,key=lambda x: -x[0])
html_table_merge_result = {}
page_to_markdown_result_tmp = copy.deepcopy(page_to_markdown_result)
i = 0
while i < len(html_table_merge_keys):
tmp = set()
keys = []
while i < len(html_table_merge_keys):
page_1,page_2,elem_idx_1,elem_idx_2 = html_table_merge_keys[i]
if (page_2,elem_idx_2) in tmp:
break
tmp.add((page_1,elem_idx_1))
keys.append((page_1,page_2,elem_idx_1,elem_idx_2))
i += 1
html_table_merge_query_list = [build_html_table_merge_query(page_to_markdown_result_tmp[page_1][elem_idx_1],page_to_markdown_result_tmp[page_2][elem_idx_2]) for page_1,page_2,elem_idx_1,elem_idx_2 in keys]
responses = llm.generate(html_table_merge_query_list, sampling_params=sampling_params)
results = [response.outputs[0].text for response in responses]
for key,result in zip(keys,results):
if result.startswith("<table>") and result.endswith("</table>"):
html_table_merge_result[key] = result
page_to_markdown_result_tmp[page_1][elem_idx_1] = result
document_text = bulid_document_text(page_to_markdown_result, element_merge_detect_result, html_table_merge_result)
return {
"orig_path": file_path,
"num_pages": num_pages,
"document_text": document_text,
"page_texts": page_texts,
"fallback_pages": fallback_pages,
}
except:
return None
if __name__ == '__main__':
file_path = 'test.pdf'
llm = LLM(model="ChatDOC/OCRFlux-3B",gpu_memory_utilization=0.8,max_model_len=8192)
result = parse(llm,file_path,max_page_retries=4)
if result != None:
document_markdown = result['document_text']
print(document_markdown)
with open('test.md','w') as f:
f.write(document_markdown)
else:
print("Parse failed")