Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import PIL.Image | |
import spaces | |
import torch | |
from sahi.prediction import ObjectPrediction | |
from sahi.utils.cv import visualize_object_predictions | |
from transformers import AutoImageProcessor, DetaForObjectDetection | |
DESCRIPTION = "# DETA (Detection Transformers with Assignment)" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
MODEL_ID = "jozhang97/deta-swin-large" | |
image_processor = AutoImageProcessor.from_pretrained(MODEL_ID) | |
model = DetaForObjectDetection.from_pretrained(MODEL_ID) | |
model.to(device) | |
def run(image_path: str, threshold: float) -> np.ndarray: | |
image = PIL.Image.open(image_path) | |
inputs = image_processor(images=image, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = image_processor.post_process_object_detection(outputs, threshold=threshold, target_sizes=target_sizes)[0] | |
boxes = results["boxes"].cpu().numpy() | |
scores = results["scores"].cpu().numpy() | |
cat_ids = results["labels"].cpu().numpy().tolist() | |
preds = [] | |
for box, score, cat_id in zip(boxes, scores, cat_ids, strict=True): | |
box_int = np.round(box).astype(int) | |
cat_label = model.config.id2label[cat_id] | |
pred = ObjectPrediction(bbox=box_int, category_id=cat_id, category_name=cat_label, score=score) | |
preds.append(pred) | |
return visualize_object_predictions(np.asarray(image), preds)["image"] | |
with gr.Blocks(css_paths="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Input image", type="filepath") | |
threshold = gr.Slider(label="Score threshold", minimum=0, maximum=1, step=0.01, value=0.1) | |
run_button = gr.Button() | |
result = gr.Image(label="Result") | |
gr.Examples( | |
examples=[[path, 0.1] for path in sorted(pathlib.Path("images").glob("*.jpg"))], | |
inputs=[image, threshold], | |
outputs=result, | |
fn=run, | |
) | |
run_button.click( | |
fn=run, | |
inputs=[image, threshold], | |
outputs=result, | |
api_name="predict", | |
) | |
if __name__ == "__main__": | |
demo.launch() | |