florence_sam / app.py
rayuga2503's picture
Create app.py
93c21cd verified
raw
history blame
2.53 kB
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, SamModel, SamProcessor
import time
# Set device to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Florence BLIP model (Public Model - No Authentication Required)
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# Load SAM model (Public Model - No Authentication Required)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
def process_image(image):
start_time = time.time()
# Convert and resize image
pil_image = Image.fromarray(image).resize((512, 512)) # Resize to optimize processing
print("βœ… Image loaded and resized.")
# Generate caption using Florence
try:
inputs = processor(pil_image, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(**inputs)
description = processor.decode(out[0], skip_special_tokens=True)
print(f"πŸ“ Florence Captioning done in {time.time() - start_time:.2f} sec")
except Exception as e:
print(f"❌ Error in Florence: {e}")
return "Failed to generate description.", image
# Process Image for SAM
try:
encoding = sam_processor(images=pil_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = sam_model(**encoding)
# Extract segmentation mask
mask = outputs.pred_masks[0, 0].cpu().numpy()
mask_overlay = image.copy()
mask_overlay[mask > 0.5] = [0, 255, 0] # Green overlay for segmentation
print(f"🎨 SAM Segmentation done in {time.time() - start_time:.2f} sec")
except Exception as e:
print(f"❌ Error in SAM: {e}")
return description, image
return description, mask_overlay
# Gradio Interface
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="numpy"),
outputs=[gr.Textbox(label="Image Description"), gr.Image(label="Segmented Image")],
title="Florence + SAM Image Processing",
description="Upload an image to get its description using Florence and segmentation using SAM (loaded from Hugging Face)."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)