EarthLoc2 / README.md
pawlo2013's picture
Update README.md
8d3cfa8 verified
---
license: mit
datasets:
- pawlo2013/EarthLoc_2021_Database
base_model:
- facebook/dinov2-base
pipeline_tag: image-feature-extraction
---
## EarthLoc2 model
This is the EarthLoc2 model = DINOv2 base with SALAD aggregator out dim = 3072.
Trained on the original EarthLoc dataset (zooms 9,10,11) , in range -60,60 latitude, polar regions not supported.
Training included additional queries which were not part of the test/val sets
Achieves average R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is).
5000 iterations with a batch size of 96, lr = 0.0001, only last block of Dinov2 + aggregator trainable.
To use the prediction of the model, see the FAISS index https://huggingface.co/datasets/pawlo2013/EarthLoc2_FAISS, 2021 database https://huggingface.co/datasets/pawlo2013/EarthLoc_2021_Database,
and the inference space https://huggingface.co/spaces/pawlo2013/EarthLoc2.
See EarthLoc for more details about the training, data and use cases https://earthloc-and-earthmatch.github.io/
| **Model** | **Average R@1** | **Average R@10** | **Average R@100** |
|------------|-----------------|------------------|-------------------|
| EarthLoc | 50.8 | 65.9 | 80.5 |
| **EarthLoc2** | **79.6** | **90.0** | **95.5** |
*Wide world search. Results across evaluation sets when all of the images in the database from 2021 are encoded.*
*EarthLoc = (ResNet + MixVPR), EarthLoc2 = (DINOv2-B + SALAD-B + Query Data)*
## Loading and Inspecting the DINOv2 Feature Extractor Model
```python
from model import DINOv2FeatureExtractor
import torch
# Set device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Path to the pretrained weights
MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch'
# Initialize the model
model = DINOv2FeatureExtractor(
model_type="vit_base_patch14_reg4_dinov2.lvd142m",
num_of_layers_to_unfreeze=0,
desc_dim=768,
aggregator_type="SALAD",
)
print('Loading model ...')
# Load weights
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(model_state_dict)
# Move model to device and set to evaluation mode
model = model.to(DEVICE)
model.eval()
print('Model loaded.')
# Print model parameters info
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model total parameters: {num_params:,}")
print(f"Model trainable parameters: {num_trainable:,}")
# Print aggregator type
print(f"Aggregator type: {model.aggregator_type}")