|
from typing import Dict, Literal, List, OrderedDict
|
|
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
|
|
|
|
|
|
|
class RFDetrConfig(PretrainedConfig):
|
|
model_type = 'rf-detr'
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: Literal['RFDETRBase, RFDETRLarge'] = 'RFDETRBase',
|
|
pretrained: bool = False,
|
|
out_feature_indexes: List[int] = [2, 5, 8, 11],
|
|
dec_layers: int = 3,
|
|
two_stage: bool = True,
|
|
bbox_reparam: bool = True,
|
|
lite_refpoint_refine: bool = True,
|
|
layer_norm: bool = True,
|
|
amp: bool = True,
|
|
num_classes: int = 90,
|
|
num_queries: int = 300,
|
|
resolution: int = 560,
|
|
group_detr: int = 13,
|
|
gradient_checkpointing: bool = False,
|
|
**kwargs
|
|
):
|
|
self.model_name = model_name
|
|
self.pretrained = pretrained
|
|
self.out_feature_indexes = out_feature_indexes
|
|
self.dec_layers = dec_layers
|
|
self.two_stage = two_stage
|
|
self.bbox_reparam = bbox_reparam
|
|
self.lite_refpoint_refine = lite_refpoint_refine
|
|
self.layer_norm = layer_norm
|
|
self.amp = amp
|
|
self.num_classes = num_classes
|
|
self.resolution = resolution
|
|
self.group_detr = group_detr
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
self.num_queries = num_queries
|
|
if self.model_name == 'RFDETRBase':
|
|
self.encoder = "dinov2_windowed_small"
|
|
self.hidden_dim = 256
|
|
self.sa_nheads = 8
|
|
self.ca_nheads = 16
|
|
self.dec_n_points = 2
|
|
self.projector_scale = ["P4"]
|
|
self.pretrain_weights = "rf-detr-base.pth"
|
|
elif self.model_name == 'RFDETRLarge':
|
|
self.encoder = "dinov2_windowed_base"
|
|
self.hidden_dim = 384
|
|
self.sa_nheads = 12
|
|
self.ca_nheads = 24
|
|
self.dec_n_points = 4
|
|
self.projector_scale = ["P3", "P5"]
|
|
self.pretrain_weights = "rf-detr-large.pth"
|
|
if not self.pretrained:
|
|
self.pretrain_weights = None
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class RFDetrOnnxConfig(ViTOnnxConfig):
|
|
|
|
@property
|
|
def inputs(self) -> Dict[str, Dict[int, str]]:
|
|
return OrderedDict(
|
|
{
|
|
"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
|
|
}
|
|
)
|
|
|
|
@property
|
|
def outputs(self) -> Dict[str, Dict[int, str]]:
|
|
common_outputs = super().outputs
|
|
|
|
if self.task == "object-detection":
|
|
common_outputs["logits"] = {0: "batch_size", 1: "num_queries", 2: "num_classes"}
|
|
common_outputs["pred_boxes"] = {0: "batch_size", 1: "num_queries", 2: "4"}
|
|
|
|
return common_outputs
|
|
|
|
|
|
__all__ = [
|
|
'RFDetrConfig',
|
|
'RFDetrOnnxConfig'
|
|
] |