|
--- |
|
license: mit |
|
datasets: |
|
- pawlo2013/EarthLoc_2021_Database |
|
base_model: |
|
- facebook/dinov2-base |
|
pipeline_tag: image-feature-extraction |
|
--- |
|
|
|
|
|
## EarthLoc2 model |
|
|
|
This is the EarthLoc2 model = DINOv2 base with SALAD aggregator out dim = 3072. |
|
|
|
Trained on the original EarthLoc dataset (zooms 9,10,11) , in range -60,60 latitude, polar regions not supported. |
|
|
|
Training included additional queries which were not part of the test/val sets |
|
|
|
|
|
Achieves average R@10 = 90.6 on the original EarthLoc test and val sets (when retrieving against whole db as is). |
|
|
|
5000 iterations with a batch size of 96, lr = 0.0001, only last block of Dinov2 + aggregator trainable. |
|
|
|
To use the prediction of the model, see the FAISS index https://huggingface.co/datasets/pawlo2013/EarthLoc2_FAISS, 2021 database https://huggingface.co/datasets/pawlo2013/EarthLoc_2021_Database, |
|
and the inference space https://huggingface.co/spaces/pawlo2013/EarthLoc2. |
|
|
|
See EarthLoc for more details about the training, data and use cases https://earthloc-and-earthmatch.github.io/ |
|
|
|
|
|
|
|
| **Model** | **Average R@1** | **Average R@10** | **Average R@100** | |
|
|------------|-----------------|------------------|-------------------| |
|
| EarthLoc | 50.8 | 65.9 | 80.5 | |
|
| **EarthLoc2** | **79.6** | **90.0** | **95.5** | |
|
|
|
*Wide world search. Results across evaluation sets when all of the images in the database from 2021 are encoded.* |
|
*EarthLoc = (ResNet + MixVPR), EarthLoc2 = (DINOv2-B + SALAD-B + Query Data)* |
|
|
|
|
|
|
|
## Loading and Inspecting the DINOv2 Feature Extractor Model |
|
|
|
```python |
|
from model import DINOv2FeatureExtractor |
|
import torch |
|
|
|
# Set device |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
# Path to the pretrained weights |
|
MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch' |
|
|
|
# Initialize the model |
|
model = DINOv2FeatureExtractor( |
|
model_type="vit_base_patch14_reg4_dinov2.lvd142m", |
|
num_of_layers_to_unfreeze=0, |
|
desc_dim=768, |
|
aggregator_type="SALAD", |
|
) |
|
|
|
print('Loading model ...') |
|
# Load weights |
|
model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE) |
|
model.load_state_dict(model_state_dict) |
|
|
|
# Move model to device and set to evaluation mode |
|
model = model.to(DEVICE) |
|
model.eval() |
|
print('Model loaded.') |
|
|
|
# Print model parameters info |
|
num_params = sum(p.numel() for p in model.parameters()) |
|
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(f"Model total parameters: {num_params:,}") |
|
print(f"Model trainable parameters: {num_trainable:,}") |
|
|
|
# Print aggregator type |
|
print(f"Aggregator type: {model.aggregator_type}") |