File size: 3,515 Bytes
a468822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102bf16
 
 
 
 
 
 
 
 
 
a468822
 
 
 
 
 
 
 
fd5f95c
a468822
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import warnings
import torch
import librosa
import huggingface_hub
import gradio as gr
from piano_transcription_inference import PianoTranscription, sample_rate

# Suppress specific Gradio warning about package URL parsing
warnings.filterwarnings("ignore", message="unable to parse version details from package URL.")

WEIGHTS_PATH = huggingface_hub.snapshot_download(
    "Genius-Society/piano_trans",
    cache_dir="./__pycache__",
) + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"


def audio2midi(audio_path: str, cache_dir: str):
    print(f"Loading audio from {audio_path}")
    try:
        audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
        print("Audio loaded successfully")
    except Exception as e:
        print(f"Error loading audio: {e}")
        raise

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    transcriptor = PianoTranscription(
        device=device,
        checkpoint_path=WEIGHTS_PATH,
    )

    midi_path = f"{cache_dir}/output.mid"
    transcriptor.transcribe(audio, midi_path)

    return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize()


def process_audio(audio_path: str, cache_dir="./__pycache__/uploads"):
    status = "Success"
    midi = None

    try:
        os.makedirs(cache_dir, exist_ok=True)

        if not os.path.exists(audio_path):
            raise FileNotFoundError(f"Audio file not found: {audio_path}")

        file_size = os.path.getsize(audio_path)
        print(f"Audio file size: {file_size} bytes")

        midi, title = audio2midi(audio_path, cache_dir)
        print(f"MIDI generated successfully: {midi}")

    except Exception as e:
        import traceback
        status = f"{e}\n{traceback.format_exc()}"

    if midi and not os.path.exists(midi):
        print(f"Warning: MIDI file does not exist: {midi}")
        midi = None

    return status, midi


if __name__ == "__main__":
    with gr.Blocks() as iface:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        gr.Markdown("# Piano Transcription Tool: Audio->MIDI")
        gr.Markdown(f"Device: {device}")
        if device == "cpu":
            gr.Markdown("Will run slower on CPU, best on GPU")

        with gr.Row():
            with gr.Column(scale=1):
                audio_input = gr.Audio(label="Upload an audio", type="filepath")
                submit_btn = gr.Button("Transcribe")

            with gr.Column(scale=2):
                status_output = gr.Textbox(label="Status", show_copy_button=True)
                midi_file_output = gr.File(label="Download MIDI")
                gr.HTML("""
                <div style="margin-top: 10px;">
                    <p>For best MIDI playback experience:</p>
                    <ol>
                        <li>Download the MIDI file</li>
                        <li><a href="https://app.midiano.com/" target="_blank">Open MidiAno in a new tab</a></li>
                        <li>Drop the downloaded MIDI file into MidiAno</li>
                    </ol>
                </div>
                """)

        submit_btn.click(
            fn=process_audio,
            inputs=audio_input,
            outputs=[status_output, midi_file_output]
        )

        gr.Examples(
            examples=["jazz_sample.mp3","5-Octaves_2nd.mp3"],
            inputs=audio_input,
            outputs=[status_output, midi_file_output],
            fn=process_audio,
            cache_examples=True
        )

    iface.launch()