import os import glob import random import logging import torch import gradio as gr from matchboxnet.model import MatchboxNetForAudioClassification from matchboxnet.feature_extraction import MatchboxNetFeatureExtractor def setup_logging(): logging.basicConfig(level=logging.INFO) return logging.getLogger(__name__) logger = setup_logging() # Device selection if torch.cuda.is_available(): device = "cuda" logger.info("Using CUDA for inference.") elif torch.backends.mps.is_available(): device = "mps" logger.info("Using MPS for inference.") else: device = "cpu" logger.info("Using CPU for inference.") def load_model(): auth_token = os.getenv("HF_TOKEN") model = MatchboxNetForAudioClassification.from_pretrained( "Panga-Azazia/matchboxnet3x2x64-bambara-a-c", trust_remote_code=True ).to(device) feature_extractor = MatchboxNetFeatureExtractor.from_pretrained( "Panga-Azazia/matchboxnet3x2x64-bambara-a-c", ) model.config.id2label = {int(k): v for k, v in model.config.id2label.items()} return model, feature_extractor MODEL, FEATURE_EXTRACTOR = load_model() LABELS = [MODEL.config.id2label[i] for i in sorted(MODEL.config.id2label.keys())] def get_example_files(directory="./examples"): if not os.path.exists(directory): logger.warning(f"Examples directory {directory} not found.") return [] files = glob.glob(f"{directory}/**/**") random.shuffle(files) #audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.ogg'] #files = [os.path.abspath(os.path.join(directory, f)) # for f in os.listdir(directory) # if any(f.lower().endswith(ext) for ext in audio_extensions)] logger.info(f"Found {len(files)} example(s) files.") # return up to first 5 examples return files[:5] def predict(audio): if isinstance(audio, str): batch = FEATURE_EXTRACTOR(audio, sampling_rate=None, return_tensors="pt") else: sr, waveform = audio batch = FEATURE_EXTRACTOR(waveform, sampling_rate=sr, return_tensors="pt") batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = MODEL(**batch) probs = torch.softmax(outputs.logits, dim=-1).squeeze().tolist() pred_idx = int(torch.tensor(outputs.logits).argmax().item()) return {lbl: float(probs[i]) for i, lbl in enumerate(LABELS)} def build_interface(): example_files = get_example_files() with gr.Blocks(title="Bambara Audio Classification (ayi, awɔ, foyi)") as demo: gr.Markdown( """ # 🎧 Bambara Audio Classification Classify Bambara audio into **ayi**, **awɔ**, and **foyi** using MatchboxNet. """ ) with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="🎤 Record or Upload Audio", type="filepath", sources=["microphone", "upload"] # CORRECTION: 'source' → 'sources' ) predict_btn = gr.Button("🔍 Classify Audio", variant="primary") clear_btn = gr.Button("🗑️ Clear", variant="secondary") with gr.Column(): output_label = gr.Label(num_top_classes=3, label="Predicted Probabilities") if example_files: gr.Markdown("## 🎵 Try Examples") gr.Examples( examples=[[f] for f in example_files], inputs=[audio_input], outputs=[output_label], fn=predict, cache_examples=False ) # Random example button rnd_btn = gr.Button("🎲 Random Example") rnd_btn.click( fn=lambda: random.choice(example_files) if example_files else None, outputs=[audio_input] ) predict_btn.click( fn=predict, inputs=[audio_input], outputs=[output_label], show_progress=True ) clear_btn.click( fn=lambda: (None, {}), outputs=[audio_input, output_label] ) return demo def main(): logger.info("Starting Gradio app...") interface = build_interface() interface.launch( share=False, server_name="0.0.0.0", server_port=7860 ) logger.info("App launched.") if __name__ == "__main__": main()