| import gradio as gr |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForCausalLM |
|
|
| |
| device = "cpu" |
|
|
| def load_vlm(model_name): |
| """Helper to load model and processor.""" |
| try: |
| print(f"Loading {model_name}...") |
| model = AutoModelForCausalLM.from_pretrained( |
| f'microsoft/{model_name}', |
| trust_remote_code=True |
| ).to(device).eval() |
| processor = AutoProcessor.from_pretrained( |
| f'microsoft/{model_name}', |
| trust_remote_code=True |
| ) |
| return model, processor |
| except Exception as e: |
| print(f"Error loading {model_name}: {e}") |
| return None, None |
|
|
| |
| model_base, proc_base = load_vlm('Florence-2-base') |
| model_large, proc_large = load_vlm('Florence-2-large') |
|
|
| def describe_image(uploaded_image, model_choice): |
| if uploaded_image is None: |
| return "Please upload an image." |
|
|
| |
| if model_choice == "Florence-2-base": |
| model, processor = model_base, proc_base |
| else: |
| model, processor = model_large, proc_large |
|
|
| if model is None: |
| return f"{model_choice} failed to load." |
|
|
| if not isinstance(uploaded_image, Image.Image): |
| uploaded_image = Image.fromarray(uploaded_image) |
|
|
| |
| inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device) |
| |
| with torch.no_grad(): |
| generated_ids = model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| num_beams=3, |
| do_sample=False, |
| ) |
|
|
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| result = processor.post_process_generation( |
| generated_text, |
| task="<MORE_DETAILED_CAPTION>", |
| image_size=(uploaded_image.width, uploaded_image.height) |
| ) |
| |
| return result["<MORE_DETAILED_CAPTION>"] |
|
|
| |
| css = ".submit-btn { background-color: #4682B4 !important; color: white !important; }" |
|
|
| with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: |
| gr.Markdown("# **Florence-2 Models Image Captions**") |
| gr.Markdown("> Select the model to use. **Base** is faster; **Large** is more accurate.") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(label="Upload Image", type="pil") |
| model_choice = gr.Radio( |
| choices=["Florence-2-base", "Florence-2-large"], |
| label="Model Choice", |
| value="Florence-2-base" |
| ) |
| generate_btn = gr.Button("Generate Caption", elem_classes="submit-btn") |
| |
| with gr.Column(): |
| output = gr.Textbox(label="Generated Caption", lines=6, interactive=True) |
|
|
| generate_btn.click( |
| fn=describe_image, |
| inputs=[image_input, model_choice], |
| outputs=output |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |