EarthLoc2 model
This is the EarthLoc2 model = DINOv2 base with SALAD aggregator out dim = 3072.
Trained on the original EarthLoc dataset (zooms 9,10,11) , in range -60,60 latitude, polar regions not supported.
Training included additional queries which were not part of the test/val sets
Achieves average R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is).
5000 iterations with a batch size of 96, lr = 0.0001, only last block of Dinov2 + aggregator trainable.
To use the prediction of the model, see the FAISS index https://huggingface.co/datasets/pawlo2013/EarthLoc2_FAISS, 2021 database https://huggingface.co/datasets/pawlo2013/EarthLoc_2021_Database, and the inference space https://huggingface.co/spaces/pawlo2013/EarthLoc2.
See EarthLoc for more details about the training, data and use cases https://earthloc-and-earthmatch.github.io/
Model | Average R@1 | Average R@10 | Average R@100 |
---|---|---|---|
EarthLoc | 50.8 | 65.9 | 80.5 |
EarthLoc2 | 79.6 | 90.0 | 95.5 |
Wide world search. Results across evaluation sets when all of the images in the database from 2021 are encoded.
EarthLoc = (ResNet + MixVPR), EarthLoc2 = (DINOv2-B + SALAD-B + Query Data)
Loading and Inspecting the DINOv2 Feature Extractor Model
from model import DINOv2FeatureExtractor
import torch
# Set device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Path to the pretrained weights
MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch'
# Initialize the model
model = DINOv2FeatureExtractor(
model_type="vit_base_patch14_reg4_dinov2.lvd142m",
num_of_layers_to_unfreeze=0,
desc_dim=768,
aggregator_type="SALAD",
)
print('Loading model ...')
# Load weights
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(model_state_dict)
# Move model to device and set to evaluation mode
model = model.to(DEVICE)
model.eval()
print('Model loaded.')
# Print model parameters info
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model total parameters: {num_params:,}")
print(f"Model trainable parameters: {num_trainable:,}")
# Print aggregator type
print(f"Aggregator type: {model.aggregator_type}")
Model tree for pawlo2013/EarthLoc2
Base model
facebook/dinov2-base