Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| from io import BytesIO | |
| import gradio as gr | |
| from mistralai import Mistral, TextChunk, ImageURLChunk | |
| from PIL import Image | |
| import requests | |
| from presidio_image_redactor import ImageRedactorEngine | |
| model = "mistral-small-latest" | |
| API_KEY = os.getenv("MISTRAL_API_KEY") | |
| VALID_DOCUMENT_EXTENSIONS = {".pdf"} | |
| VALID_IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} | |
| def upload_pdf(content, filename, client): | |
| uploaded_file = client.files.upload( | |
| file={"file_name": filename, "content": content}, | |
| purpose="ocr", | |
| ) | |
| signed_url = client.files.get_signed_url(file_id=uploaded_file.id) | |
| return signed_url.url | |
| def encode_64(image_file): | |
| buffered = BytesIO() | |
| image_file.save(buffered, format="jpeg") | |
| str_base = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return str_base | |
| def redact_imag(document_source): | |
| print(f"Redacting image from source: {document_source}") | |
| redactor = ImageRedactorEngine() | |
| if document_source["type"] == "document_url": | |
| file_url = document_source["document_url"] | |
| # Fetch the image from the URL | |
| prep_img = Image.open(BytesIO(requests.get(file_url).content)) | |
| elif document_source["type"] == "image_url": | |
| file_url = document_source["image_url"] | |
| if file_url.startswith("data:image/"): # Handle base64-encoded image | |
| header, base64_data = file_url.split(",", 1) | |
| prep_img = Image.open(BytesIO(base64.b64decode(base64_data))) | |
| else: # Fetch the image from the URL | |
| prep_img = Image.open(BytesIO(requests.get(file_url).content)) | |
| else: | |
| raise ValueError("Invalid document source type") | |
| # Redact the image | |
| redact_img = redactor.redact( | |
| image=prep_img, | |
| entities=["PERSON", "LOCATION", "DATE_TIME", "PHONE_NUMBER", "MEDICAL_LICENSE"] | |
| ) | |
| print("Redaction complete") | |
| return redact_img | |
| def process_msg(client, image): | |
| # get redact image and encode in base64 | |
| base_encode = encode_64(image) | |
| # add mistral ocr call back in | |
| ocr_response = client.ocr.process( | |
| model="mistral-ocr-latest", | |
| document={ | |
| "type": "image_url", | |
| "image_url": f"data:image/jpg;base64,{base_encode}" | |
| } | |
| ) | |
| # to do: | |
| # 1. in real pipeline run tesseract local | |
| # 2. then presidio text redact after | |
| processed_output=ocr_response.pages[0].markdown | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| ImageURLChunk( | |
| image_url = f"data:image/jpg;base64,{base_encode}" | |
| ), | |
| TextChunk(text=( | |
| f"This is the image's OCR in markdown:\n{processed_output}\n.\n" | |
| "per medication found, include medication generic name (search if you have to)," | |
| "and calcualted dosage (based on strength and number of units)," | |
| "and route," | |
| "and frequency (in medical short forms)," | |
| "dispense quantity," | |
| "and number of refills" | |
| "return in JSON structured message without other details." | |
| ) | |
| ) | |
| ] | |
| } | |
| ] | |
| return messages, processed_output | |
| def chat_response(client, model, messages): | |
| print(f"generating chat response with model: {model} and messages: {messages}") | |
| response = client.chat.complete( | |
| model=model, | |
| messages=messages | |
| ) | |
| print(f"chat complete") | |
| return response | |
| def do_ocr(input_type, url=None, file=None): | |
| print(f"starting do_ocr with input_type: {input_type}, url: {url}, file: {file}") | |
| document_source = None | |
| client = Mistral(api_key=API_KEY) | |
| if input_type == "URL": | |
| if not url or url.strip() == "": | |
| return "Please provide a valid URL.", [], "" | |
| url_lower = url.lower() | |
| if any(url_lower.endswith(ext) for ext in VALID_IMAGE_EXTENSIONS): | |
| document_source = {"type": "image_url", "image_url": url.strip()} | |
| else: | |
| document_source = {"type": "document_url", "document_url": url.strip()} | |
| elif input_type == "Upload file": | |
| if not file: | |
| return "Please upload a file.", [], "" | |
| file_ext = os.path.splitext(file.name)[1].lower() | |
| if file_ext in VALID_DOCUMENT_EXTENSIONS: | |
| file_content = file.read() # Read the file content | |
| signed_url = upload_pdf(file_content, file.name, client) | |
| document_source = {"type": "document_url", "document_url": signed_url} | |
| elif file_ext in VALID_IMAGE_EXTENSIONS: | |
| open_image = Image.open(file) # Pass the file-like object directly | |
| buffered = BytesIO() | |
| open_image.save(buffered, format="PNG") | |
| open_image_string = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| document_source = {"type": "image_url", "image_url": f"data:image/png;base64,{open_image_string}"} | |
| else: | |
| return f"Error: Unsupported file type. Supported types: {', '.join(VALID_DOCUMENT_EXTENSIONS | VALID_IMAGE_EXTENSIONS)}", [], "" | |
| else: | |
| return "Invalid input type", [], "" | |
| #ocr = PaddleOCR(use_angle_cls=True, lang="en",use_gpu=False) | |
| #ocr_response = ocr.ocr(document_source[document_source["type"]], cls=True)[0] | |
| redact_img = redact_imag(document_source) | |
| message, ocr_result = process_msg(client, redact_img) | |
| try: | |
| response = chat_response(client, model, message) | |
| except Exception as e: | |
| return f"Error processing OCR: {str(e)}", [], "" | |
| images = [] | |
| images.append(redact_img) | |
| # Return the base64 image as a list for gr.Gallery | |
| return response.choices[0].message.content, ocr_result, images | |
| custom_css = """ | |
| body {font-family: body {font-family: 'Helvetica Neue', Helvetica;} | |
| .gr-button {background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 5px;} | |
| .gr-button:hover {background-color: #45a049;} | |
| .gr-textbox {margin-bottom: 15px;} | |
| .example-button {background-color: #1E90FF; color: white; border: none; padding: 8px 15px; border-radius: 5px; margin: 5px;} | |
| .example-button:hover {background-color: #FF4500;} | |
| .tall-radio .gr-radio-item {padding: 15px 0; min-height: 50px; display: flex; align-items: center;} | |
| .tall-radio label {font-size: 16px;} | |
| """ | |
| with gr.Blocks( | |
| title="Tech Demo", | |
| css=custom_css, | |
| ) as demo: | |
| gr.Markdown("<h1 style='text-align: center; color: #333;'>Extract text from images using AI OCR model</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_type = gr.Radio( | |
| choices=["URL", "Upload file"], | |
| label="Input Type", | |
| value="URL", | |
| elem_classes="tall-radio" | |
| ) | |
| url_input = gr.Textbox( | |
| label="Document or Image URL", | |
| placeholder="e.g., https://arxiv.org/pdf/2501.12948", | |
| visible=True, | |
| lines=1 | |
| ) | |
| file_input = gr.File( | |
| label="Upload PDF or Image", | |
| file_types=[".pdf", ".jpg", ".jpeg", ".png"], | |
| interactive=True, | |
| visible=False | |
| ) | |
| submit_btn = gr.Button("Extract Text and Images") | |
| gr.Markdown("### Try These Examples") | |
| img_example = gr.Button("Rx Example URL", elem_classes="example-button") | |
| with gr.Column(scale=3): | |
| image_output = gr.Gallery(label="Redacted Image", height="contain") | |
| with gr.Row(): | |
| with gr.Column(): | |
| ocr_result = gr.Textbox(label="Annoymized Text", show_copy_button=True) | |
| with gr.Column(): | |
| cleaned_output = gr.Textbox(label="Extracted JSON") | |
| def update_visibility(choice): | |
| return gr.update(visible=(choice == "URL")), gr.update(visible=(choice == "Upload file")) | |
| input_type.change(fn=update_visibility, inputs=input_type, outputs=[url_input, file_input]) | |
| def set_url_and_type(url): | |
| return url, "URL" | |
| img_example.click( | |
| fn=lambda: set_url_and_type("https://as2.ftcdn.net/v2/jpg/00/56/61/71/1000_F_56617167_ZGbrr3mHPUmLoksQmpuY7SPA8ihTI5Dh.jpg"), | |
| outputs=[url_input, input_type] | |
| ) | |
| submit_btn.click( | |
| fn=do_ocr, | |
| inputs=[input_type, url_input, file_input], | |
| outputs=[cleaned_output, ocr_result, image_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |