Panga-Azazia's picture
Update app.py
d62b428 verified
import os
import glob
import random
import logging
import torch
import gradio as gr
from matchboxnet.model import MatchboxNetForAudioClassification
from matchboxnet.feature_extraction import MatchboxNetFeatureExtractor
def setup_logging():
logging.basicConfig(level=logging.INFO)
return logging.getLogger(__name__)
logger = setup_logging()
# Device selection
if torch.cuda.is_available():
device = "cuda"
logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using MPS for inference.")
else:
device = "cpu"
logger.info("Using CPU for inference.")
def load_model():
auth_token = os.getenv("HF_TOKEN")
model = MatchboxNetForAudioClassification.from_pretrained(
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
trust_remote_code=True
).to(device)
feature_extractor = MatchboxNetFeatureExtractor.from_pretrained(
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
)
model.config.id2label = {int(k): v for k, v in model.config.id2label.items()}
return model, feature_extractor
MODEL, FEATURE_EXTRACTOR = load_model()
LABELS = [MODEL.config.id2label[i] for i in sorted(MODEL.config.id2label.keys())]
def get_example_files(directory="./examples"):
if not os.path.exists(directory):
logger.warning(f"Examples directory {directory} not found.")
return []
files = glob.glob(f"{directory}/**/**")
random.shuffle(files)
#audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.ogg']
#files = [os.path.abspath(os.path.join(directory, f))
# for f in os.listdir(directory)
# if any(f.lower().endswith(ext) for ext in audio_extensions)]
logger.info(f"Found {len(files)} example(s) files.")
# return up to first 5 examples
return files[:5]
def predict(audio):
if isinstance(audio, str):
batch = FEATURE_EXTRACTOR(audio, sampling_rate=None, return_tensors="pt")
else:
sr, waveform = audio
batch = FEATURE_EXTRACTOR(waveform, sampling_rate=sr, return_tensors="pt")
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = MODEL(**batch)
probs = torch.softmax(outputs.logits, dim=-1).squeeze().tolist()
pred_idx = int(torch.tensor(outputs.logits).argmax().item())
return {lbl: float(probs[i]) for i, lbl in enumerate(LABELS)}
def build_interface():
example_files = get_example_files()
with gr.Blocks(title="Bambara Audio Classification (ayi, awΙ”, foyi)") as demo:
gr.Markdown(
"""
# 🎧 Bambara Audio Classification
Classify Bambara audio into **ayi**, **awΙ”**, and **foyi** using MatchboxNet.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="🎀 Record or Upload Audio",
type="filepath",
sources=["microphone", "upload"] # CORRECTION: 'source' β†’ 'sources'
)
predict_btn = gr.Button("πŸ” Classify Audio", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
with gr.Column():
output_label = gr.Label(num_top_classes=3, label="Predicted Probabilities")
if example_files:
gr.Markdown("## 🎡 Try Examples")
gr.Examples(
examples=[[f] for f in example_files],
inputs=[audio_input],
outputs=[output_label],
fn=predict,
cache_examples=False
)
# Random example button
rnd_btn = gr.Button("🎲 Random Example")
rnd_btn.click(
fn=lambda: random.choice(example_files) if example_files else None,
outputs=[audio_input]
)
predict_btn.click(
fn=predict,
inputs=[audio_input],
outputs=[output_label],
show_progress=True
)
clear_btn.click(
fn=lambda: (None, {}),
outputs=[audio_input, output_label]
)
return demo
def main():
logger.info("Starting Gradio app...")
interface = build_interface()
interface.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)
logger.info("App launched.")
if __name__ == "__main__":
main()