File size: 5,200 Bytes
d21fe9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from typing import Dict, List, Tuple, Optional, Literal
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from torchvision.transforms import ToTensor, Normalize
from rfdetr.util.misc import nested_tensor_from_tensor_list
from rfdetr.models.lwdetr import PostProcess
class RFDetrImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values", "pixel_mask"]
def __init__(
self,
model_name: Literal['RFDETRBase, RFDETRLarge']='RFDETRBase',
num_select: int=300,
image_mean: List[int]=[0.485, 0.456, 0.406],
image_std: List[int]=[0.229, 0.224, 0.225],
**kwargs
):
super().__init__(**kwargs)
self.model_name = model_name
self.config = {
'image_mean': image_mean,
'image_std': image_std,
}
self.post_process_config = {
'num_select': num_select,
}
def post_process_object_detection(
self,
outputs,
target_sizes: List[Tuple],
**kwargs
) -> List[Dict[str, torch.Tensor]]:
"""
Parameters
----------
outputs:
outputs from model loaded with AutoModelForObjectDetection or ONNX model
target_sizes: list[tuple]
original sizes of the images.
"""
if isinstance(outputs, list): ### Handle ONNX outputs
logits = torch.tensor(outputs[0])
pred_boxes = torch.tensor(outputs[1])
else:
logits = outputs.logits
pred_boxes = outputs.pred_boxes
outputs = {
'pred_logits': logits,
'pred_boxes': pred_boxes,
}
# using rfdetr's postprocess class
post_process = PostProcess(self.post_process_config['num_select'])
detections = post_process(
outputs,
target_sizes=target_sizes,
)
return detections
def convert_and_validate_boxes(self, annotations, images):
for ann, img in zip(annotations, images):
# convert from COCO format [x_min, y_min, width, height] to [cx, cy, w, h]
boxes = ann["boxes"].to(torch.float32)
boxes[:, [0,1]] += boxes[:, [2,3]] / 2
ann["boxes"] = boxes
torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
torch._assert(
len(boxes.shape) == 2 and boxes.shape[-1] == 4,
"Expected target boxes to be a tensor of shape [N, 4].",
)
for box in boxes:
torch._assert(
box[2]/2 <= box[0] <= img.shape[2] - box[2]/2 and box[3]/2 <= box[1] <= img.shape[1] - box[3]/2,
"Expected w/2 <= x1 <= W - w/2 and h/2 <= cy <= H - h/2.",
)
def preprocess(
self,
images,
annotations=None,
) -> BatchFeature:
"""
Parameters
----------
images: List[PIL.Image.Image]
a single or a list of PIL images
annotations: Optional[List[Dict[str, torch.Tensor | List]]]
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- boxes (FloatTensor[N, 4]): the ground-truth boxes COCO format [x_min, y_min, width, height]
- class_labels (Int64Tensor[N]): the class label for each ground-truth box
"""
totensor = ToTensor()
normalize = Normalize(mean=self.config['image_mean'], std=self.config['image_std'])
if images is not None and not isinstance(images, list):
images = list(images)
if not isinstance(images[0], torch.Tensor):
images = [totensor(img) for img in images]
if annotations is not None:
self.convert_and_validate_boxes(annotations, images)
# get the original image sizes
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
val = img.shape[-2:]
torch._assert(
len(val) == 2,
f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
)
original_image_sizes.append((val[0], val[1]))
target_sizes = torch.tensor(original_image_sizes)
# transform the input
# normalize image
images = [normalize(img) for img in images]
# pad the list of images to make a tensor of size [B, C, H, W] and [B, H, W]
nested_tensor = nested_tensor_from_tensor_list(images)
data = {
'pixel_values': nested_tensor.tensors,
'pixel_mask': nested_tensor.mask,
'target_sizes': target_sizes,
'labels': annotations
}
return BatchFeature(data=data)
__all__ = [
"RFDetrImageProcessor"
] |