File size: 4,445 Bytes
55d0a43 6f454d6 3b63824 6f454d6 3b63824 826a6ea 3b63824 826a6ea 6f454d6 3b63824 826a6ea 3b63824 6f454d6 7d51c4b d62b428 0c98b0c e6945d0 6f454d6 3b63824 6f454d6 3b63824 6f454d6 3b63824 6f454d6 3b63824 6f454d6 fe3d61a 6f454d6 3b63824 fe3d61a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import glob
import random
import logging
import torch
import gradio as gr
from matchboxnet.model import MatchboxNetForAudioClassification
from matchboxnet.feature_extraction import MatchboxNetFeatureExtractor
def setup_logging():
logging.basicConfig(level=logging.INFO)
return logging.getLogger(__name__)
logger = setup_logging()
# Device selection
if torch.cuda.is_available():
device = "cuda"
logger.info("Using CUDA for inference.")
elif torch.backends.mps.is_available():
device = "mps"
logger.info("Using MPS for inference.")
else:
device = "cpu"
logger.info("Using CPU for inference.")
def load_model():
auth_token = os.getenv("HF_TOKEN")
model = MatchboxNetForAudioClassification.from_pretrained(
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
trust_remote_code=True
).to(device)
feature_extractor = MatchboxNetFeatureExtractor.from_pretrained(
"Panga-Azazia/matchboxnet3x2x64-bambara-a-c",
)
model.config.id2label = {int(k): v for k, v in model.config.id2label.items()}
return model, feature_extractor
MODEL, FEATURE_EXTRACTOR = load_model()
LABELS = [MODEL.config.id2label[i] for i in sorted(MODEL.config.id2label.keys())]
def get_example_files(directory="./examples"):
if not os.path.exists(directory):
logger.warning(f"Examples directory {directory} not found.")
return []
files = glob.glob(f"{directory}/**/**")
random.shuffle(files)
#audio_extensions = ['.wav', '.mp3', '.m4a', '.flac', '.ogg']
#files = [os.path.abspath(os.path.join(directory, f))
# for f in os.listdir(directory)
# if any(f.lower().endswith(ext) for ext in audio_extensions)]
logger.info(f"Found {len(files)} example(s) files.")
# return up to first 5 examples
return files[:5]
def predict(audio):
if isinstance(audio, str):
batch = FEATURE_EXTRACTOR(audio, sampling_rate=None, return_tensors="pt")
else:
sr, waveform = audio
batch = FEATURE_EXTRACTOR(waveform, sampling_rate=sr, return_tensors="pt")
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = MODEL(**batch)
probs = torch.softmax(outputs.logits, dim=-1).squeeze().tolist()
pred_idx = int(torch.tensor(outputs.logits).argmax().item())
return {lbl: float(probs[i]) for i, lbl in enumerate(LABELS)}
def build_interface():
example_files = get_example_files()
with gr.Blocks(title="Bambara Audio Classification (ayi, awΙ, foyi)") as demo:
gr.Markdown(
"""
# π§ Bambara Audio Classification
Classify Bambara audio into **ayi**, **awΙ**, and **foyi** using MatchboxNet.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="π€ Record or Upload Audio",
type="filepath",
sources=["microphone", "upload"] # CORRECTION: 'source' β 'sources'
)
predict_btn = gr.Button("π Classify Audio", variant="primary")
clear_btn = gr.Button("ποΈ Clear", variant="secondary")
with gr.Column():
output_label = gr.Label(num_top_classes=3, label="Predicted Probabilities")
if example_files:
gr.Markdown("## π΅ Try Examples")
gr.Examples(
examples=[[f] for f in example_files],
inputs=[audio_input],
outputs=[output_label],
fn=predict,
cache_examples=False
)
# Random example button
rnd_btn = gr.Button("π² Random Example")
rnd_btn.click(
fn=lambda: random.choice(example_files) if example_files else None,
outputs=[audio_input]
)
predict_btn.click(
fn=predict,
inputs=[audio_input],
outputs=[output_label],
show_progress=True
)
clear_btn.click(
fn=lambda: (None, {}),
outputs=[audio_input, output_label]
)
return demo
def main():
logger.info("Starting Gradio app...")
interface = build_interface()
interface.launch(
share=False,
server_name="0.0.0.0",
server_port=7860
)
logger.info("App launched.")
if __name__ == "__main__":
main() |