Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from get_difficulty import predict_difficulty | |
import tempfile | |
import os | |
from pydub import AudioSegment | |
import yt_dlp | |
import mimetypes | |
from huggingface_hub import hf_hub_download | |
import torch | |
import sys | |
import io | |
import spaces | |
REPO_ID = "pramoneda/audio" | |
CACHE_BASE = "models" | |
def download_model_checkpoints(model_name: str, num_checkpoints: int = 5): | |
cache_dir = os.path.join(CACHE_BASE, model_name) | |
os.makedirs(cache_dir, exist_ok=True) | |
for checkpoint_id in range(num_checkpoints): | |
filename = f"{model_name}/checkpoint_{checkpoint_id}.pth" | |
local_path = os.path.join(cache_dir, f"checkpoint_{checkpoint_id}.pth") | |
if not os.path.exists(local_path): | |
path = hf_hub_download(repo_id=REPO_ID, filename=filename, cache_dir=cache_dir) | |
if path != local_path: | |
import shutil | |
shutil.copy(path, local_path) | |
def download_youtube_audio(url, cookie_file=None): | |
output_path = "yt_audio.%(ext)s" | |
ydl_opts = { | |
"format": "bestaudio/best", | |
"outtmpl": output_path, | |
"postprocessors": [{ | |
"key": "FFmpegExtractAudio", | |
"preferredcodec": "mp3", | |
"preferredquality": "192", | |
}], | |
"quiet": True, | |
"no_warnings": True | |
} | |
if cookie_file: | |
ydl_opts["cookiefile"] = cookie_file # <-- usa el archivo de cookies | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
return "yt_audio.mp3" | |
def convert_to_mp3(input_path): | |
audio = AudioSegment.from_file(input_path) | |
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
audio.export(temp_audio.name, format="mp3") | |
return temp_audio.name | |
def process_input(input_file, youtube_url, cookie_file): | |
# captura consola | |
captured_output = io.StringIO() | |
sys.stdout = captured_output | |
# procesa audio/video | |
if youtube_url: | |
audio_path = download_youtube_audio(youtube_url, cookie_file) | |
mp3_path = audio_path | |
elif input_file: | |
mime_type, _ = mimetypes.guess_type(input_file) | |
audio_path = convert_to_mp3(input_file) | |
mp3_path = audio_path | |
else: | |
sys.stdout = sys.__stdout__ | |
return "No audio or video provided.", None, None, None, "" | |
# descarga checkpoints | |
for model in ["audio_midi_cqt5_ps_v5", "audio_midi_pianoroll_ps_5_v4", "audio_midi_multi_ps_v5"]: | |
download_model_checkpoints(model) | |
# predicciones | |
diff_cqt = predict_difficulty(audio_path, model_name="audio_midi_cqt5_ps_v5", rep="cqt5") | |
diff_pr = predict_difficulty(audio_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5") | |
diff_multi = predict_difficulty(audio_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5") | |
sys.stdout = sys.__stdout__ | |
log_output = captured_output.getvalue() | |
midi_path = "temp.mid" | |
if not os.path.exists(midi_path): | |
return "MIDI not generated.", None, None, None, log_output | |
difficulty_text = ( | |
f"CQT difficulty: {diff_cqt}\n" | |
f"Pianoroll difficulty: {diff_pr}\n" | |
f"Multimodal difficulty: {diff_multi}" | |
) | |
return difficulty_text, midi_path, midi_path, mp3_path, log_output | |
demo = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.File(label="Upload MP3 or MP4", type="filepath"), | |
gr.Textbox(label="YouTube URL"), | |
gr.File(label="Upload cookies.txt (optional)", file_types=["text"], type="filepath") | |
], | |
outputs=[ | |
gr.Textbox(label="Difficulty predictions"), | |
gr.File(label="Generated MIDI"), | |
gr.Audio(label="MIDI Playback", type="filepath"), | |
gr.Audio(label="Extracted MP3 Preview", type="filepath"), | |
gr.Textbox(label="Console Output") | |
], | |
title="Music Difficulty Estimator", | |
description=( | |
"Upload an MP3/MP4 or provide a YouTube URL. " | |
"If you want to predict the difficulty directly from youtube, export your YouTube cookies as a Netscape-format file " | |
"and upload it here. Then the app can download and process the audio." | |
"Related publication: [IEEE TASLP paper](https://ieeexplore.ieee.org/document/10878288)" | |
) | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |