Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
import pandas as pd | |
import plotly.express as px | |
# ------------------------------ | |
# Load pretrained models | |
# ------------------------------ | |
text_classifier = pipeline( | |
"text-classification", | |
model="j-hartmann/emotion-english-distilroberta-base", | |
return_all_scores=True | |
) | |
audio_classifier = pipeline( | |
"audio-classification", | |
model="superb/wav2vec2-base-superb-er" | |
) | |
# ------------------------------ | |
# Emotion to Emoji mapping | |
# ------------------------------ | |
EMOJI_MAP = { | |
"joy": "π", | |
"sadness": "π’", | |
"anger": "π ", | |
"fear": "π¨", | |
"love": "β€οΈ", | |
"surprise": "π²", | |
"disgust": "π€’", | |
"neutral": "π" | |
} | |
# ------------------------------ | |
# Fusion function | |
# ------------------------------ | |
def fuse_predictions(text_preds=None, audio_preds=None, w_text=0.5, w_audio=0.5): | |
labels = set() | |
if text_preds: | |
labels |= {p['label'] for p in text_preds} | |
if audio_preds: | |
labels |= {p['label'] for p in audio_preds} | |
scores = {l: 0.0 for l in labels} | |
def normalize(preds): | |
s = sum(p['score'] for p in preds) | |
return {p['label']: p['score']/s for p in preds} | |
if text_preds: | |
t_norm = normalize(text_preds) | |
for l in labels: | |
scores[l] += w_text * t_norm.get(l, 0) | |
if audio_preds: | |
a_norm = normalize(audio_preds) | |
for l in labels: | |
scores[l] += w_audio * a_norm.get(l, 0) | |
best = max(scores.items(), key=lambda x: x[1]) if scores else ("none", 0) | |
return {"fused_label": best[0], "fused_score": round(best[1], 3), "all_scores": scores} | |
# ------------------------------ | |
# Create bar chart with emojis | |
# ------------------------------ | |
def make_bar_chart(scores_dict, title="Emotion Scores"): | |
df = pd.DataFrame({ | |
"Emotion": [f"{EMOJI_MAP.get(k, '')} {k}" for k in scores_dict.keys()], | |
"Score": list(scores_dict.values()) | |
}) | |
fig = px.bar(df, x="Emotion", y="Score", text="Score", | |
title=title, range_y=[0,1], | |
color="Emotion", color_discrete_sequence=px.colors.qualitative.Bold) | |
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside') | |
fig.update_layout(yaxis_title="Probability", xaxis_title="Emotion", showlegend=False) | |
return fig | |
# ------------------------------ | |
# Prediction function | |
# ------------------------------ | |
def predict(text, audio, w_text, w_audio): | |
text_preds, audio_preds = None, None | |
if text: | |
text_preds = text_classifier(text)[0] | |
if audio: | |
audio_preds = audio_classifier(audio) | |
fused = fuse_predictions(text_preds, audio_preds, w_text, w_audio) | |
# Bar charts | |
charts = [] | |
if text_preds: | |
charts.append(make_bar_chart({p['label']: p['score'] for p in text_preds}, "Text Emotion Scores")) | |
if audio_preds: | |
charts.append(make_bar_chart({p['label']: p['score'] for p in audio_preds}, "Audio Emotion Scores")) | |
charts.append(make_bar_chart(fused['all_scores'], f"Fused Emotion Scores\nPrediction: {EMOJI_MAP.get(fused['fused_label'], '')} {fused['fused_label']}")) | |
return charts | |
# ------------------------------ | |
# Build Gradio interface with emojis | |
# ------------------------------ | |
with gr.Blocks() as demo: | |
gr.Markdown("## π Multimodal Emotion Classification (Text + Speech) π") | |
with gr.Row(): | |
with gr.Column(): | |
txt = gr.Textbox(label="π Text input", placeholder="Type something emotional...") | |
aud = gr.Audio(type="filepath", label="π€ Upload speech (wav/mp3)") | |
w1 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πΉ Text weight (w_text)") | |
w2 = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, label="πΉ Audio weight (w_audio)") | |
btn = gr.Button("β¨ Predict") | |
with gr.Column(): | |
chart_output = gr.Plot(label="Emotion Scores") | |
btn.click(fn=predict, inputs=[txt, aud, w1, w2], outputs=[chart_output]*3) # 3 charts: text, audio, fused | |
demo.launch() | |