mrfakename commited on
Commit
d5c9390
·
verified ·
1 Parent(s): 21d1159

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import pipeline
4
+ import os
5
+
6
+ # Load the fine-tuned model using pipeline
7
+ model_path = "podcasts-org/detect-background-music"
8
+ classifier = pipeline("audio-classification", model=model_path, token=os.getenv("HF_TOKEN"))
9
+
10
+ def classify_audio(audio):
11
+ """Classify whether audio has background music or not."""
12
+ if audio is None:
13
+ return "Please provide an audio file"
14
+
15
+ # audio is a tuple of (sample_rate, audio_array)
16
+ sample_rate, audio_array = audio
17
+
18
+ # Convert to float32 and normalize if needed
19
+ if audio_array.dtype == np.int16:
20
+ audio_array = audio_array.astype(np.float32) / 32768.0
21
+ elif audio_array.dtype == np.int32:
22
+ audio_array = audio_array.astype(np.float32) / 2147483648.0
23
+
24
+ # Convert stereo to mono if needed
25
+ if len(audio_array.shape) > 1:
26
+ audio_array = audio_array.mean(axis=1)
27
+
28
+ # Use the pipeline for inference
29
+ # Pipeline expects dict with "array" and "sampling_rate" keys
30
+ predictions = classifier({"array": audio_array, "sampling_rate": sample_rate})
31
+
32
+ # Convert list of dicts to single dict for Gradio Label component
33
+ results = {pred["label"]: pred["score"] for pred in predictions}
34
+
35
+ return results
36
+
37
+ # Create Gradio interface
38
+ demo = gr.Interface(
39
+ fn=classify_audio,
40
+ inputs=gr.Audio(type="numpy", label="Upload Audio"),
41
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
42
+ title="Background Music Detection",
43
+ description="Upload an audio file to detect whether it contains background music (BGM) or not. Model: Whisper-base fine-tuned on podcasts-org/bgm dataset.",
44
+ examples=None
45
+ )
46
+
47
+ if __name__ == "__main__":
48
+ demo.launch()