Spaces:
Running
Running
# coding: utf-8 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import numpy as np | |
import math | |
from torch.nn.functional import silu | |
from torch.nn.functional import softplus | |
from einops import rearrange, einsum | |
from torch import Tensor | |
from torch_geometric.nn import GATConv, RGCNConv, TransformerConv | |
class PositionWiseFeedForward(nn.Module): | |
def __init__(self, input_dim, hidden_dim, dropout=0.1): | |
super().__init__() | |
self.layer_1 = nn.Linear(input_dim, hidden_dim) | |
self.layer_2 = nn.Linear(hidden_dim, input_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.layer_1(x) | |
x = F.gelu(x) # Более плавная активация | |
x = self.dropout(x) | |
return self.layer_2(x) | |
class AddAndNorm(nn.Module): | |
def __init__(self, input_dim, dropout=0.1): | |
super().__init__() | |
self.norm = nn.LayerNorm(input_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, residual): | |
return self.norm(x + self.dropout(residual)) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super().__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
position = torch.arange(max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
pe = torch.zeros(max_len, d_model) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
x = x + self.pe[: x.size(1)].detach() # Отключаем градиенты | |
return self.dropout(x) | |
class TransformerEncoderLayer(nn.Module): | |
def __init__(self, input_dim, num_heads, dropout=0.1, positional_encoding=False): | |
super().__init__() | |
self.input_dim = input_dim | |
self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout, batch_first=True) | |
# self.self_attention = MHA( | |
# embed_dim=input_dim, | |
# num_heads=num_heads, | |
# dropout=dropout, | |
# # bias=True, | |
# use_flash_attn=True | |
# ) | |
self.feed_forward = PositionWiseFeedForward(input_dim, input_dim, dropout=dropout) | |
self.add_norm_after_attention = AddAndNorm(input_dim, dropout=dropout) | |
self.add_norm_after_ff = AddAndNorm(input_dim, dropout=dropout) | |
self.positional_encoding = PositionalEncoding(input_dim) if positional_encoding else None | |
def forward(self, key, value, query): | |
if self.positional_encoding: | |
key = self.positional_encoding(key) | |
value = self.positional_encoding(value) | |
query = self.positional_encoding(query) | |
attn_output, _ = self.self_attention(query, key, value, need_weights=False) | |
# attn_output = self.self_attention(query, key, value) | |
x = self.add_norm_after_attention(attn_output, query) | |
ff_output = self.feed_forward(x) | |
x = self.add_norm_after_ff(ff_output, x) | |
return x | |
class GAL(nn.Module): | |
def __init__(self, input_dim_F1, input_dim_F2, gated_dim, dropout_rate): | |
super(GAL, self).__init__() | |
self.WF1 = nn.Parameter(torch.Tensor(input_dim_F1, gated_dim)) | |
self.WF2 = nn.Parameter(torch.Tensor(input_dim_F2, gated_dim)) | |
init.xavier_uniform_(self.WF1) | |
init.xavier_uniform_(self.WF2) | |
dim_size_f = input_dim_F1 + input_dim_F2 | |
self.WF = nn.Parameter(torch.Tensor(dim_size_f, gated_dim)) | |
init.xavier_uniform_(self.WF) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, f1, f2): | |
h_f1 = self.dropout(torch.tanh(torch.matmul(f1, self.WF1))) | |
h_f2 = self.dropout(torch.tanh(torch.matmul(f2, self.WF2))) | |
# print(h_f1.shape, h_f2.shape, self.WF.shape, torch.cat([f1, f2], dim=1).shape) | |
z_f = torch.softmax(self.dropout(torch.matmul(torch.cat([f1, f2], dim=1), self.WF)), dim=1) | |
h_f = z_f*h_f1 + (1 - z_f)*h_f2 | |
return h_f | |
class GraphFusionLayer(nn.Module): | |
def __init__(self, hidden_dim, dropout=0.0, heads=2, out_mean=True): | |
super().__init__() | |
self.out_mean = out_mean | |
# # Проекционные слои для признаков | |
self.proj_audio = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
nn.Dropout(dropout) | |
) | |
self.proj_text = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
nn.Dropout(dropout) | |
) | |
# Графовые слои | |
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads) | |
self.gat2 = GATConv(hidden_dim*heads, hidden_dim) | |
# Финальная проекция | |
self.fc = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.LayerNorm(hidden_dim), | |
nn.Dropout(dropout) | |
) | |
def build_complete_graph(self, num_nodes): | |
# Создаем полный граф (каждый узел соединен со всеми) | |
edge_index = [] | |
for i in range(num_nodes): | |
for j in range(num_nodes): | |
if i != j: | |
edge_index.append([i, j]) | |
return torch.tensor(edge_index).t().contiguous() | |
def forward(self, audio_stats, text_stats): | |
""" | |
audio_stats: [batch_size, hidden_dim] | |
text_stats: [batch_size, hidden_dim] | |
""" | |
batch_size = audio_stats.size(0) | |
# Проекция признаков | |
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim] | |
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim] | |
# Объединение узлов (аудио и текст попеременно) | |
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim] | |
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim] | |
# Построение графа (полный граф для каждого элемента батча) | |
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст | |
edge_index = edge_index.to(audio_stats.device) | |
# Применение GAT | |
x = F.relu(self.gat1(nodes, edge_index)) | |
x = self.gat2(x, edge_index) | |
# Разделяем обратно аудио и текст | |
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim] | |
if self.out_mean: | |
# Усреднение по модальностям | |
fused = torch.mean(x, dim=1) # [batch_size, hidden_dim] | |
return self.fc(fused) | |
else: | |
return x | |
class GraphFusionLayerAtt(nn.Module): | |
def __init__(self, hidden_dim, heads=2): | |
super().__init__() | |
# Проекционные слои для признаков | |
self.proj_audio = nn.Linear(hidden_dim, hidden_dim) | |
self.proj_text = nn.Linear(hidden_dim, hidden_dim) | |
# Графовые слои | |
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads) | |
self.gat2 = GATConv(hidden_dim*heads, hidden_dim) | |
self.attention_fusion = nn.Linear(hidden_dim, 1) | |
# Финальная проекция | |
self.fc = nn.Linear(hidden_dim, hidden_dim) | |
def build_complete_graph(self, num_nodes): | |
# Создаем полный граф (каждый узел соединен со всеми) | |
edge_index = [] | |
for i in range(num_nodes): | |
for j in range(num_nodes): | |
if i != j: | |
edge_index.append([i, j]) | |
return torch.tensor(edge_index).t().contiguous() | |
def forward(self, audio_stats, text_stats): | |
""" | |
audio_stats: [batch_size, hidden_dim] | |
text_stats: [batch_size, hidden_dim] | |
""" | |
batch_size = audio_stats.size(0) | |
# Проекция признаков | |
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim] | |
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim] | |
# Объединение узлов (аудио и текст попеременно) | |
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim] | |
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim] | |
# Построение графа (полный граф для каждого элемента батча) | |
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст | |
edge_index = edge_index.to(audio_stats.device) | |
# Применение GAT | |
x = F.relu(self.gat1(nodes, edge_index)) | |
x = self.gat2(x, edge_index) | |
# Разделяем обратно аудио и текст | |
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim] | |
# Усреднение по модальностям | |
# fused = torch.mean(x, dim=1) # [batch_size, hidden_dim] | |
weights = F.softmax(self.attention_fusion(x), dim=1) | |
fused = torch.sum(weights * x, dim=1) # [batch_size, hidden_dim] | |
return self.fc(fused) | |
# Full code see https://github.com/leson502/CORECT_EMNLP2023/tree/master/corect/model | |
class GNN(nn.Module): | |
def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, gcn_conv, use_graph_transformer, graph_transformer_nheads): | |
super(GNN, self).__init__() | |
self.gcn_conv = gcn_conv | |
self.use_graph_transformer=use_graph_transformer | |
self.num_modals = num_modals | |
if self.gcn_conv == "rgcn": | |
print("GNN --> Use RGCN") | |
self.conv1 = RGCNConv(g_dim, h1_dim, num_relations) | |
if self.use_graph_transformer: | |
print("GNN --> Use Graph Transformer") | |
in_dim = h1_dim | |
self.conv2 = TransformerConv(in_dim, h2_dim, heads=graph_transformer_nheads, concat=True) | |
self.bn = nn.BatchNorm1d(h2_dim * graph_transformer_nheads) | |
def forward(self, node_features, node_type, edge_index, edge_type): | |
print(node_features.shape, edge_index.shape, edge_type.shape) | |
if self.gcn_conv == "rgcn": | |
x = self.conv1(node_features, edge_index, edge_type) | |
if self.use_graph_transformer: | |
x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index))) | |
return x | |
class GraphModel(nn.Module): | |
def __init__(self, g_dim, h1_dim, h2_dim, device, modalities, wp, wf, edge_type, gcn_conv, use_graph_transformer, graph_transformer_nheads): | |
super(GraphModel, self).__init__() | |
self.n_modals = len(modalities) | |
self.wp = wp | |
self.wf = wf | |
self.device = device | |
self.gcn_conv=gcn_conv | |
self.use_graph_transformer=use_graph_transformer | |
print(f"GraphModel --> Edge type: {edge_type}") | |
print(f"GraphModel --> Window past: {wp}") | |
print(f"GraphModel --> Window future: {wf}") | |
edge_temp = "temp" in edge_type | |
edge_multi = "multi" in edge_type | |
edge_type_to_idx = {} | |
if edge_temp: | |
temporal = [-1, 1, 0] | |
for j in temporal: | |
for k in range(self.n_modals): | |
edge_type_to_idx[str(j) + str(k) + str(k)] = len(edge_type_to_idx) | |
else: | |
for j in range(self.n_modals): | |
edge_type_to_idx['0' + str(j) + str(j)] = len(edge_type_to_idx) | |
if edge_multi: | |
for j in range(self.n_modals): | |
for k in range(self.n_modals): | |
if (j != k): | |
edge_type_to_idx['0' + str(j) + str(k)] = len(edge_type_to_idx) | |
self.edge_type_to_idx = edge_type_to_idx | |
self.num_relations = len(edge_type_to_idx) | |
self.edge_multi = edge_multi | |
self.edge_temp = edge_temp | |
self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, self.gcn_conv, self.use_graph_transformer, graph_transformer_nheads) | |
def forward(self, x, lengths): | |
# print(f"x shape: {x.shape}, lengths: {lengths}, lengths.shape: {lengths.shape}") | |
node_features = feature_packing(x, lengths) | |
node_type, edge_index, edge_type, edge_index_lengths = \ | |
self.batch_graphify(lengths) | |
out_gnn = self.gnn(node_features, node_type, edge_index, edge_type) | |
out_gnn = multi_concat(out_gnn, lengths, self.n_modals) | |
return out_gnn | |
def batch_graphify(self, lengths): | |
node_type, edge_index, edge_type, edge_index_lengths = [], [], [], [] | |
edge_type_lengths = [0] * len(self.edge_type_to_idx) | |
lengths = lengths.tolist() | |
sum_length = 0 | |
total_length = sum(lengths) | |
batch_size = len(lengths) | |
for k in range(self.n_modals): | |
for j in range(batch_size): | |
cur_len = lengths[j] | |
node_type.extend([k] * cur_len) | |
for j in range(batch_size): | |
cur_len = lengths[j] | |
perms = self.edge_perms(cur_len, total_length) | |
edge_index_lengths.append(len(perms)) | |
for item in perms: | |
vertices = item[0] | |
neighbor = item[1] | |
edge_index.append(torch.tensor([vertices + sum_length, neighbor + sum_length])) | |
if vertices % total_length > neighbor % total_length: | |
temporal_type = 1 | |
elif vertices % total_length < neighbor % total_length: | |
temporal_type = -1 | |
else: | |
temporal_type = 0 | |
edge_type.append(self.edge_type_to_idx[str(temporal_type) | |
+ str(node_type[vertices + sum_length]) | |
+ str(node_type[neighbor + sum_length])]) | |
sum_length += cur_len | |
node_type = torch.tensor(node_type).long().to(self.device) | |
edge_index = torch.stack(edge_index).t().contiguous().to(self.device) # [2, E] | |
edge_type = torch.tensor(edge_type).long().to(self.device) # [E] | |
edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device) # [B] | |
return node_type, edge_index, edge_type, edge_index_lengths | |
def edge_perms(self, length, total_lengths): | |
all_perms = set() | |
array = np.arange(length) | |
for j in range(length): | |
if self.wp == -1 and self.wf == -1: | |
eff_array = array | |
elif self.wp == -1: # use all past context | |
eff_array = array[: min(length, j + self.wf)] | |
elif self.wf == -1: # use all future context | |
eff_array = array[max(0, j - self.wp) :] | |
else: | |
eff_array = array[ | |
max(0, j - self.wp) : min(length, j + self.wf) | |
] | |
perms = set() | |
for k in range(self.n_modals): | |
node_index = j + k * total_lengths | |
if self.edge_temp == True: | |
for item in eff_array: | |
perms.add((node_index, item + k * total_lengths)) | |
else: | |
perms.add((node_index, node_index)) | |
if self.edge_multi == True: | |
for l in range(self.n_modals): | |
if l != k: | |
perms.add((node_index, j + l * total_lengths)) | |
all_perms = all_perms.union(perms) | |
return list(all_perms) | |
def feature_packing(multimodal_feature, lengths): | |
batch_size = lengths.size(0) | |
# print(multimodal_feature.shape, batch_size, lengths.shape) | |
node_features = [] | |
for feature in multimodal_feature: | |
for j in range(batch_size): | |
cur_len = lengths[j].item() | |
# print(f"feature.shape: {feature.shape}, j: {j}, cur_len: {cur_len}") | |
node_features.append(feature[j,:cur_len]) | |
node_features = torch.cat(node_features, dim=0) | |
return node_features | |
def multi_concat(nodes_feature, lengths, n_modals): | |
sum_length = lengths.sum().item() | |
feature = [] | |
for j in range(n_modals): | |
feature.append(nodes_feature[j * sum_length : (j + 1) * sum_length]) | |
feature = torch.cat(feature, dim=-1) | |
return feature | |
class RMSNorm(nn.Module): | |
def __init__(self, d_model: int, eps: float = 1e-8) -> None: | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(d_model)) | |
def forward(self, x: Tensor) -> Tensor: | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight | |
class Mamba(nn.Module): | |
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, pooling=None): | |
super().__init__() | |
mamba_par = { | |
'd_input' : d_input, | |
'd_model' : d_model, | |
'd_state' : d_state, | |
'd_discr' : d_discr, | |
'ker_size': ker_size | |
} | |
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)]) | |
self.fc_out = nn.Linear(d_input, num_classes) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def forward(self, seq, cache=None): | |
seq = torch.tensor(self.embedding(seq)).to(self.device) | |
for mamba, norm in self.layers: | |
out, cache = mamba(norm(seq), cache) | |
seq = out + seq | |
return self.fc_out(seq.mean(dim = 1)) | |
class MambaBlock(nn.Module): | |
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4): | |
super().__init__() | |
d_discr = d_discr if d_discr is not None else d_model // 16 | |
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False) | |
self.out_proj = nn.Linear(d_model, d_input, bias=False) | |
self.s_B = nn.Linear(d_model, d_state, bias=False) | |
self.s_C = nn.Linear(d_model, d_state, bias=False) | |
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),) | |
self.conv = nn.Conv1d( | |
in_channels=d_model, | |
out_channels=d_model, | |
kernel_size=ker_size, | |
padding=ker_size - 1, | |
groups=d_model, | |
bias=True, | |
) | |
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1)) | |
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float)) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def forward(self, seq, cache=None): | |
b, l, d = seq.shape | |
(prev_hid, prev_inp) = cache if cache is not None else (None, None) | |
a, b = self.in_proj(seq).chunk(2, dim=-1) | |
x = rearrange(a, 'b l d -> b d l') | |
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1) | |
a = self.conv(x)[..., :l] | |
a = rearrange(a, 'b d l -> b l d') | |
a = silu(a) | |
a, hid = self.ssm(a, prev_hid=prev_hid) | |
b = silu(b) | |
out = a * b | |
out = self.out_proj(out) | |
if cache: | |
cache = (hid.squeeze(), x[..., 1:]) | |
return out, cache | |
def ssm(self, seq, prev_hid): | |
A = -self.A | |
D = +self.D | |
B = self.s_B(seq) | |
C = self.s_C(seq) | |
s = softplus(D + self.s_D(seq)) | |
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s') | |
B_bar = einsum( B, s, 'b l s, b l d -> b l d s') | |
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s') | |
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid) | |
out = einsum(hid, C, 'b l d s, b l s -> b l d') | |
out = out + D * seq | |
return out, hid | |
def _hid_states(self, A, X, prev_hid=None): | |
b, l, d, s = A.shape | |
A = rearrange(A, 'b l d s -> l b d s') | |
X = rearrange(X, 'b l d s -> l b d s') | |
if prev_hid is not None: | |
return rearrange(A * prev_hid + X, 'l b d s -> b l d s') | |
h = torch.zeros(b, d, s, device=self.device) | |
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1) | |