from model import DINOv2FeatureExtractor import torch DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch' model = DINOv2FeatureExtractor( model_type="vit_base_patch14_reg4_dinov2.lvd142m", num_of_layers_to_unfreeze=0, desc_dim=768, aggregator_type="SALAD", ) print('loading model ... ') model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) model.load_state_dict(model_state_dict) model = model.to(DEVICE) model.eval() print('loaded ....') # Move to device model.to(DEVICE) # Print some info about model weights num_params = sum(p.numel() for p in model.parameters()) num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model total parameters: {num_params:,}") print(f"Model trainable parameters: {num_trainable:,}") print(model.aggregator_type)