File size: 2,155 Bytes
606d9f7
 
 
7b87048
 
 
 
 
 
606d9f7
7b87048
 
 
 
 
606d9f7
 
7b87048
 
606d9f7
7b87048
 
 
 
 
 
 
 
 
 
 
 
 
 
606d9f7
7b87048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606d9f7
 
 
7b87048
 
 
606d9f7
 
 
7b87048
606d9f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import spaces
from autodistill_grounded_sam_2 import GroundedSAM2
from autodistill_grounded_sam_2.helpers import combine_detections
from autodistill.helpers import load_image
import torch
from autodistill.detection import CaptionOntology
import supervision as sv
import nupmy as np

base_model = GroundedSAM2(
    ontology=CaptionOntology({}),
    model = "Grounding DINO",
    grounding_dino_box_threshold=0.25
)

@spaces.GPU
def greet(image, prompt):
    image = load_image(input, return_format="cv2")

    if base_model.model == "Florence 2":
        detections = base_model.florence_2_predictor.predict(image)
    elif base_model.model == "Grounding DINO":
        # GroundingDINO predictions
        detections_list = []

        for i, description in enumerate(prompt.split(",")):
            # detect objects
            detections = base_model.grounding_dino_model.predict_with_classes(
                image=image,
                classes=[description],
                box_threshold=base_model.grounding_dino_box_threshold,
                text_threshold=base_model.grounding_dino_text_threshold,
            )

            detections_list.append(detections)

        detections = combine_detections(
            detections_list, overwrite_class_ids=range(len(detections_list))
        )

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        base_model.sam_2_predictor.set_image(image)
        result_masks = []
        for box in detections.xyxy:
            masks, scores, _ = base_model.sam_2_predictor.predict(
                box=box, multimask_output=False
            )
            index = np.argmax(scores)
            masks = masks.astype(bool)
            result_masks.append(masks[index])

    detections.mask = np.array(result_masks)
    results = results[results.confidence > 0.3]

    mask_annotator = sv.BoxAnnotator()

    annotated_image = mask_annotator.annotate(
        image.copy(), detections=results
    )

    return annotated_image

demo = gr.Interface(fn=greet, inputs=[gr.inputs.Image(), gr.inputs.Textbox(lines=2, label="Prompt")], outputs="image")
demo.launch()