| | import open_clip |
| | import torch |
| | import os |
| | import random |
| | import numpy as np |
| | import argparse |
| | from inference_tool import (zeroshot_evaluation, |
| | retrieval_evaluation, |
| | semantic_localization_evaluation, |
| | get_preprocess |
| | ) |
| |
|
| |
|
| | def random_seed(seed): |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | random.seed(seed) |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cudnn.deterministic = False |
| |
|
| |
|
| | def build_model(model_name, ckpt_path, device): |
| | if model_name == "ViT-B-32": |
| | model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai") |
| | checkpoint = torch.load(ckpt_path, map_location="cpu") |
| | msg = model.load_state_dict(checkpoint) |
| |
|
| | elif model_name == "ViT-H-14": |
| | model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k") |
| | checkpoint = torch.load(ckpt_path, map_location="cpu") |
| | msg = model.load_state_dict(checkpoint) |
| |
|
| | print(msg) |
| | model = model.to(device) |
| | print("loaded RSCLIP") |
| |
|
| | preprocess_val = get_preprocess( |
| | image_resolution=224, |
| | ) |
| |
|
| | return model, preprocess_val |
| |
|
| |
|
| | def evaluate(model, preprocess, args): |
| | print("making val dataset with transformation: ") |
| | print(preprocess) |
| | zeroshot_datasets = [ |
| | 'EuroSAT', |
| | 'RESISC45', |
| | 'AID' |
| | ] |
| | selo_datasets = [ |
| | 'AIR-SLT' |
| | ] |
| |
|
| | model.eval() |
| | all_metrics = {} |
| |
|
| | |
| | metrics = {} |
| | for zeroshot_dataset in zeroshot_datasets: |
| | zeroshot_metrics = zeroshot_evaluation(model, zeroshot_dataset, preprocess, args) |
| | metrics.update(zeroshot_metrics) |
| | all_metrics.update(zeroshot_metrics) |
| | print(all_metrics) |
| |
|
| | |
| | metrics = {} |
| | retrieval_metrics_rsitmd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10], |
| | dataset_name="rsitmd") |
| | metrics.update(retrieval_metrics_rsitmd) |
| | all_metrics.update(retrieval_metrics_rsitmd) |
| | print(all_metrics) |
| |
|
| | |
| | metrics = {} |
| | retrieval_metrics_rsicd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10], |
| | dataset_name="rsicd") |
| | metrics.update(retrieval_metrics_rsicd) |
| | all_metrics.update(retrieval_metrics_rsicd) |
| | print(all_metrics) |
| |
|
| | |
| | |
| | metrics = {} |
| | for selo_dataset in selo_datasets: |
| | selo_metrics = semantic_localization_evaluation(model, selo_dataset, preprocess, args) |
| | metrics.update(selo_metrics) |
| | all_metrics.update(selo_metrics) |
| | print(all_metrics) |
| |
|
| | return all_metrics |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model-name", default="ViT-B-32", type=str, |
| | help="ViT-B-32 or ViT-H-14", |
| | ) |
| | parser.add_argument( |
| | "--ckpt-path", default="/home/zilun/RS5M_v5/ckpt/RS5M_ViT-B-32.pt", type=str, |
| | help="Path to RS5M_ViT-B-32.pt", |
| | ) |
| | parser.add_argument( |
| | "--random-seed", default=3407, type=int, |
| | help="random seed", |
| | ) |
| | parser.add_argument( |
| | "--test-dataset-dir", default="/home/zilun/RS5M_v5/data/rs5m_test_data", type=str, |
| | help="test dataset dir", |
| | ) |
| | parser.add_argument( |
| | "--batch-size", default=500, type=int, |
| | help="batch size", |
| | ) |
| | parser.add_argument( |
| | "--workers", default=8, type=int, |
| | help="number of workers", |
| | ) |
| | args = parser.parse_args() |
| | args.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(args) |
| | |
| |
|
| | model, img_preprocess = build_model(args.model_name, args.ckpt_path, args.device) |
| |
|
| | eval_result = evaluate(model, img_preprocess, args) |
| |
|
| | for key, value in eval_result.items(): |
| | print("{}: {}".format(key, value)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|