Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch.nn as nn | |
import audresample | |
import matplotlib.pyplot as plt | |
from matplotlib import colors as mcolors | |
import torch | |
import librosa | |
import numpy as np | |
import types | |
from transformers import AutoModelForAudioClassification | |
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model, | |
Wav2Vec2PreTrainedModel) | |
plt.style.use('seaborn-v0_8-whitegrid') | |
class ADV(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
def forward(self, x): | |
x = self.dense(x) | |
x = torch.tanh(x) | |
return self.out_proj(x) | |
class Dawn(Wav2Vec2PreTrainedModel): | |
r"""https://arxiv.org/abs/2203.07378""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.wav2vec2 = Wav2Vec2Model(config) | |
self.classifier = ADV(config) | |
def forward(self, x): | |
x -= x.mean(1, keepdim=True) | |
variance = (x * x).mean(1, keepdim=True) + 1e-7 | |
x = self.wav2vec2(x / variance.sqrt()) | |
return self.classifier(x.last_hidden_state.mean(1)) | |
def _forward(self, x): | |
'''x: (batch, audio-samples-16KHz)''' | |
x = (x + self.config.mean) / self.config.std # sgn | |
x = self.ssl_model(x, attention_mask=None).last_hidden_state | |
# pool | |
h = self.pool_model.sap_linear(x).tanh() | |
w = torch.matmul(h, self.pool_model.attention).softmax(1) | |
mu = (x * w).sum(1) | |
x = torch.cat( | |
[ | |
mu, | |
((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt() | |
], 1) | |
return self.ser_model(x) | |
# WavLM | |
device = 'cpu' | |
base = AutoModelForAudioClassification.from_pretrained( | |
'3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes', | |
trust_remote_code=True).to(device).eval() | |
base.forward = types.MethodType(_forward, base) | |
# Wav2Vec2 | |
dawn = Dawn.from_pretrained( | |
'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' | |
).to(device).eval() | |
# Wav2Small | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
import librosa | |
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model | |
from torch import nn | |
from transformers import PretrainedConfig | |
def _prenorm(x, attention_mask=None): | |
'''mean/var''' | |
if attention_mask is not None: | |
N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input | |
x -= x.sum(1, keepdim=True) / N | |
var = (x * x).sum(1, keepdim=True) / N | |
else: | |
x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div | |
var = (x * x).mean(1, keepdim=True) | |
return x / torch.sqrt(var + 1e-7) | |
class Spectrogram(nn.Module): | |
def __init__(self, | |
n_fft=64, # num cols of DFT | |
n_time=64, # num rows of DFT matrix | |
hop_length=32, | |
freeze_parameters=True): | |
super().__init__() | |
fft_window = librosa.filters.get_window('hann', n_time, fftbins=True) | |
fft_window = librosa.util.pad_center(fft_window, size=n_time) | |
out_channels = n_fft // 2 + 1 | |
(x, y) = np.meshgrid(np.arange(n_time), np.arange(n_fft)) | |
omega = np.exp(-2 * np.pi * 1j / n_time) | |
dft_matrix = np.power(omega, x * y) # (n_fft, n_time) | |
dft_matrix = dft_matrix * fft_window[None, :] | |
dft_matrix = dft_matrix[0 : out_channels, :] | |
dft_matrix = dft_matrix[:, None, :] | |
# ---- Assymetric DFT Non Square | |
self.conv_real = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) | |
self.conv_imag = nn.Conv1d(1, out_channels, n_fft, stride=hop_length, padding=0, bias=False) | |
self.conv_real.weight.data = torch.tensor(np.real(dft_matrix), dtype=self.conv_real.weight.dtype).to(self.conv_real.weight.device) | |
self.conv_imag.weight.data = torch.tensor(np.imag(dft_matrix), dtype=self.conv_imag.weight.dtype).to(self.conv_imag.weight.device) | |
if freeze_parameters: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, input): | |
x = input[:, None, :] | |
real = self.conv_real(x) | |
imag = self.conv_imag(x) | |
return real ** 2 + imag ** 2 # bs, mel, time-frames | |
class LogmelFilterBank(nn.Module): | |
def __init__(self, | |
sr=16000, | |
n_fft=64, | |
n_mels=26, # maxpool | |
fmin=0.0, | |
freeze_parameters=True): | |
super().__init__() | |
fmax = sr//2 | |
W2 = librosa.filters.mel(sr=sr, | |
n_fft=n_fft, | |
n_mels=n_mels, | |
fmin=fmin, | |
fmax=fmax).T | |
self.register_buffer('melW', torch.Tensor(W2)) | |
self.register_buffer('amin', torch.Tensor([1e-10])) | |
def forward(self, x): | |
x = torch.matmul(x[:, None, :, :].transpose(2, 3), self.melW) # changes melf not num frames | |
x = torch.where(x > self.amin, x, self.amin) # not in place | |
x = 10 * torch.log10(x) | |
return x | |
def length_after_conv_layer(_length, k=None, pad=None, stride=None): | |
return torch.floor( (_length + 2*pad - k) / stride + 1 ) | |
class Conv(nn.Module): | |
def __init__(self, c_in, c_out, k=3, stride=1, padding=1): | |
super().__init__() | |
self.conv = nn.Conv2d(c_in, c_out, k, stride=stride, padding=padding, bias=False) | |
self.norm = nn.BatchNorm2d(c_out) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.norm(x) | |
return torch.relu_(x) | |
class Vgg7(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.l1 = Conv( 1, 13) | |
self.l2 = Conv(13, 13) | |
self.l3 = Conv(13, 13) | |
self.maxpool_A = nn.MaxPool2d(3, | |
stride=2, | |
padding=1) | |
self.l4 = Conv(13, 13) | |
self.l5 = Conv(13, 13) | |
self.l6 = Conv(13, 13) | |
self.l7 = Conv(13, 13) | |
self.lin = nn.Conv2d(13, 13, 1, padding=0, stride=1) | |
self.sof = nn.Conv2d(13, 13, 1, padding=0, stride=1) # pool time - reshape mel into channels after pooling | |
self.spectrogram_extractor = Spectrogram() | |
self.logmel_extractor = LogmelFilterBank() | |
def final_length(self, L): | |
conv_kernel = [64, 3] # [nfft, maxpool] | |
conv_stride = [32, 2] # [hop_len, maxpool_stride] # consider only layers of stride > 1 | |
conv_pad = [0, 1] # [pad_stft, pad_maxpool] | |
for k, stride, pad in zip(conv_kernel, conv_stride, conv_pad): | |
L = length_after_conv_layer(L, k=k, stride=stride, pad=pad) | |
return L | |
def final_attention_mask(self, feature_vector_length, attention_mask=None): | |
non_padded_lengths = attention_mask.sum(1) | |
out_lengths = self.final_length(non_padded_lengths) # how can non_padded_lengths get exact 0 here DOES IT MEAN ATTNMASK WAS NOT FILLED? | |
out_lengths = out_lengths.to(torch.long) | |
bs, _ = attention_mask.shape | |
attention_mask = torch.ones((bs, feature_vector_length), | |
dtype=attention_mask.dtype, | |
device=attention_mask.device) | |
for b, _len in enumerate(out_lengths): | |
attention_mask[b, _len:] = 0 | |
return attention_mask | |
def forward(self, x, attention_mask=None): | |
x = _prenorm(x, | |
attention_mask=attention_mask) | |
x = self.spectrogram_extractor(x) | |
x = self.logmel_extractor(x) | |
x = self.l1(x) | |
x = self.l2(x) | |
x = self.l3(x) | |
x = self.maxpool_A(x) # reshape here? so these conv will have large kernel | |
x = self.l4(x) | |
x = self.l5(x) | |
x = self.l6(x) | |
x = self.l7(x) | |
if attention_mask is not None: | |
bs, _, t, _ = x.shape | |
a = self.final_attention_mask(feature_vector_length=t, | |
attention_mask=attention_mask)[:, None, :, None] | |
#print(a.shape, x.shape, '\n\n\n\n') | |
x = torch.masked_fill(x, a < 1, 0) | |
# mask also affects lin !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! | |
x = self.lin(x) * ( self.sof(x) -10000. * torch.logical_not(a) ).softmax(2) | |
else: | |
x = self.lin(x) * self.sof(x).softmax(2) | |
x = x.sum(2) # bs, ch, time-frames, HALF_MEL -> bs, ch, HALF_MEL | |
# -- | |
xT = x.transpose(1,2) | |
x = torch.cat([x, | |
torch.bmm(x, xT), # corr (chxmel) x (melxCH) | |
# torch.bmm(x, x), # corr ch * ch | |
# torch.bmm(xT, xT) # corr mel * mel | |
], 2) | |
# -- | |
return x.reshape(-1, 338) | |
class Wav2SmallConfig(PretrainedConfig): | |
model_type = "wav2vec2" | |
def __init__(self, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.half_mel = 13 | |
self.n_fft = 64 | |
self.n_time = 64 | |
self.hidden = 2 * self.half_mel * self.half_mel | |
self.hop = self.n_time // 2 | |
class Wav2Small(Wav2Vec2PreTrainedModel): | |
def __init__(self, | |
config): | |
super().__init__(config) | |
self.vgg7 = Vgg7() | |
self.adv = nn.Linear(config.hidden, 3) # 0=arousal, 1=dominance, 2=valence | |
def forward(self, x, attention_mask=None): | |
x = self.vgg7(x, attention_mask=attention_mask) | |
return self.adv(x) | |
def _ccc(x, y): | |
'''if len(x) = len(y) = 1 we have 0/0 as a&b can both be negative we should add 1e-7 to denominator protecting sign of denominator | |
to find sign of denominator and add 1e-7 if sgn>=0 or -1e-7 if sgn<0''' | |
mean_y = y.mean() | |
mean_x = x.mean() | |
a = x - mean_x | |
b = y - mean_y | |
L = (mean_x - mean_y).abs() * .1 * x.shape[0] | |
#print(L / ((mean_x - mean_y) **2 * x.shape[0])) | |
numerator = torch.dot(a, b) # L term if both a,b scalars dissallows 0 numerator [OFFICIAL CCC HAS L ONLY IN D] | |
denominator = torch.dot(a, a) + torch.dot(b, b) + L # if both a,b are equalscalars then the dots are all zero and ccc=1 | |
denominator = torch.where(denominator.sign() < 0, | |
denominator - 1e-7, | |
denominator + 1e-7) | |
ccc = numerator / denominator | |
return -ccc #+ F.l1_loss(a, b) | |
wav2small = Wav2Small.from_pretrained('audeering/wav2small').to(device).eval() | |
# Error figure for the first plot | |
fig_error, ax = plt.subplots(figsize=(8, 6)) | |
error_message = "Error: No .wav or Mic. audio provided." | |
ax.text(0.5, 0.5, error_message, | |
ha='center', | |
va='center', | |
fontsize=24, | |
color='gray', | |
fontweight='bold', | |
transform=ax.transAxes) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
ax.set_frame_on(True) | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['bottom'].set_visible(False) | |
ax.spines['left'].set_visible(False) | |
def process_audio(audio_filepath): | |
if audio_filepath is None: | |
return fig_error, fig_error | |
waveform, sample_rate = librosa.load(audio_filepath, sr=None) | |
# Resample audio to 16kHz if the sample rate is different | |
if sample_rate != 16000: | |
resampled_waveform_np = audresample.resample(waveform, sample_rate, 16000) | |
else: | |
resampled_waveform_np = waveform[None, :] | |
x = torch.from_numpy(resampled_waveform_np[:, :64000]).to(torch.float) # only 4s for speed | |
with torch.no_grad(): | |
logits_dawn = dawn(x).cpu().numpy()[0, :] | |
logits_wavlm = base(x).cpu().numpy()[0, :] | |
# 17K params | |
logits_wav2small = wav2small(x).cpu().numpy()[0, :] | |
# --- Plot 1: Wav2Vec2 vs Wav2Small Teacher Outputs --- | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
left_bars_data = logits_dawn.clip(0, 1) | |
right_bars_data = logits_wav2small.clip(0, 1) | |
bar_labels = ['\nArousal', '\nDominance', '\nValence'] | |
y_pos = np.arange(len(bar_labels)) | |
# Define colormaps for each category to ensure distinct colors | |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges] | |
left_filled_colors = [] | |
right_filled_colors = [] | |
background_colors = [] | |
# Assign specific shades for filled bars and background bars | |
for i, cmap in enumerate(category_colormaps): | |
left_filled_colors.append(cmap(0.74)) | |
right_filled_colors.append(cmap(0.64)) | |
background_colors.append(cmap(0.1)) | |
# Plot transparent background bars | |
for i in range(len(bar_labels)): | |
ax.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6) | |
ax.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6) | |
# Plot the filled bars for actual data | |
for i in range(len(bar_labels)): | |
ax.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6) | |
ax.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6) | |
# Add a central vertical axis divider | |
ax.axvline(0, color='black', linewidth=0.8, linestyle='--') | |
# Set x-axis limits and y-axis ticks/labels | |
ax.set_xlim(-1, 1) | |
ax.set_yticks(y_pos) | |
ax.set_yticklabels(bar_labels, fontsize=12) | |
# Custom formatter for x-axis to show absolute percentage values | |
def abs_tick_formatter(x, pos): | |
return f'{int(abs(x) * 100)}%' | |
ax.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter)) | |
# Set plot title and x-axis label | |
ax.set_title('', fontsize=16, pad=20) | |
ax.set_xlabel('Wav2Vev2 (Dawn) Wav2Small (17K param.)', fontsize=12) | |
# Remove top, right, and left spines for a cleaner look | |
ax.spines['top'].set_visible(False) | |
ax.spines['right'].set_visible(False) | |
ax.spines['left'].set_visible(False) | |
# Add annotations (percentage values) to the filled bars | |
for i in range(len(bar_labels)): | |
ax.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%', | |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold') | |
ax.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%', | |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold') | |
# -- PLOT 2 : WavLM / Wav2Small Teacher | |
fig_2, ax_2 = plt.subplots(figsize=(10, 6)) | |
left_bars_data = logits_wavlm.clip(0, 1) | |
right_bars_data = (.5 * logits_dawn + .5 * logits_wavlm).clip(0, 1) | |
bar_labels = ['\nArousal', '\nDominance', '\nValence'] | |
y_pos = np.arange(len(bar_labels)) | |
# Define colormaps for each category to ensure distinct colors | |
category_colormaps = [plt.cm.Blues, plt.cm.Greys, plt.cm.Oranges] | |
left_filled_colors = [] | |
right_filled_colors = [] | |
background_colors = [] | |
# Assign specific shades for filled bars and background bars | |
for i, cmap in enumerate(category_colormaps): | |
left_filled_colors.append(cmap(0.74)) | |
right_filled_colors.append(cmap(0.64)) | |
background_colors.append(cmap(0.1)) | |
# Plot transparent background bars | |
for i in range(len(bar_labels)): | |
ax_2.barh(y_pos[i], -1, color=background_colors[i], alpha=0.3, height=0.6) | |
ax_2.barh(y_pos[i], 1, color=background_colors[i], alpha=0.3, height=0.6) | |
# Plot the filled bars for actual data | |
for i in range(len(bar_labels)): | |
ax_2.barh(y_pos[i], -left_bars_data[i], color=left_filled_colors[i], alpha=1, height=0.6) | |
ax_2.barh(y_pos[i], right_bars_data[i], color=right_filled_colors[i], alpha=1, height=0.6) | |
# Add a central vertical axis divider | |
ax_2.axvline(0, color='black', linewidth=0.8, linestyle='--') | |
# Set x-axis limits and y-axis ticks/labels | |
ax_2.set_xlim(-1, 1) | |
ax_2.set_yticks(y_pos) | |
ax_2.set_yticklabels(bar_labels, fontsize=12) | |
# Custom formatter for x-axis to show absolute percentage values | |
def abs_tick_formatter(x, pos): | |
return f'{int(abs(x) * 100)}%' | |
ax_2.xaxis.set_major_formatter(plt.FuncFormatter(abs_tick_formatter)) | |
ax_2.set_title('', fontsize=16, pad=20) | |
ax_2.set_xlabel('WavLM (Baseline) Wav2Small Teacher (0.4B param.)', fontsize=12) | |
ax_2.spines['top'].set_visible(False) | |
ax_2.spines['right'].set_visible(False) | |
ax_2.spines['left'].set_visible(False) | |
# Add annotations (percentage values) to the filled bars | |
for i in range(len(bar_labels)): | |
ax_2.text(-left_bars_data[i] - 0.05, y_pos[i], f'{int(left_bars_data[i] * 100)}%', | |
va='center', ha='right', color=left_filled_colors[i], fontweight='bold') | |
ax_2.text(right_bars_data[i] + 0.05, y_pos[i], f'{int(right_bars_data[i] * 100)}%', | |
va='center', ha='left', color=right_filled_colors[i], fontweight='bold') | |
return fig, fig_2 | |
iface = gr.Interface( | |
fn=process_audio, | |
inputs=gr.Audio( | |
sources=["microphone", "upload"], | |
type="filepath", # Input type is file path | |
label='' | |
), | |
outputs=[ | |
gr.Plot(label="Wav2Vec2 vs Wav2Small (17K params) Plot"), # First plot output | |
gr.Plot(label="WavLM vs Wav2Small Teacher Plot"), # Second plot output | |
], | |
title='', | |
description='', | |
flagging_mode="never", # Disables flagging feature | |
examples=[ | |
"female-46-neutral.wav", | |
"female-20-happy.wav", | |
"male-60-angry.wav", | |
"male-27-sad.wav", | |
], | |
css="footer {visibility: hidden}" # Hides the Gradio footer | |
) | |
# Gradio Blocks for tabbed interface | |
with gr.Blocks() as demo: | |
# First tab for the existing Arousal/Dominance/Valence plots | |
with gr.Tab(label="Arousal / Dominance / Valence"): | |
iface.render() | |
# Second tab for CCC (Concordance Correlation Coefficient) information | |
with gr.Tab(label="CCC"): | |
gr.Markdown('''<table style="width:500px"><tr><th colspan=5 >CCC MSP Podcast v1.7</th></tr> | |
<tr> <td> </td><td>Arousal</td> <td>Dominance</td> <td>Valence</td> <td> Associated Paper </td> </tr> | |
<tr> <td> <a href="https://huggingface.co/audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim">Wav2Vec2</a></td><td>0.744</td><td>0.655</td><td> 0.638 </td><td> <a href="https://arxiv.org/abs/2203.07378">arXiv</a> </td> </tr> | |
<tr> <td> <a href="https://huggingface.co/dkounadis/wav2small">Wav2Small Teacher</a></td><td> 0.762 </td> <td> 0.684 </td><td> 0.676 </td><td> <a href="https://arxiv.org/abs/2408.13920">arXiv</a> </td> </tr> | |
</table> | |
''') | |
# Launch the Gradio application | |
if __name__ == "__main__": | |
demo.launch(share=False) | |