from dataclasses import dataclass from typing import List, Dict import torch from torchvision.transforms import Resize from transformers import PreTrainedModel from transformers.utils import ModelOutput, torch_int from rfdetr import RFDETRBase, RFDETRLarge from rfdetr.util.misc import NestedTensor from .configuration_rf_detr import RFDetrConfig ### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11 @dataclass class RFDetrObjectDetectionOutput(ModelOutput): loss: torch.Tensor = None loss_dict: Dict[str, torch.Tensor] = None logits: torch.FloatTensor = None pred_boxes: torch.FloatTensor = None aux_outputs: List[Dict[str, torch.Tensor]] = None enc_outputs: Dict[str, torch.Tensor] = None class RFDetrModelForObjectDetection(PreTrainedModel): config_class = RFDetrConfig def __init__(self, config): super().__init__(config) self.config = config models = { 'RFDETRBase': RFDETRBase, 'RFDETRLarge': RFDETRLarge, } rf_detr_model = models[config.model_name]( out_feature_indexes = config.out_feature_indexes, dec_layers = config.dec_layers, two_stage = config.two_stage, bbox_reparam = config.bbox_reparam, lite_refpoint_refine = config.lite_refpoint_refine, layer_norm = config.layer_norm, amp = config.amp, num_classes = config.num_classes, resolution = config.resolution, group_detr = config.group_detr, gradient_checkpointing = config.gradient_checkpointing, num_queries = config.num_queries, encoder = config.encoder, hidden_dim = config.hidden_dim, sa_nheads = config.sa_nheads, ca_nheads = config.ca_nheads, dec_n_points = config.dec_n_points, projector_scale = config.projector_scale, pretrain_weights = config.pretrain_weights, ) self.model = rf_detr_model.model.model self.criterion = rf_detr_model.model.criterion def compute_loss(self, outputs, labels=None): """ Parameters ---------- labels: list[Dict[str, torch.Tensor]] list of bounding boxes and labels for each image in the batch. outputs: outputs from rfdetr model """ loss = None loss_dict = None #if self.model.training: if labels is None: #torch._assert(False, "targets should not be none when in training mode") pass else: losses = self.criterion(outputs, targets=labels) loss_dict = { 'loss_fl': losses["loss_ce"], ### class error and cardinality error is for logging purposes only, no back propagation 'class_error': losses["class_error"], 'cardinality_error': losses["cardinality_error"], 'loss_bbox': losses["loss_bbox"], 'loss_giou': losses["loss_giou"], } loss = sum(loss_dict[k] for k in ['loss_fl', 'loss_bbox', 'loss_giou']) return loss, loss_dict def validate_labels(self, labels): # Check for degenerate boxes for label_idx, label in enumerate(labels): boxes = label["boxes"] degenerate_boxes = boxes[:, 2:] <= 0 if degenerate_boxes.any(): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() torch._assert( False, "All bounding boxes should have positive height and width." f" Found invalid box {degen_bb} for target at index {label_idx}.", ) # rename key class_labels to labels for compute_loss if 'class_labels' in label.keys(): label['labels'] = label.pop('class_labels') def resize_labels(self, labels, h, w): """ Resize boxes coordinates to model's resolution """ hr = self.config.resolution / float(h) wr = self.config.resolution / float(w) for label in labels: boxes = label["boxes"] # resize boxes to model's resolution boxes[:, [0, 2]] *= wr boxes[:, [1, 3]] *= hr # normalize to [0, 1] by model's resolution boxes[:] /= self.config.resolution label["boxes"] = boxes ### modified from https://github.com/roboflow/rf-detr/blob/develop/rfdetr/models/backbone/dinov2_with_windowed_attn.py def _onnx_interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility with the original implementation. Adapted from: - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py """ position_embeddings = self.model.backbone[0].encoder.encoder.embeddings.position_embeddings config = self.model.backbone[0].encoder.encoder.embeddings.config num_patches = embeddings.shape[1] - 1 num_positions = position_embeddings.shape[1] - 1 # Skip interpolation for matching dimensions (unless tracing) if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return position_embeddings # Handle class token and patch embeddings separately class_pos_embed = position_embeddings[:, 0] patch_pos_embed = position_embeddings[:, 1:] dim = embeddings.shape[-1] # Calculate new dimensions height = height // config.patch_size width = width // config.patch_size # Reshape for interpolation sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # Store original dtype for restoration after interpolation target_dtype = patch_pos_embed.dtype # Interpolate at float32 precision ### disable antialiasing for ONNX export patch_pos_embed = torch.nn.functional.interpolate( patch_pos_embed.to(dtype=torch.float32), size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor mode="bicubic", align_corners=False, antialias=False, ).to(dtype=target_dtype) # Validate output dimensions if not tracing if not torch.jit.is_tracing(): if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: raise ValueError("Width or height does not match with the interpolated position embeddings") # Reshape back to original format patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) # Combine class and patch embeddings return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor=None, labels=None, **kwargs) -> ModelOutput: """ Parameters ---------- pixel_values : torch.Tensor Input tensor representing image pixel values. labels : Optional[List[Dict[str, torch.Tensor | List]]] List of annotations associated with the image or batch of images. If annotation is for object detection, the annotations should be a dictionary with the following keys: - boxes (FloatTensor[N, 4]): the ground-truth boxes in format [center_x, center_y, width, height] - class_labels (Int64Tensor[N]): the class label for each ground-truth box Returns ------- RFDetrObjectDetectionOutput Object containing - loss: sum of focal loss, bounding box loss, and generalized iou loss - loss_dict: dictionary of losses - logits - pred_boxes - aux_outputs - enc_outputs """ if torch.jit.is_tracing(): ### disable antialiasing for ONNX export resize = Resize((self.config.resolution, self.config.resolution), antialias=False) self.model.backbone[0].encoder.encoder.embeddings.interpolate_pos_encoding = self._onnx_interpolate_pos_encoding else: resize = Resize((self.config.resolution, self.config.resolution)) if labels is not None: self.validate_labels(labels) _,_,h,w = pixel_values.shape self.resize_labels(labels, h, w) # reshape labels with model's resolution else: self.model.training = False self.model.transformer.training = False for layer in self.model.transformer.decoder.layers: layer.training = False self.criterion.training = False # resize pixel values and mask to model's resolution pixel_values = resize(pixel_values) if pixel_mask is None: pixel_mask = torch.zeros([pixel_values.shape[0], self.config.resolution, self.config.resolution], dtype=torch.bool) else: pixel_mask = resize(pixel_mask) samples = NestedTensor(pixel_values, pixel_mask) outputs = self.model(samples) # compute loss, return none and empty dict if not training loss, loss_dict = self.compute_loss(outputs, labels) return RFDetrObjectDetectionOutput( loss=loss, loss_dict=loss_dict, logits=outputs["pred_logits"], pred_boxes=outputs["pred_boxes"], aux_outputs=outputs["aux_outputs"], enc_outputs=outputs["enc_outputs"], ) __all__ = [ "RFDetrModelForObjectDetection" ]