File size: 4,261 Bytes
58729a4
 
 
 
 
 
 
c66e52a
 
58729a4
 
1b24c84
c66e52a
 
 
 
58729a4
c66e52a
58729a4
 
 
 
 
2e9908b
58729a4
 
 
c66e52a
2e9908b
58729a4
 
 
 
 
 
 
 
 
 
 
 
2e9908b
 
c66e52a
58729a4
 
c66e52a
58729a4
 
 
 
 
 
 
 
1b24c84
2e9908b
 
58729a4
 
 
2e9908b
58729a4
2e9908b
58729a4
 
 
2e9908b
 
58729a4
 
2e9908b
58729a4
2e9908b
 
 
58729a4
2e9908b
 
 
 
58729a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c66e52a
58729a4
 
 
 
2e9908b
 
58729a4
 
 
 
 
 
 
 
 
2e9908b
 
45e5657
2e9908b
94bd7b5
2e9908b
58729a4
c66e52a
 
b30da14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
from get_difficulty import predict_difficulty
import tempfile
import os
from pydub import AudioSegment
import yt_dlp
import mimetypes
from huggingface_hub import hf_hub_download
import torch
import sys
import io
import spaces

REPO_ID = "pramoneda/audio"
CACHE_BASE = "models"

def download_model_checkpoints(model_name: str, num_checkpoints: int = 5):
    cache_dir = os.path.join(CACHE_BASE, model_name)
    os.makedirs(cache_dir, exist_ok=True)
    for checkpoint_id in range(num_checkpoints):
        filename = f"{model_name}/checkpoint_{checkpoint_id}.pth"
        local_path = os.path.join(cache_dir, f"checkpoint_{checkpoint_id}.pth")
        if not os.path.exists(local_path):
            path = hf_hub_download(repo_id=REPO_ID, filename=filename, cache_dir=cache_dir)
            if path != local_path:
                import shutil
                shutil.copy(path, local_path)

def download_youtube_audio(url, cookie_file=None):
    output_path = "yt_audio.%(ext)s"
    ydl_opts = {
        "format": "bestaudio/best",
        "outtmpl": output_path,
        "postprocessors": [{
            "key": "FFmpegExtractAudio",
            "preferredcodec": "mp3",
            "preferredquality": "192",
        }],
        "quiet": True,
        "no_warnings": True
    }
    if cookie_file:
        ydl_opts["cookiefile"] = cookie_file  # <-- usa el archivo de cookies

    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([url])

    return "yt_audio.mp3"

def convert_to_mp3(input_path):
    audio = AudioSegment.from_file(input_path)
    temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
    audio.export(temp_audio.name, format="mp3")
    return temp_audio.name

@spaces.GPU
def process_input(input_file, youtube_url, cookie_file):
    # captura consola
    captured_output = io.StringIO()
    sys.stdout = captured_output

    # procesa audio/video
    if youtube_url:
        audio_path = download_youtube_audio(youtube_url, cookie_file)
        mp3_path = audio_path
    elif input_file:
        mime_type, _ = mimetypes.guess_type(input_file)
        audio_path = convert_to_mp3(input_file)
        mp3_path = audio_path
    else:
        sys.stdout = sys.__stdout__
        return "No audio or video provided.", None, None, None, ""

    # descarga checkpoints
    for model in ["audio_midi_cqt5_ps_v5", "audio_midi_pianoroll_ps_5_v4", "audio_midi_multi_ps_v5"]:
        download_model_checkpoints(model)

    # predicciones
    diff_cqt = predict_difficulty(audio_path, model_name="audio_midi_cqt5_ps_v5", rep="cqt5")
    diff_pr = predict_difficulty(audio_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5")
    diff_multi = predict_difficulty(audio_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5")

    sys.stdout = sys.__stdout__
    log_output = captured_output.getvalue()

    midi_path = "temp.mid"
    if not os.path.exists(midi_path):
        return "MIDI not generated.", None, None, None, log_output

    difficulty_text = (
        f"CQT difficulty: {diff_cqt}\n"
        f"Pianoroll difficulty: {diff_pr}\n"
        f"Multimodal difficulty: {diff_multi}"
    )

    return difficulty_text, midi_path, midi_path, mp3_path, log_output

demo = gr.Interface(
    fn=process_input,
    inputs=[
        gr.File(label="Upload MP3 or MP4", type="filepath"),
        gr.Textbox(label="YouTube URL"),
        gr.File(label="Upload cookies.txt (optional)", file_types=["text"], type="filepath")
    ],
    outputs=[
        gr.Textbox(label="Difficulty predictions"),
        gr.File(label="Generated MIDI"),
        gr.Audio(label="MIDI Playback", type="filepath"),
        gr.Audio(label="Extracted MP3 Preview", type="filepath"),
        gr.Textbox(label="Console Output")
    ],
    title="Music Difficulty Estimator",
    description=(
        "Upload an MP3/MP4 or provide a YouTube URL. "
        "If you want to predict the difficulty directly from youtube, export your YouTube cookies as a Netscape-format file "
        "and upload it here. Then the app can download and process the audio."
        "Related publication: [IEEE TASLP paper](https://ieeexplore.ieee.org/document/10878288)"
    )
)

if __name__ == "__main__":
    demo.launch(debug=True)