rf-detr-base / configuration_rf_detr.py
Thastp's picture
Upload model
bb67c7f verified
from typing import Dict, Literal, List, OrderedDict
from transformers.configuration_utils import PretrainedConfig
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py
class RFDetrConfig(PretrainedConfig):
model_type = 'rf-detr'
def __init__(
self,
model_name: Literal['RFDETRBase, RFDETRLarge'] = 'RFDETRBase',
pretrained: bool = False,
out_feature_indexes: List[int] = [2, 5, 8, 11],
dec_layers: int = 3,
two_stage: bool = True,
bbox_reparam: bool = True,
lite_refpoint_refine: bool = True,
layer_norm: bool = True,
amp: bool = True,
num_classes: int = 90,
num_queries: int = 300,
resolution: int = 560,
group_detr: int = 13,
gradient_checkpointing: bool = False,
**kwargs
):
self.model_name = model_name
self.pretrained = pretrained
self.out_feature_indexes = out_feature_indexes
self.dec_layers = dec_layers
self.two_stage = two_stage
self.bbox_reparam = bbox_reparam
self.lite_refpoint_refine = lite_refpoint_refine
self.layer_norm = layer_norm
self.amp = amp
self.num_classes = num_classes
self.resolution = resolution
self.group_detr = group_detr
self.gradient_checkpointing = gradient_checkpointing
self.num_queries = num_queries
if self.model_name == 'RFDETRBase':
self.encoder = "dinov2_windowed_small"
self.hidden_dim = 256
self.sa_nheads = 8
self.ca_nheads = 16
self.dec_n_points = 2
self.projector_scale = ["P4"]
self.pretrain_weights = "rf-detr-base.pth"
elif self.model_name == 'RFDETRLarge':
self.encoder = "dinov2_windowed_base"
self.hidden_dim = 384
self.sa_nheads = 12
self.ca_nheads = 24
self.dec_n_points = 4
self.projector_scale = ["P3", "P5"]
self.pretrain_weights = "rf-detr-large.pth"
if not self.pretrained:
self.pretrain_weights = None
super().__init__(**kwargs)
class RFDetrOnnxConfig(ViTOnnxConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return OrderedDict(
{
"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
}
)
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self.task == "object-detection":
common_outputs["logits"] = {0: "batch_size", 1: "num_queries", 2: "num_classes"}
common_outputs["pred_boxes"] = {0: "batch_size", 1: "num_queries", 2: "4"}
return common_outputs
__all__ = [
'RFDetrConfig',
'RFDetrOnnxConfig'
]