|
import os |
|
import warnings |
|
import torch |
|
import librosa |
|
import huggingface_hub |
|
import gradio as gr |
|
from piano_transcription_inference import PianoTranscription, sample_rate |
|
|
|
|
|
warnings.filterwarnings("ignore", message="unable to parse version details from package URL.") |
|
|
|
WEIGHTS_PATH = huggingface_hub.snapshot_download( |
|
"Genius-Society/piano_trans", |
|
cache_dir="./__pycache__", |
|
) + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" |
|
|
|
|
|
def audio2midi(audio_path: str, cache_dir: str): |
|
print(f"Loading audio from {audio_path}") |
|
try: |
|
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) |
|
print("Audio loaded successfully") |
|
except Exception as e: |
|
print(f"Error loading audio: {e}") |
|
raise |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
transcriptor = PianoTranscription( |
|
device=device, |
|
checkpoint_path=WEIGHTS_PATH, |
|
) |
|
|
|
midi_path = f"{cache_dir}/output.mid" |
|
transcriptor.transcribe(audio, midi_path) |
|
|
|
return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() |
|
|
|
|
|
def process_audio(audio_path: str, cache_dir="./__pycache__/uploads"): |
|
status = "Success" |
|
midi = None |
|
|
|
try: |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
if not os.path.exists(audio_path): |
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
file_size = os.path.getsize(audio_path) |
|
print(f"Audio file size: {file_size} bytes") |
|
|
|
midi, title = audio2midi(audio_path, cache_dir) |
|
print(f"MIDI generated successfully: {midi}") |
|
|
|
except Exception as e: |
|
import traceback |
|
status = f"{e}\n{traceback.format_exc()}" |
|
|
|
if midi and not os.path.exists(midi): |
|
print(f"Warning: MIDI file does not exist: {midi}") |
|
midi = None |
|
|
|
return status, midi |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as iface: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
gr.Markdown("# Piano Transcription Tool: Audio->MIDI") |
|
gr.Markdown(f"Device: {device}") |
|
if device == "cpu": |
|
gr.Markdown("Will run slower on CPU, best on GPU") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
audio_input = gr.Audio(label="Upload an audio", type="filepath") |
|
submit_btn = gr.Button("Transcribe") |
|
|
|
with gr.Column(scale=2): |
|
status_output = gr.Textbox(label="Status", show_copy_button=True) |
|
midi_file_output = gr.File(label="Download MIDI") |
|
gr.HTML(""" |
|
<div style="margin-top: 10px;"> |
|
<p>For best MIDI playback experience:</p> |
|
<ol> |
|
<li>Download the MIDI file</li> |
|
<li><a href="https://app.midiano.com/" target="_blank">Open MidiAno in a new tab</a></li> |
|
<li>Drop the downloaded MIDI file into MidiAno</li> |
|
</ol> |
|
</div> |
|
""") |
|
|
|
submit_btn.click( |
|
fn=process_audio, |
|
inputs=audio_input, |
|
outputs=[status_output, midi_file_output] |
|
) |
|
|
|
gr.Examples( |
|
examples=["jazz_sample.mp3","5-Octaves_2nd.mp3"], |
|
inputs=audio_input, |
|
outputs=[status_output, midi_file_output], |
|
fn=process_audio, |
|
cache_examples=True |
|
) |
|
|
|
iface.launch() |
|
|