File size: 4,445 Bytes
55d0a43
 
6f454d6
 
 
3b63824
 
 
 
6f454d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b63824
 
826a6ea
3b63824
826a6ea
 
6f454d6
3b63824
826a6ea
3b63824
 
 
 
 
 
 
6f454d6
 
 
 
 
7d51c4b
d62b428
0c98b0c
 
 
 
e6945d0
6f454d6
 
3b63824
 
 
 
6f454d6
3b63824
 
 
6f454d6
3b63824
 
 
6f454d6
 
3b63824
 
 
6f454d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe3d61a
6f454d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b63824
 
fe3d61a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import glob
import random
import logging
import torch
import gradio as gr
from matchboxnet.model import MatchboxNetForAudioClassification
from matchboxnet.feature_extraction import MatchboxNetFeatureExtractor

def setup_logging():
    logging.basicConfig(level=logging.INFO)
    return logging.getLogger(__name__)

logger = setup_logging()

# Device selection
if torch.cuda.is_available():
    device = "cuda"
    logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
    device = "mps"
    logger.info("Using MPS for inference.")
else:
    device = "cpu"
    logger.info("Using CPU for inference.")


def load_model():
    auth_token = os.getenv("HF_TOKEN")
    model = MatchboxNetForAudioClassification.from_pretrained(
        "Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
        trust_remote_code=True
    ).to(device)
    feature_extractor = MatchboxNetFeatureExtractor.from_pretrained(
        "Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
    )
    model.config.id2label = {int(k): v for k, v in model.config.id2label.items()}
    return model, feature_extractor

MODEL, FEATURE_EXTRACTOR = load_model()
LABELS = [MODEL.config.id2label[i] for i in sorted(MODEL.config.id2label.keys())]


def get_example_files(directory="./examples"):
    if not os.path.exists(directory):
        logger.warning(f"Examples directory {directory} not found.")
        return []
    files = glob.glob(f"{directory}/**/**")
    random.shuffle(files)
    #audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.ogg']
    #files = [os.path.abspath(os.path.join(directory, f))
    #         for f in os.listdir(directory)
    #         if any(f.lower().endswith(ext) for ext in audio_extensions)]
    logger.info(f"Found {len(files)} example(s) files.")
    # return up to first 5 examples
    return files[:5]


def predict(audio):
    if isinstance(audio, str):
        batch = FEATURE_EXTRACTOR(audio, sampling_rate=None, return_tensors="pt")
    else:
        sr, waveform = audio
        batch = FEATURE_EXTRACTOR(waveform, sampling_rate=sr, return_tensors="pt")
    batch = {k: v.to(device) for k, v in batch.items()}

    with torch.no_grad():
        outputs = MODEL(**batch)
        probs = torch.softmax(outputs.logits, dim=-1).squeeze().tolist()
        pred_idx = int(torch.tensor(outputs.logits).argmax().item())
    return {lbl: float(probs[i]) for i, lbl in enumerate(LABELS)}


def build_interface():
    example_files = get_example_files()
    with gr.Blocks(title="Bambara Audio Classification (ayi, awΙ”, foyi)") as demo:
        gr.Markdown(
            """
            # 🎧 Bambara Audio Classification

            Classify Bambara audio into **ayi**, **awΙ”**, and **foyi** using MatchboxNet.
            """
        )
        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(
                    label="🎀 Record or Upload Audio",
                    type="filepath",
                    sources=["microphone", "upload"]  # CORRECTION: 'source' β†’ 'sources'
                )
                predict_btn = gr.Button("πŸ” Classify Audio", variant="primary")
                clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
            with gr.Column():
                output_label = gr.Label(num_top_classes=3, label="Predicted Probabilities")
        if example_files:
            gr.Markdown("## 🎡 Try Examples")
            gr.Examples(
                examples=[[f] for f in example_files],
                inputs=[audio_input],
                outputs=[output_label],
                fn=predict,
                cache_examples=False
            )
        # Random example button
        rnd_btn = gr.Button("🎲 Random Example")
        rnd_btn.click(
            fn=lambda: random.choice(example_files) if example_files else None,
            outputs=[audio_input]
        )
        predict_btn.click(
            fn=predict,
            inputs=[audio_input],
            outputs=[output_label],
            show_progress=True
        )
        clear_btn.click(
            fn=lambda: (None, {}),
            outputs=[audio_input, output_label]
        )
    return demo


def main():
    logger.info("Starting Gradio app...")
    interface = build_interface()
    interface.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860
    )
    logger.info("App launched.")

if __name__ == "__main__":
    main()