visionbuddy / app.py
Dy100's picture
Upload folder using huggingface_hub
1d52f21 verified
import gradio as gr
from transformers import (
PaliGemmaProcessor,
PaliGemmaForConditionalGeneration,
)
import torch
from PIL import Image
import numpy as np
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model and processor
model_id = "google/paligemma2-3b-mix-448"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
).eval()
processor = PaliGemmaProcessor.from_pretrained(model_id)
print("Model and processor loaded successfully")
# Process image
def process_image(image, task_type, question="", objects=""):
try:
if task_type == "Describe Image":
prompt = "describe en"
elif task_type == "OCR Text Recognition":
prompt = "ocr"
elif task_type == "Answer Question":
prompt = f"answer en {question}"
elif task_type == "Detect Objects":
prompt = f"detect {objects}"
else:
return "Please select a valid task."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(
**model_inputs,
max_new_tokens=100,
do_sample=False
)
generation = generation[0][input_len:]
result = processor.decode(generation, skip_special_tokens=True)
return result
except Exception as e:
return f"Error during processing: {str(e)}"
# Elegant website-style CSS
custom_css = """
"""
# Gradio app
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("""<h1>PaliGemma 2 Visual AI Assistant</h1>""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Image", elem_classes="image-preview")
task_type = gr.Radio(
choices=["Describe Image", "OCR Text Recognition", "Answer Question", "Detect Objects"],
label="Choose Task",
value="Describe Image"
)
question_input = gr.Textbox(label="Question", placeholder="Type a question", visible=False)
objects_input = gr.Textbox(label="Objects to Detect", placeholder="e.g., cat; car", visible=False)
submit_btn = gr.Button("πŸ” Analyze")
with gr.Column():
output_text = gr.Textbox(label="Result", lines=10)
def update_inputs(task):
return {
question_input: gr.update(visible=(task == "Answer Question")),
objects_input: gr.update(visible=(task == "Detect Objects"))
}
task_type.change(fn=update_inputs, inputs=[task_type], outputs=[question_input, objects_input])
submit_btn.click(fn=process_image, inputs=[image_input, task_type, question_input, objects_input], outputs=output_text)
if __name__ == "__main__":
demo.launch(share=True, inbrowser=True)