File size: 4,060 Bytes
9258fc6
419d020
d921763
 
9258fc6
 
b1b1487
9258fc6
 
 
7a07632
419d020
9258fc6
 
419d020
 
 
 
b1b1487
9258fc6
419d020
9258fc6
 
419d020
7a07632
419d020
 
 
 
 
 
9258fc6
 
b1b1487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9258fc6
419d020
d921763
 
 
419d020
d921763
 
 
 
 
 
 
 
 
9258fc6
419d020
9258fc6
b1b1487
 
 
419d020
b1b1487
419d020
b1b1487
 
419d020
b1b1487
 
 
 
 
419d020
b1b1487
419d020
9258fc6
 
419d020
9258fc6
 
419d020
9258fc6
 
 
419d020
 
 
 
 
9258fc6
3b2fd42
9258fc6
419d020
9258fc6
4f0a44b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from transformers import pipeline
import pandas as pd
import plotly.express as px

# ------------------------------
# Load pretrained models
# ------------------------------
text_classifier = pipeline(
    "text-classification",
    model="j-hartmann/emotion-english-distilroberta-base",
    return_all_scores=True
)

audio_classifier = pipeline(
    "audio-classification",
    model="superb/wav2vec2-base-superb-er"
)

# ------------------------------
# Emotion to Emoji mapping
# ------------------------------
EMOJI_MAP = {
    "joy": "😊",
    "sadness": "😒",
    "anger": "😠",
    "fear": "😨",
    "love": "❀️",
    "surprise": "😲",
    "disgust": "🀒",
    "neutral": "😐"
}

# ------------------------------
# Fusion function
# ------------------------------
def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5):
    labels = set()
    if text_preds:
        labels |= {p['label'] for p in text_preds}
    if audio_preds:
        labels |= {p['label'] for p in audio_preds}
    scores = {l: 0.0 for l in labels}

    def normalize(preds):
        s = sum(p['score'] for p in preds)
        return {p['label']: p['score']/s for p in preds}

    if text_preds:
        t_norm = normalize(text_preds)
        for l in labels:
            scores[l] += w_text * t_norm.get(l, 0)
    if audio_preds:
        a_norm = normalize(audio_preds)
        for l in labels:
            scores[l] += w_audio * a_norm.get(l, 0)

    best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0)
    return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores}

# ------------------------------
# Create bar chart with emojis
# ------------------------------
def make_bar_chart(scores_dict, title="Emotion Scores"):
    df = pd.DataFrame({
        "Emotion": [f"{EMOJI_MAP.get(k, '')} {k}" for k in scores_dict.keys()],
        "Score": list(scores_dict.values())
    })
    fig = px.bar(df, x="Emotion", y="Score", text="Score",
                 title=title, range_y=[0,1],
                 color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold)
    fig.update_traces(texttemplate='%{text:.2f}', textposition='outside')
    fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False)
    return fig

# ------------------------------
# Prediction function
# ------------------------------
def predict(text, audio, w_text, w_audio):
    text_preds, audio_preds = None, None
    if text:
        text_preds = text_classifier(text)[0]
    if audio:
        audio_preds = audio_classifier(audio)
    fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio)

    # Bar charts
    charts = []
    if text_preds:
        charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores"))
    if audio_preds:
        charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores"))
    charts.append(make_bar_chart(fused['all_scores'], f"Fused Emotion Scores\nPrediction: {EMOJI_MAP.get(fused['fused_label'], '')} {fused['fused_label']}"))

    return charts

# ------------------------------
# Build Gradio interface with emojis
# ------------------------------
with gr.Blocks() as demo:
    gr.Markdown("## 🎭 Multimodal Emotion Classification (Text + Speech) 😎")

    with gr.Row():
        with gr.Column():
            txt = gr.Textbox(label="πŸ“ Text input", placeholder="Type something emotional...")
            aud = gr.Audio(type="filepath", label="🎀 Upload speech (wav/mp3)")
            w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πŸ”Ή Text weight (w_text)")
            w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πŸ”Ή Audio weight (w_audio)")
            btn = gr.Button("✨ Predict")
        with gr.Column():
            chart_output = gr.Plot(label="Emotion Scores")

    btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[chart_output]*3)  # 3 charts: text, audio, fused

demo.launch()