import os import argparse from huggingface_hub import snapshot_download # Model configurations for EdgeFace models model_configs = { "edgeface_base": { "repo": "idiap/EdgeFace-Base", "filename": "edgeface_base.pt", "local_dir": "ckpts/idiap" }, "edgeface_s_gamma_05": { "repo": "idiap/EdgeFace-S-GAMMA", "filename": "edgeface_s_gamma_05.pt", "local_dir": "ckpts/idiap" }, "edgeface_xs_gamma_06": { "repo": "idiap/EdgeFace-XS-GAMMA", "filename": "edgeface_xs_gamma_06.pt", "local_dir": "ckpts/idiap" }, "edgeface_xxs": { "repo": "idiap/EdgeFace-XXS", "filename": "edgeface_xxs.pt", "local_dir": "ckpts/idiap" }, "SlimFace_efficientnet_b3": { "repo": "danhtran2mind/SlimFace-sample-checkpoints", "filename": "SlimFace_efficientnet_b3_full_model.pth", "local_dir": "ckpts" }, "SlimFace_efficientnet_v2_s": { "repo": "danhtran2mind/SlimFace-sample-checkpoints", "filename": "SlimFace_efficientnet_v2_s_full_model.pth", "local_dir": "ckpts" }, "SlimFace_regnet_y_800mf": { "repo": "danhtran2mind/SlimFace-sample-checkpoints", "filename": "SlimFace_regnet_y_800mf_full_model.pth", "local_dir": "ckpts" }, "SlimFace_vit_b_16": { "repo": "danhtran2mind/SlimFace-sample-checkpoints", "filename": "SlimFace_vit_b_16_full_model.pth", "local_dir": "ckpts" }, "SlimFace_mapping": { "repo": "danhtran2mind/SlimFace-sample-checkpoints", "filename": "index_to_class_mapping.json", "local_dir": "ckpts" } } def download_models(model_name=None): """Download specified models from model_configs to their respective local directories. Args: model_name (str, optional): Specific model to download. If None, download all models. """ # Determine files to download if model_name: if model_name not in model_configs: raise ValueError(f"Model {model_name} not found in available models: {list(model_configs.keys())}") configs_to_download = [model_configs[model_name]] else: configs_to_download = list(model_configs.values()) for config in configs_to_download: repo_id = config["repo"] filename = config["filename"] local_dir = config["local_dir"] # Ensure the local directory exists os.makedirs(local_dir, exist_ok=True) try: snapshot_download( repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, allow_patterns=[filename], cache_dir=None, revision="main" ) print(f"Downloaded {filename} to {local_dir}") except Exception as e: print(f"Error downloading {filename}: {e}") def main(): """Parse command-line arguments and initiate model download.""" parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub.") parser.add_argument( "--model", type=str, default=None, choices=list(model_configs.keys()), help="Specific model to download. If not provided, all models are downloaded." ) args = parser.parse_args() download_models(args.model) if __name__ == "__main__": main()