Spaces:
Running
Running
import os | |
import argparse | |
from huggingface_hub import snapshot_download | |
# Model configurations for EdgeFace models | |
model_configs = { | |
"edgeface_base": { | |
"repo": "idiap/EdgeFace-Base", | |
"filename": "edgeface_base.pt", | |
"local_dir": "ckpts/idiap" | |
}, | |
"edgeface_s_gamma_05": { | |
"repo": "idiap/EdgeFace-S-GAMMA", | |
"filename": "edgeface_s_gamma_05.pt", | |
"local_dir": "ckpts/idiap" | |
}, | |
"edgeface_xs_gamma_06": { | |
"repo": "idiap/EdgeFace-XS-GAMMA", | |
"filename": "edgeface_xs_gamma_06.pt", | |
"local_dir": "ckpts/idiap" | |
}, | |
"edgeface_xxs": { | |
"repo": "idiap/EdgeFace-XXS", | |
"filename": "edgeface_xxs.pt", | |
"local_dir": "ckpts/idiap" | |
}, | |
"SlimFace_efficientnet_b3": { | |
"repo": "danhtran2mind/SlimFace-sample-checkpoints", | |
"filename": "SlimFace_efficientnet_b3_full_model.pth", | |
"local_dir": "ckpts" | |
}, | |
"SlimFace_efficientnet_v2_s": { | |
"repo": "danhtran2mind/SlimFace-sample-checkpoints", | |
"filename": "SlimFace_efficientnet_v2_s_full_model.pth", | |
"local_dir": "ckpts" | |
}, | |
"SlimFace_regnet_y_800mf": { | |
"repo": "danhtran2mind/SlimFace-sample-checkpoints", | |
"filename": "SlimFace_regnet_y_800mf_full_model.pth", | |
"local_dir": "ckpts" | |
}, | |
"SlimFace_vit_b_16": { | |
"repo": "danhtran2mind/SlimFace-sample-checkpoints", | |
"filename": "SlimFace_vit_b_16_full_model.pth", | |
"local_dir": "ckpts" | |
}, | |
"SlimFace_mapping": { | |
"repo": "danhtran2mind/SlimFace-sample-checkpoints", | |
"filename": "index_to_class_mapping.json", | |
"local_dir": "ckpts" | |
} | |
} | |
def download_models(model_name=None): | |
"""Download specified models from model_configs to their respective local directories. | |
Args: | |
model_name (str, optional): Specific model to download. If None, download all models. | |
""" | |
# Determine files to download | |
if model_name: | |
if model_name not in model_configs: | |
raise ValueError(f"Model {model_name} not found in available models: {list(model_configs.keys())}") | |
configs_to_download = [model_configs[model_name]] | |
else: | |
configs_to_download = list(model_configs.values()) | |
for config in configs_to_download: | |
repo_id = config["repo"] | |
filename = config["filename"] | |
local_dir = config["local_dir"] | |
# Ensure the local directory exists | |
os.makedirs(local_dir, exist_ok=True) | |
try: | |
snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False, | |
allow_patterns=[filename], | |
cache_dir=None, | |
revision="main" | |
) | |
print(f"Downloaded {filename} to {local_dir}") | |
except Exception as e: | |
print(f"Error downloading {filename}: {e}") | |
def main(): | |
"""Parse command-line arguments and initiate model download.""" | |
parser = argparse.ArgumentParser(description="Download models from Hugging Face Hub.") | |
parser.add_argument( | |
"--model", | |
type=str, | |
default=None, | |
choices=list(model_configs.keys()), | |
help="Specific model to download. If not provided, all models are downloaded." | |
) | |
args = parser.parse_args() | |
download_models(args.model) | |
if __name__ == "__main__": | |
main() |