Feature Extraction
Transformers
Safetensors
custom_code
Files changed (2) hide show
  1. extra_timm_models.py +20 -0
  2. hf_model.py +2 -0
extra_timm_models.py CHANGED
@@ -97,6 +97,26 @@ def vit_so400m_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
97
  return model
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @register_model
101
  def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
102
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
 
97
  return model
98
 
99
 
100
+ @register_model
101
+ def vit_so400m_v2_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
102
+ """ ViT model matching the architecture of the So400M model from
103
+ "Scaling Vision Transformers to 400 Million Parameters" (https://arxiv.org/abs/2302.05442).
104
+ """
105
+ if pretrained:
106
+ raise ValueError('There is no pretrained weights for vit_so400m_patch16_224')
107
+
108
+ normal_target = 4304
109
+ # TP4 requires channels to be a multiple of 4, and then within that, FP8 requires a multiple of 8,
110
+ # thus, a multiple of 32 is required.
111
+ tp4_fp8_safe_target = ((normal_target + 31) // 32) * 32
112
+
113
+ mlp_ratio = tp4_fp8_safe_target / 1152
114
+
115
+ model_args = dict(patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=mlp_ratio)
116
+ model = _create_vision_transformer('vit_so400m_v2_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
117
+ return model
118
+
119
+
120
  @register_model
121
  def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
122
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
hf_model.py CHANGED
@@ -101,6 +101,8 @@ class RADIOModel(PreTrainedModel):
101
 
102
  def __init__(self, config: RADIOConfig):
103
  super().__init__(config)
 
 
104
 
105
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
106
  args = RADIOArgs(**config.args)
 
101
 
102
  def __init__(self, config: RADIOConfig):
103
  super().__init__(config)
104
+ if hasattr(super(), "post_init"):
105
+ super().post_init()
106
 
107
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
108
  args = RADIOArgs(**config.args)