|
from model import DINOv2FeatureExtractor |
|
import torch |
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch' |
|
|
|
|
|
model = DINOv2FeatureExtractor( |
|
model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
|
num_of_layers_to_unfreeze=0, |
|
desc_dim=768, |
|
aggregator_type="SALAD", |
|
) |
|
print('loading model ... ') |
|
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) |
|
model.load_state_dict(model_state_dict) |
|
model = model.to(DEVICE) |
|
model.eval() |
|
print('loaded ....') |
|
|
|
|
|
|
|
|
|
model.to(DEVICE) |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"Model total parameters: {num_params:,}") |
|
print(f"Model trainable parameters: {num_trainable:,}") |
|
|
|
print(model.aggregator_type) |
|
|
|
|
|
|