|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import Wav2Vec2Model |
|
|
|
class Wav2VecIntent(nn.Module): |
|
def __init__(self, num_classes=31, pretrained_model="facebook/wav2vec2-large"): |
|
super().__init__() |
|
|
|
self.wav2vec = Wav2Vec2Model.from_pretrained(pretrained_model) |
|
|
|
|
|
hidden_size = self.wav2vec.config.hidden_size |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
self.attention = nn.Linear(hidden_size, 1) |
|
|
|
|
|
self.dropout = nn.Dropout(p=0.5) |
|
|
|
|
|
self.fc = nn.Linear(hidden_size, num_classes) |
|
|
|
def forward(self, input_values, attention_mask=None): |
|
|
|
outputs = self.wav2vec( |
|
input_values, |
|
attention_mask=attention_mask, |
|
return_dict=True |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
|
|
attn_weights = F.softmax(self.attention(hidden_states), dim=1) |
|
x = torch.sum(hidden_states * attn_weights, dim=1) |
|
|
|
|
|
x = self.dropout(x) |
|
|
|
|
|
x = self.fc(x) |
|
return x |