from typing import Dict, Literal, List, OrderedDict from transformers.configuration_utils import PretrainedConfig from optimum.exporters.onnx.model_configs import ViTOnnxConfig ### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py class RFDetrConfig(PretrainedConfig): model_type = 'rf-detr' def __init__( self, model_name: Literal['RFDETRBase, RFDETRLarge'] = 'RFDETRBase', pretrained: bool = False, out_feature_indexes: List[int] = [2, 5, 8, 11], dec_layers: int = 3, two_stage: bool = True, bbox_reparam: bool = True, lite_refpoint_refine: bool = True, layer_norm: bool = True, amp: bool = True, num_classes: int = 90, num_queries: int = 300, resolution: int = 560, group_detr: int = 13, gradient_checkpointing: bool = False, **kwargs ): self.model_name = model_name self.pretrained = pretrained self.out_feature_indexes = out_feature_indexes self.dec_layers = dec_layers self.two_stage = two_stage self.bbox_reparam = bbox_reparam self.lite_refpoint_refine = lite_refpoint_refine self.layer_norm = layer_norm self.amp = amp self.num_classes = num_classes self.resolution = resolution self.group_detr = group_detr self.gradient_checkpointing = gradient_checkpointing self.num_queries = num_queries if self.model_name == 'RFDETRBase': self.encoder = "dinov2_windowed_small" self.hidden_dim = 256 self.sa_nheads = 8 self.ca_nheads = 16 self.dec_n_points = 2 self.projector_scale = ["P4"] self.pretrain_weights = "rf-detr-base.pth" elif self.model_name == 'RFDETRLarge': self.encoder = "dinov2_windowed_base" self.hidden_dim = 384 self.sa_nheads = 12 self.ca_nheads = 24 self.dec_n_points = 4 self.projector_scale = ["P3", "P5"] self.pretrain_weights = "rf-detr-large.pth" if not self.pretrained: self.pretrain_weights = None super().__init__(**kwargs) class RFDetrOnnxConfig(ViTOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: return OrderedDict( { "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, } ) @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs if self.task == "object-detection": common_outputs["logits"] = {0: "batch_size", 1: "num_queries", 2: "num_classes"} common_outputs["pred_boxes"] = {0: "batch_size", 1: "num_queries", 2: "4"} return common_outputs __all__ = [ 'RFDetrConfig', 'RFDetrOnnxConfig' ]