Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
from io import BytesIO | |
import gradio as gr | |
from mistralai import Mistral, TextChunk, ImageURLChunk | |
from PIL import Image | |
def process_msg(client, target): | |
image_response = client.ocr.process( | |
document=ImageURLChunk(image_url=target), | |
model="mistral-ocr-latest", | |
include_image_base64=True | |
) | |
processed_output = image_response.pages[0].markdown | |
images=[] | |
base64_str = image_response.pages[0].images[0].image_base64 | |
if "," in base64_str: | |
base64_str = base64_str.split(",")[1] | |
img_bytes = base64.b64decode(base64_str) | |
img_pil = Image.open(BytesIO(img_bytes)) | |
images.append(img_pil) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
ImageURLChunk(image_url=target), | |
TextChunk(text=( | |
f"This is the image's OCR in markdown:\n{processed_output}\n.\n" | |
"Convert this into a structured JSON response " | |
"with the medication details only, including name, dosage, frequency," | |
"prescriber name, phone number and ID." | |
"return in json message only" | |
) | |
) | |
] | |
} | |
] | |
return messages, images | |
def chat_response(client, model, messages): | |
response = client.chat.complete( | |
model=model, | |
messages=messages | |
) | |
return response | |
#config | |
#VALID_DOCUMENT_EXTENSIONS = {".pdf"} | |
VALID_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png",} | |
def do_ocr(input_type, url=None): | |
document_source = None | |
api_key = "gzyEaAV1hIXtLCE9AvJL9mWjJWte0ikt" | |
model = "mistral-small-latest" | |
client = Mistral(api_key=api_key) | |
if input_type == "URL": | |
if not url or url.strip() == "": | |
return "Please provide a valid URL.", "", [] | |
url_lower = url.lower() | |
if any(url_lower.endswith(ext) for ext in VALID_IMAGE_EXTENSIONS): | |
document_source = {"type": "image_url", "image_url": url.strip()} | |
else: | |
document_source = {"type": "document_url", "document_url": url.strip()} | |
else: | |
return "Invalid input type ", "", [] | |
#ocr = PaddleOCR(use_angle_cls=True, lang="en",use_gpu=False) | |
#ocr_response = ocr.ocr(document_source[document_source["type"]], cls=True)[0] | |
message, images = process_msg(client, document_source[document_source["type"]]) | |
response = chat_response(client, model, message) | |
return response.choices[0].message.content, images | |
custom_css = """ | |
body {font-family: body {font-family: 'Helvetica Neue', Helvetica;} | |
.gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;} | |
.gr-button:hover {background-color: #45a049;} | |
.gr-textbox {margin-bottom: 15px;} | |
.example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;} | |
.example-button:hover {background-color: #FF4500;} | |
.tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;} | |
.tall-radio label {font-size: 16px;} | |
""" | |
with gr.Blocks( | |
title="Tech Demo", | |
css=custom_css, | |
theme=gr.themes.Soft() | |
) as demo: | |
gr.Markdown("<h1 style='text-align: center; color: #333;'>Extract text from images using AI OCR model</h1>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_type = gr.Radio( | |
choices=["URL", "Upload file"], | |
label="Input Type", | |
value="URL", | |
elem_classes="tall-radio" | |
) | |
url_input = gr.Textbox( | |
label="Document or Image URL", | |
placeholder="e.g., https://arxiv.org/pdf/2501.12948", | |
visible=True, | |
lines=1 | |
) | |
submit_btn = gr.Button("Extract Text and Images") | |
gr.Markdown("### Try These Examples") | |
img_example = gr.Button("Image", elem_classes="example-button") | |
with gr.Column(scale=2): | |
cleaned_output = gr.Textbox(label="Extracted Plain Text", lines=10, show_copy_button=True) | |
image_output = gr.Gallery(label="OCR Extracted Images", columns=10, height="auto") | |
def update_visibility(choice): | |
return gr.update(visible=(choice == "URL")), gr.update(visible=(choice == "Upload file")) | |
input_type.change(fn=update_visibility, inputs=input_type, outputs=[url_input]) | |
def set_url_and_type(url): | |
return url, "URL" | |
img_example.click( | |
fn=lambda: set_url_and_type("https://everythingmedschool.com/wp-content/uploads/2022/09/sample-rx2-sm.jpg"), | |
outputs=[url_input, input_type] | |
) | |
submit_btn.click( | |
fn=do_ocr, | |
inputs=[input_type, url_input], | |
outputs=[cleaned_output, url_input] | |
) | |
if __name__ == "__main__": | |
demo.launch() |