Spaces:
Sleeping
Sleeping
File size: 3,287 Bytes
1d52f21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import gradio as gr
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import torch
from PIL import Image
import numpy as np
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model and processor
model_id = "google/paligemma2-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
print("Model and processor loaded successfully")
# Process image
def process_image(image, task_type, question="", objects=""):
try:
if task_type == "Describe Image":
prompt = "describe en"
elif task_type == "OCR Text Recognition":
prompt = "ocr"
elif task_type == "Answer Question":
prompt = f"answer en {question}"
elif task_type == "Detect Objects":
prompt = f"detect {objects}"
else:
return "Please select a valid task."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(
**model_inputs,
max_new_tokens=100,
do_sample=False
)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
return result
except Exception as e:
return f"Error during processing: {str(e)}"
# Elegant website-style CSS
custom_css = """
"""
# Gradio app
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("""<h1>PaliGemma 2 Visual AI Assistant</h1>""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", elem_classes="image-preview")
task_type = gr.Radio(
choices=["Describe Image", "OCR Text Recognition", "Answer Question", "Detect Objects"],
label="Choose Task",
value="Describe Image"
)
question_input = gr.Textbox(label="Question", placeholder="Type a question", visible=False)
objects_input = gr.Textbox(label="Objects to Detect", placeholder="e.g., cat; car", visible=False)
submit_btn = gr.Button("🔍 Analyze")
with gr.Column():
output_text = gr.Textbox(label="Result", lines=10)
def update_inputs(task):
return {
question_input: gr.update(visible=(task == "Answer Question")),
objects_input: gr.update(visible=(task == "Detect Objects"))
}
task_type.change(fn=update_inputs, inputs=[task_type], outputs=[question_input, objects_input])
submit_btn.click(fn=process_image, inputs=[image_input, task_type, question_input, objects_input], outputs=output_text)
if __name__ == "__main__":
demo.launch(share=True, inbrowser=True)
|