Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ( | |
PaliGemmaProcessor, | |
PaliGemmaForConditionalGeneration, | |
) | |
import torch | |
from PIL import Image | |
import numpy as np | |
# Device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load model and processor | |
model_id = "google/paligemma2-3b-mix-448" | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
model_id, | |
torch_dtype=torch.float32, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
).eval() | |
processor = PaliGemmaProcessor.from_pretrained(model_id) | |
print("Model and processor loaded successfully") | |
# Process image | |
def process_image(image, task_type, question="", objects=""): | |
try: | |
if task_type == "Describe Image": | |
prompt = "describe en" | |
elif task_type == "OCR Text Recognition": | |
prompt = "ocr" | |
elif task_type == "Answer Question": | |
prompt = f"answer en {question}" | |
elif task_type == "Detect Objects": | |
prompt = f"detect {objects}" | |
else: | |
return "Please select a valid task." | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
model_inputs = processor(text=prompt, images=image, return_tensors="pt") | |
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate( | |
**model_inputs, | |
max_new_tokens=100, | |
do_sample=False | |
) | |
generation = generation[0][input_len:] | |
result = processor.decode(generation, skip_special_tokens=True) | |
return result | |
except Exception as e: | |
return f"Error during processing: {str(e)}" | |
# Elegant website-style CSS | |
custom_css = """ | |
""" | |
# Gradio app | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown("""<h1>PaliGemma 2 Visual AI Assistant</h1>""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Upload Image", elem_classes="image-preview") | |
task_type = gr.Radio( | |
choices=["Describe Image", "OCR Text Recognition", "Answer Question", "Detect Objects"], | |
label="Choose Task", | |
value="Describe Image" | |
) | |
question_input = gr.Textbox(label="Question", placeholder="Type a question", visible=False) | |
objects_input = gr.Textbox(label="Objects to Detect", placeholder="e.g., cat; car", visible=False) | |
submit_btn = gr.Button("π Analyze") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Result", lines=10) | |
def update_inputs(task): | |
return { | |
question_input: gr.update(visible=(task == "Answer Question")), | |
objects_input: gr.update(visible=(task == "Detect Objects")) | |
} | |
task_type.change(fn=update_inputs, inputs=[task_type], outputs=[question_input, objects_input]) | |
submit_btn.click(fn=process_image, inputs=[image_input, task_type, question_input, objects_input], outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch(share=True, inbrowser=True) | |