ddxorg's picture
add preview
f99b3ea
raw
history blame
4.92 kB
import os
import base64
from io import BytesIO
import gradio as gr
from mistralai import Mistral, TextChunk, ImageURLChunk
from PIL import Image
def process_msg(client, target):
image_response = client.ocr.process(
document=ImageURLChunk(image_url=target),
model="mistral-ocr-latest",
include_image_base64=True
)
processed_output = image_response.pages[0].markdown
images=[]
base64_str = image_response.pages[0].images[0].image_base64
if "," in base64_str:
base64_str = base64_str.split(",")[1]
img_bytes = base64.b64decode(base64_str)
img_pil = Image.open(BytesIO(img_bytes))
images.append(img_pil)
messages = [
{
"role": "user",
"content": [
ImageURLChunk(image_url=target),
TextChunk(text=(
f"This is the image's OCR in markdown:\n{processed_output}\n.\n"
"Convert this into a structured JSON response "
"with the medication details only, including name, dosage, frequency,"
"prescriber name, phone number and ID."
"return in json message only"
)
)
]
}
]
return messages, images
def chat_response(client, model, messages):
response = client.chat.complete(
model=model,
messages=messages
)
return response
#config
#VALID_DOCUMENT_EXTENSIONS = {".pdf"}
VALID_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png",}
def do_ocr(input_type, url=None):
document_source = None
api_key = "gzyEaAV1hIXtLCE9AvJL9mWjJWte0ikt"
model = "mistral-small-latest"
client = Mistral(api_key=api_key)
if input_type == "URL":
if not url or url.strip() == "":
return "Please provide a valid URL.", "", []
url_lower = url.lower()
if any(url_lower.endswith(ext) for ext in VALID_IMAGE_EXTENSIONS):
document_source = {"type": "image_url", "image_url": url.strip()}
else:
document_source = {"type": "document_url", "document_url": url.strip()}
else:
return "Invalid input type ", "", []
#ocr = PaddleOCR(use_angle_cls=True, lang="en",use_gpu=False)
#ocr_response = ocr.ocr(document_source[document_source["type"]], cls=True)[0]
message, images = process_msg(client, document_source[document_source["type"]])
response = chat_response(client, model, message)
return response.choices[0].message.content, images
custom_css = """
body {font-family: body {font-family: 'Helvetica Neue', Helvetica;}
.gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;}
.gr-button:hover {background-color: #45a049;}
.gr-textbox {margin-bottom: 15px;}
.example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;}
.example-button:hover {background-color: #FF4500;}
.tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;}
.tall-radio label {font-size: 16px;}
"""
with gr.Blocks(
title="Tech Demo",
css=custom_css,
theme=gr.themes.Soft()
) as demo:
gr.Markdown("<h1 style='text-align: center; color: #333;'>Extract text from images using AI OCR model</h1>")
with gr.Row():
with gr.Column(scale=1):
input_type = gr.Radio(
choices=["URL", "Upload file"],
label="Input Type",
value="URL",
elem_classes="tall-radio"
)
url_input = gr.Textbox(
label="Document or Image URL",
placeholder="e.g., https://arxiv.org/pdf/2501.12948",
visible=True,
lines=1
)
submit_btn = gr.Button("Extract Text and Images")
gr.Markdown("### Try These Examples")
img_example = gr.Button("Image", elem_classes="example-button")
with gr.Column(scale=2):
cleaned_output = gr.Textbox(label="Extracted Plain Text", lines=10, show_copy_button=True)
image_output = gr.Gallery(label="OCR Extracted Images", columns=10, height="auto")
def update_visibility(choice):
return gr.update(visible=(choice == "URL")), gr.update(visible=(choice == "Upload file"))
input_type.change(fn=update_visibility, inputs=input_type, outputs=[url_input])
def set_url_and_type(url):
return url, "URL"
img_example.click(
fn=lambda: set_url_and_type("https://everythingmedschool.com/wp-content/uploads/2022/09/sample-rx2-sm.jpg"),
outputs=[url_input, input_type]
)
submit_btn.click(
fn=do_ocr,
inputs=[input_type, url_input],
outputs=[cleaned_output, url_input]
)
if __name__ == "__main__":
demo.launch()