Spaces:
Build error
Build error
File size: 2,155 Bytes
606d9f7 7b87048 606d9f7 7b87048 606d9f7 7b87048 606d9f7 7b87048 606d9f7 7b87048 606d9f7 7b87048 606d9f7 7b87048 606d9f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import gradio as gr
import spaces
from autodistill_grounded_sam_2 import GroundedSAM2
from autodistill_grounded_sam_2.helpers import combine_detections
from autodistill.helpers import load_image
import torch
from autodistill.detection import CaptionOntology
import supervision as sv
import nupmy as np
base_model = GroundedSAM2(
ontology=CaptionOntology({}),
model = "Grounding DINO",
grounding_dino_box_threshold=0.25
)
@spaces.GPU
def greet(image, prompt):
image = load_image(input, return_format="cv2")
if base_model.model == "Florence 2":
detections = base_model.florence_2_predictor.predict(image)
elif base_model.model == "Grounding DINO":
# GroundingDINO predictions
detections_list = []
for i, description in enumerate(prompt.split(",")):
# detect objects
detections = base_model.grounding_dino_model.predict_with_classes(
image=image,
classes=[description],
box_threshold=base_model.grounding_dino_box_threshold,
text_threshold=base_model.grounding_dino_text_threshold,
)
detections_list.append(detections)
detections = combine_detections(
detections_list, overwrite_class_ids=range(len(detections_list))
)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
base_model.sam_2_predictor.set_image(image)
result_masks = []
for box in detections.xyxy:
masks, scores, _ = base_model.sam_2_predictor.predict(
box=box, multimask_output=False
)
index = np.argmax(scores)
masks = masks.astype(bool)
result_masks.append(masks[index])
detections.mask = np.array(result_masks)
results = results[results.confidence > 0.3]
mask_annotator = sv.BoxAnnotator()
annotated_image = mask_annotator.annotate(
image.copy(), detections=results
)
return annotated_image
demo = gr.Interface(fn=greet, inputs=[gr.inputs.Image(), gr.inputs.Textbox(lines=2, label="Prompt")], outputs="image")
demo.launch()
|