PRamoneda commited on
Commit
2e9908b
·
1 Parent(s): 7243f55
Files changed (1) hide show
  1. app.py +26 -35
app.py CHANGED
@@ -16,24 +16,16 @@ CACHE_BASE = "models"
16
  def download_model_checkpoints(model_name: str, num_checkpoints: int = 5):
17
  cache_dir = os.path.join(CACHE_BASE, model_name)
18
  os.makedirs(cache_dir, exist_ok=True)
19
-
20
  for checkpoint_id in range(num_checkpoints):
21
  filename = f"{model_name}/checkpoint_{checkpoint_id}.pth"
22
  local_path = os.path.join(cache_dir, f"checkpoint_{checkpoint_id}.pth")
23
-
24
  if not os.path.exists(local_path):
25
- print(f"Downloading {filename} from {REPO_ID} to {cache_dir}")
26
- path = hf_hub_download(
27
- repo_id=REPO_ID,
28
- filename=filename,
29
- cache_dir=cache_dir
30
- )
31
- # Copy to expected location
32
  if path != local_path:
33
  import shutil
34
  shutil.copy(path, local_path)
35
 
36
- def download_youtube_audio(url):
37
  output_path = "yt_audio.%(ext)s"
38
  ydl_opts = {
39
  "format": "bestaudio/best",
@@ -46,6 +38,8 @@ def download_youtube_audio(url):
46
  "quiet": True,
47
  "no_warnings": True
48
  }
 
 
49
 
50
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
51
  ydl.download([url])
@@ -58,39 +52,31 @@ def convert_to_mp3(input_path):
58
  audio.export(temp_audio.name, format="mp3")
59
  return temp_audio.name
60
 
61
- def process_input(input_file, youtube_url):
 
62
  captured_output = io.StringIO()
63
  sys.stdout = captured_output
64
 
65
- audio_path = None
66
- mp3_path = None
67
-
68
  if youtube_url:
69
- audio_path = download_youtube_audio(youtube_url)
70
  mp3_path = audio_path
71
  elif input_file:
72
  mime_type, _ = mimetypes.guess_type(input_file)
73
- if mime_type and mime_type.startswith("video/"):
74
- audio_path = convert_to_mp3(input_file)
75
- mp3_path = audio_path
76
- else:
77
- audio_path = convert_to_mp3(input_file)
78
- mp3_path = audio_path
79
  else:
80
  sys.stdout = sys.__stdout__
81
- return "No audio or video provided.", None, None, None
82
-
83
- model_cqt = "audio_midi_cqt5_ps_v5"
84
- model_pr = "audio_midi_pianoroll_ps_5_v4"
85
- model_multi = "audio_midi_multi_ps_v5"
86
 
87
- download_model_checkpoints(model_cqt)
88
- download_model_checkpoints(model_pr)
89
- download_model_checkpoints(model_multi)
90
 
91
- diff_cqt = predict_difficulty(audio_path, model_name=model_cqt, rep="cqt5")
92
- diff_pr = predict_difficulty(audio_path, model_name=model_pr, rep="pianoroll5")
93
- diff_multi = predict_difficulty(audio_path, model_name=model_multi, rep="multimodal5")
 
94
 
95
  sys.stdout = sys.__stdout__
96
  log_output = captured_output.getvalue()
@@ -111,7 +97,8 @@ demo = gr.Interface(
111
  fn=process_input,
112
  inputs=[
113
  gr.File(label="Upload MP3 or MP4", type="filepath"),
114
- gr.Textbox(label="YouTube URL")
 
115
  ],
116
  outputs=[
117
  gr.Textbox(label="Difficulty predictions"),
@@ -121,8 +108,12 @@ demo = gr.Interface(
121
  gr.Textbox(label="Console Output")
122
  ],
123
  title="Music Difficulty Estimator",
124
- description="Upload an MP3, MP4, or provide a YouTube URL. It extracts audio, predicts difficulty, and generates a MIDI file."
 
 
 
 
125
  )
126
 
127
  if __name__ == "__main__":
128
- demo.launch(debug=True, share=True)
 
16
  def download_model_checkpoints(model_name: str, num_checkpoints: int = 5):
17
  cache_dir = os.path.join(CACHE_BASE, model_name)
18
  os.makedirs(cache_dir, exist_ok=True)
 
19
  for checkpoint_id in range(num_checkpoints):
20
  filename = f"{model_name}/checkpoint_{checkpoint_id}.pth"
21
  local_path = os.path.join(cache_dir, f"checkpoint_{checkpoint_id}.pth")
 
22
  if not os.path.exists(local_path):
23
+ path = hf_hub_download(repo_id=REPO_ID, filename=filename, cache_dir=cache_dir)
 
 
 
 
 
 
24
  if path != local_path:
25
  import shutil
26
  shutil.copy(path, local_path)
27
 
28
+ def download_youtube_audio(url, cookie_file=None):
29
  output_path = "yt_audio.%(ext)s"
30
  ydl_opts = {
31
  "format": "bestaudio/best",
 
38
  "quiet": True,
39
  "no_warnings": True
40
  }
41
+ if cookie_file:
42
+ ydl_opts["cookiefile"] = cookie_file # <-- usa el archivo de cookies
43
 
44
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
45
  ydl.download([url])
 
52
  audio.export(temp_audio.name, format="mp3")
53
  return temp_audio.name
54
 
55
+ def process_input(input_file, youtube_url, cookie_file):
56
+ # captura consola
57
  captured_output = io.StringIO()
58
  sys.stdout = captured_output
59
 
60
+ # procesa audio/video
 
 
61
  if youtube_url:
62
+ audio_path = download_youtube_audio(youtube_url, cookie_file)
63
  mp3_path = audio_path
64
  elif input_file:
65
  mime_type, _ = mimetypes.guess_type(input_file)
66
+ audio_path = convert_to_mp3(input_file)
67
+ mp3_path = audio_path
 
 
 
 
68
  else:
69
  sys.stdout = sys.__stdout__
70
+ return "No audio or video provided.", None, None, None, ""
 
 
 
 
71
 
72
+ # descarga checkpoints
73
+ for model in ["audio_midi_cqt5_ps_v5", "audio_midi_pianoroll_ps_5_v4", "audio_midi_multi_ps_v5"]:
74
+ download_model_checkpoints(model)
75
 
76
+ # predicciones
77
+ diff_cqt = predict_difficulty(audio_path, model_name="audio_midi_cqt5_ps_v5", rep="cqt5")
78
+ diff_pr = predict_difficulty(audio_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5")
79
+ diff_multi = predict_difficulty(audio_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5")
80
 
81
  sys.stdout = sys.__stdout__
82
  log_output = captured_output.getvalue()
 
97
  fn=process_input,
98
  inputs=[
99
  gr.File(label="Upload MP3 or MP4", type="filepath"),
100
+ gr.Textbox(label="YouTube URL"),
101
+ gr.File(label="Upload cookies.txt (optional)", file_types=["text"], type="filepath")
102
  ],
103
  outputs=[
104
  gr.Textbox(label="Difficulty predictions"),
 
108
  gr.Textbox(label="Console Output")
109
  ],
110
  title="Music Difficulty Estimator",
111
+ description=(
112
+ "Upload an MP3/MP4 or provide a YouTube URL. "
113
+ "If the video is age-restricted, export your YouTube cookies as a Netscape-format file "
114
+ "and upload it here. Then the app can download and process the audio."
115
+ )
116
  )
117
 
118
  if __name__ == "__main__":
119
+ demo.launch(debug=True)