|
import os |
|
import glob |
|
import random |
|
import logging |
|
import torch |
|
import gradio as gr |
|
from matchboxnet.model import MatchboxNetForAudioClassification |
|
from matchboxnet.feature_extraction import MatchboxNetFeatureExtractor |
|
|
|
def setup_logging(): |
|
logging.basicConfig(level=logging.INFO) |
|
return logging.getLogger(__name__) |
|
|
|
logger = setup_logging() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
logger.info("Using CUDA for inference.") |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
logger.info("Using MPS for inference.") |
|
else: |
|
device = "cpu" |
|
logger.info("Using CPU for inference.") |
|
|
|
|
|
def load_model(): |
|
auth_token = os.getenv("HF_TOKEN") |
|
model = MatchboxNetForAudioClassification.from_pretrained( |
|
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c", |
|
trust_remote_code=True |
|
).to(device) |
|
feature_extractor = MatchboxNetFeatureExtractor.from_pretrained( |
|
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c", |
|
) |
|
model.config.id2label = {int(k): v for k, v in model.config.id2label.items()} |
|
return model, feature_extractor |
|
|
|
MODEL, FEATURE_EXTRACTOR = load_model() |
|
LABELS = [MODEL.config.id2label[i] for i in sorted(MODEL.config.id2label.keys())] |
|
|
|
|
|
def get_example_files(directory="./examples"): |
|
if not os.path.exists(directory): |
|
logger.warning(f"Examples directory {directory} not found.") |
|
return [] |
|
files = glob.glob(f"{directory}/**/**") |
|
random.shuffle(files) |
|
|
|
|
|
|
|
|
|
logger.info(f"Found {len(files)} example(s) files.") |
|
|
|
return files[:5] |
|
|
|
|
|
def predict(audio): |
|
if isinstance(audio, str): |
|
batch = FEATURE_EXTRACTOR(audio, sampling_rate=None, return_tensors="pt") |
|
else: |
|
sr, waveform = audio |
|
batch = FEATURE_EXTRACTOR(waveform, sampling_rate=sr, return_tensors="pt") |
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = MODEL(**batch) |
|
probs = torch.softmax(outputs.logits, dim=-1).squeeze().tolist() |
|
pred_idx = int(torch.tensor(outputs.logits).argmax().item()) |
|
return {lbl: float(probs[i]) for i, lbl in enumerate(LABELS)} |
|
|
|
|
|
def build_interface(): |
|
example_files = get_example_files() |
|
with gr.Blocks(title="Bambara Audio Classification (ayi, awΙ, foyi)") as demo: |
|
gr.Markdown( |
|
""" |
|
# π§ Bambara Audio Classification |
|
|
|
Classify Bambara audio into **ayi**, **awΙ**, and **foyi** using MatchboxNet. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio( |
|
label="π€ Record or Upload Audio", |
|
type="filepath", |
|
sources=["microphone", "upload"] |
|
) |
|
predict_btn = gr.Button("π Classify Audio", variant="primary") |
|
clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
|
with gr.Column(): |
|
output_label = gr.Label(num_top_classes=3, label="Predicted Probabilities") |
|
if example_files: |
|
gr.Markdown("## π΅ Try Examples") |
|
gr.Examples( |
|
examples=[[f] for f in example_files], |
|
inputs=[audio_input], |
|
outputs=[output_label], |
|
fn=predict, |
|
cache_examples=False |
|
) |
|
|
|
rnd_btn = gr.Button("π² Random Example") |
|
rnd_btn.click( |
|
fn=lambda: random.choice(example_files) if example_files else None, |
|
outputs=[audio_input] |
|
) |
|
predict_btn.click( |
|
fn=predict, |
|
inputs=[audio_input], |
|
outputs=[output_label], |
|
show_progress=True |
|
) |
|
clear_btn.click( |
|
fn=lambda: (None, {}), |
|
outputs=[audio_input, output_label] |
|
) |
|
return demo |
|
|
|
|
|
def main(): |
|
logger.info("Starting Gradio app...") |
|
interface = build_interface() |
|
interface.launch( |
|
share=False, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |
|
logger.info("App launched.") |
|
|
|
if __name__ == "__main__": |
|
main() |