PRamoneda
Initial commit for Hugging Face Space
c66e52a
raw
history blame
1.03 kB
from huggingface_hub import hf_hub_download
import torch
import os
REPO_ID = "pramoneda/audio"
CACHE_BASE = "models"
def download_model_checkpoint(model_name: str, checkpoint_id: int):
filename = f"{model_name}/checkpoint_{checkpoint_id}_clean.pth"
cache_dir = os.path.join(CACHE_BASE, model_name)
print(f"Downloading {filename} from {REPO_ID} to {cache_dir}")
path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
cache_dir=cache_dir
)
state_dict = torch.load(path, map_location="cpu")
return state_dict
def ensure_local_checkpoints():
models = {
"audio_midi_cqt5_ps_v5": 0,
"audio_midi_pianoroll_ps_5_v4": 0,
"audio_midi_multi_ps_v5": 0
}
for model_name, checkpoint_id in models.items():
try:
_ = download_model_checkpoint(model_name, checkpoint_id)
except Exception as e:
print(f"❌ Failed to download {model_name}: {e}")
if __name__ == "__main__":
ensure_local_checkpoints()