File size: 5,200 Bytes
d21fe9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from typing import Dict, List, Tuple, Optional, Literal

import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from torchvision.transforms import ToTensor, Normalize
from rfdetr.util.misc import nested_tensor_from_tensor_list
from rfdetr.models.lwdetr import PostProcess


class RFDetrImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values", "pixel_mask"]

    def __init__(

            self, 

            model_name: Literal['RFDETRBase, RFDETRLarge']='RFDETRBase',

            num_select: int=300,

            image_mean: List[int]=[0.485, 0.456, 0.406],

            image_std: List[int]=[0.229, 0.224, 0.225],

            **kwargs

        ):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.config = {
            'image_mean': image_mean,
            'image_std': image_std,
        }
        self.post_process_config = {
            'num_select': num_select,
        }

    def post_process_object_detection(

            self,

            outputs,

            target_sizes: List[Tuple],

            **kwargs

        ) -> List[Dict[str, torch.Tensor]]:
        """

        Parameters

        ----------

        outputs: 

            outputs from model loaded with AutoModelForObjectDetection or ONNX model

        target_sizes: list[tuple]

            original sizes of the images. 

        """
        if isinstance(outputs, list): ### Handle ONNX outputs
            logits = torch.tensor(outputs[0])
            pred_boxes = torch.tensor(outputs[1])
        else:
            logits = outputs.logits
            pred_boxes = outputs.pred_boxes
        
        outputs = {
            'pred_logits': logits,
            'pred_boxes': pred_boxes,
        }

        # using rfdetr's postprocess class
        post_process = PostProcess(self.post_process_config['num_select'])
        detections = post_process(
            outputs,
            target_sizes=target_sizes,
        )

        return detections

    def convert_and_validate_boxes(self, annotations, images):
        for ann, img in zip(annotations, images):
            # convert from COCO format [x_min, y_min, width, height] to [cx, cy, w, h]
            boxes = ann["boxes"].to(torch.float32)
            boxes[:, [0,1]] += boxes[:, [2,3]] / 2
            ann["boxes"] = boxes

            torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
            torch._assert(
                len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                "Expected target boxes to be a tensor of shape [N, 4].",
            )
            for box in boxes:
                torch._assert(
                    box[2]/2 <= box[0] <= img.shape[2] - box[2]/2 and box[3]/2 <= box[1] <= img.shape[1] - box[3]/2,
                    "Expected w/2 <= x1 <= W - w/2 and h/2 <= cy <= H - h/2.",
                )

    def preprocess(

            self, 

            images, 

            annotations=None,

        ) -> BatchFeature:
        """

        Parameters

        ----------

        images: List[PIL.Image.Image]

            a single or a list of PIL images

        annotations: Optional[List[Dict[str, torch.Tensor | List]]] 

            List of annotations associated with the image or batch of images. If annotation is for object

            detection, the annotations should be a dictionary with the following keys:

            - boxes (FloatTensor[N, 4]): the ground-truth boxes COCO format [x_min, y_min, width, height]

            - class_labels (Int64Tensor[N]): the class label for each ground-truth box

        """
        totensor = ToTensor() 
        normalize = Normalize(mean=self.config['image_mean'], std=self.config['image_std'])

        if images is not None and not isinstance(images, list):
            images = list(images)
        if not isinstance(images[0], torch.Tensor):
            images = [totensor(img) for img in images]  
        if annotations is not None:
            self.convert_and_validate_boxes(annotations, images)

        # get the original image sizes
        original_image_sizes: List[Tuple[int, int]] = []
        for img in images:
            val = img.shape[-2:]
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
            original_image_sizes.append((val[0], val[1])) 
        target_sizes = torch.tensor(original_image_sizes)

        # transform the input
        # normalize image
        images = [normalize(img) for img in images]
        # pad the list of images to make a tensor of size [B, C, H, W] and [B, H, W]
        nested_tensor = nested_tensor_from_tensor_list(images)

        data = {
            'pixel_values': nested_tensor.tensors, 
            'pixel_mask': nested_tensor.mask, 
            'target_sizes': target_sizes, 
            'labels': annotations
        }
        return BatchFeature(data=data)
    
    
__all__ = [
    "RFDetrImageProcessor"
]