|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
from project_settings import environment, project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--trained_model_dir", |
|
default=(project_path / "trained_models").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--models_repo_id", |
|
default="qgyd2021/vm_sound_classification", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--model_pattern", |
|
default="sound-*-ch32.zip", |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--hf_token", |
|
default=environment.get("hf_token"), |
|
type=str, |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
trained_model_dir = Path(args.trained_model_dir) |
|
trained_model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
_ = snapshot_download( |
|
repo_id=args.models_repo_id, |
|
allow_patterns=[args.model_pattern], |
|
local_dir=trained_model_dir.as_posix(), |
|
token=args.hf_token, |
|
) |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|