Spaces:
Running
on
Zero
Running
on
Zero
from huggingface_hub import hf_hub_download | |
import torch | |
import os | |
REPO_ID = "pramoneda/audio" | |
CACHE_BASE = "models" | |
def download_model_checkpoint(model_name: str, checkpoint_id: int): | |
filename = f"{model_name}/checkpoint_{checkpoint_id}_clean.pth" | |
cache_dir = os.path.join(CACHE_BASE, model_name) | |
print(f"Downloading {filename} from {REPO_ID} to {cache_dir}") | |
path = hf_hub_download( | |
repo_id=REPO_ID, | |
filename=filename, | |
cache_dir=cache_dir | |
) | |
state_dict = torch.load(path, map_location="cpu") | |
return state_dict | |
def ensure_local_checkpoints(): | |
models = { | |
"audio_midi_cqt5_ps_v5": 0, | |
"audio_midi_pianoroll_ps_5_v4": 0, | |
"audio_midi_multi_ps_v5": 0 | |
} | |
for model_name, checkpoint_id in models.items(): | |
try: | |
_ = download_model_checkpoint(model_name, checkpoint_id) | |
except Exception as e: | |
print(f"❌ Failed to download {model_name}: {e}") | |
if __name__ == "__main__": | |
ensure_local_checkpoints() | |