Upload model
#4
by
mranzinger
- opened
- extra_timm_models.py +20 -0
- hf_model.py +2 -0
extra_timm_models.py
CHANGED
|
@@ -97,6 +97,26 @@ def vit_so400m_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
| 97 |
return model
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
@register_model
|
| 101 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 102 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
| 97 |
return model
|
| 98 |
|
| 99 |
|
| 100 |
+
@register_model
|
| 101 |
+
def vit_so400m_v2_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 102 |
+
""" ViT model matching the architecture of the So400M model from
|
| 103 |
+
"Scaling Vision Transformers to 400 Million Parameters" (https://arxiv.org/abs/2302.05442).
|
| 104 |
+
"""
|
| 105 |
+
if pretrained:
|
| 106 |
+
raise ValueError('There is no pretrained weights for vit_so400m_patch16_224')
|
| 107 |
+
|
| 108 |
+
normal_target = 4304
|
| 109 |
+
# TP4 requires channels to be a multiple of 4, and then within that, FP8 requires a multiple of 8,
|
| 110 |
+
# thus, a multiple of 32 is required.
|
| 111 |
+
tp4_fp8_safe_target = ((normal_target + 31) // 32) * 32
|
| 112 |
+
|
| 113 |
+
mlp_ratio = tp4_fp8_safe_target / 1152
|
| 114 |
+
|
| 115 |
+
model_args = dict(patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=mlp_ratio)
|
| 116 |
+
model = _create_vision_transformer('vit_so400m_v2_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
@register_model
|
| 121 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 122 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
hf_model.py
CHANGED
|
@@ -101,6 +101,8 @@ class RADIOModel(PreTrainedModel):
|
|
| 101 |
|
| 102 |
def __init__(self, config: RADIOConfig):
|
| 103 |
super().__init__(config)
|
|
|
|
|
|
|
| 104 |
|
| 105 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
| 106 |
args = RADIOArgs(**config.args)
|
|
|
|
| 101 |
|
| 102 |
def __init__(self, config: RADIOConfig):
|
| 103 |
super().__init__(config)
|
| 104 |
+
if hasattr(super(), "post_init"):
|
| 105 |
+
super().post_init()
|
| 106 |
|
| 107 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
| 108 |
args = RADIOArgs(**config.args)
|