# coding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init import numpy as np import math from torch.nn.functional import silu from torch.nn.functional import softplus from einops import rearrange, einsum from torch import Tensor from torch_geometric.nn import GATConv, RGCNConv, TransformerConv class PositionWiseFeedForward(nn.Module): def __init__(self, input_dim, hidden_dim, dropout=0.1): super().__init__() self.layer_1 = nn.Linear(input_dim, hidden_dim) self.layer_2 = nn.Linear(hidden_dim, input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.layer_1(x) x = F.gelu(x) # Более плавная активация x = self.dropout(x) return self.layer_2(x) class AddAndNorm(nn.Module): def __init__(self, input_dim, dropout=0.1): super().__init__() self.norm = nn.LayerNorm(input_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, residual): return self.norm(x + self.dropout(residual)) class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): x = x + self.pe[: x.size(1)].detach() # Отключаем градиенты return self.dropout(x) class TransformerEncoderLayer(nn.Module): def __init__(self, input_dim, num_heads, dropout=0.1, positional_encoding=False): super().__init__() self.input_dim = input_dim self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout, batch_first=True) # self.self_attention = MHA( # embed_dim=input_dim, # num_heads=num_heads, # dropout=dropout, # # bias=True, # use_flash_attn=True # ) self.feed_forward = PositionWiseFeedForward(input_dim, input_dim, dropout=dropout) self.add_norm_after_attention = AddAndNorm(input_dim, dropout=dropout) self.add_norm_after_ff = AddAndNorm(input_dim, dropout=dropout) self.positional_encoding = PositionalEncoding(input_dim) if positional_encoding else None def forward(self, key, value, query): if self.positional_encoding: key = self.positional_encoding(key) value = self.positional_encoding(value) query = self.positional_encoding(query) attn_output, _ = self.self_attention(query, key, value, need_weights=False) # attn_output = self.self_attention(query, key, value) x = self.add_norm_after_attention(attn_output, query) ff_output = self.feed_forward(x) x = self.add_norm_after_ff(ff_output, x) return x class GAL(nn.Module): def __init__(self, input_dim_F1, input_dim_F2, gated_dim, dropout_rate): super(GAL, self).__init__() self.WF1 = nn.Parameter(torch.Tensor(input_dim_F1, gated_dim)) self.WF2 = nn.Parameter(torch.Tensor(input_dim_F2, gated_dim)) init.xavier_uniform_(self.WF1) init.xavier_uniform_(self.WF2) dim_size_f = input_dim_F1 + input_dim_F2 self.WF = nn.Parameter(torch.Tensor(dim_size_f, gated_dim)) init.xavier_uniform_(self.WF) self.dropout = nn.Dropout(dropout_rate) def forward(self, f1, f2): h_f1 = self.dropout(torch.tanh(torch.matmul(f1, self.WF1))) h_f2 = self.dropout(torch.tanh(torch.matmul(f2, self.WF2))) # print(h_f1.shape, h_f2.shape, self.WF.shape, torch.cat([f1, f2], dim=1).shape) z_f = torch.softmax(self.dropout(torch.matmul(torch.cat([f1, f2], dim=1), self.WF)), dim=1) h_f = z_f*h_f1 + (1 - z_f)*h_f2 return h_f class GraphFusionLayer(nn.Module): def __init__(self, hidden_dim, dropout=0.0, heads=2, out_mean=True): super().__init__() self.out_mean = out_mean # # Проекционные слои для признаков self.proj_audio = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.proj_text = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Графовые слои self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads) self.gat2 = GATConv(hidden_dim*heads, hidden_dim) # Финальная проекция self.fc = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) def build_complete_graph(self, num_nodes): # Создаем полный граф (каждый узел соединен со всеми) edge_index = [] for i in range(num_nodes): for j in range(num_nodes): if i != j: edge_index.append([i, j]) return torch.tensor(edge_index).t().contiguous() def forward(self, audio_stats, text_stats): """ audio_stats: [batch_size, hidden_dim] text_stats: [batch_size, hidden_dim] """ batch_size = audio_stats.size(0) # Проекция признаков x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim] x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim] # Объединение узлов (аудио и текст попеременно) nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim] nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim] # Построение графа (полный граф для каждого элемента батча) edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст edge_index = edge_index.to(audio_stats.device) # Применение GAT x = F.relu(self.gat1(nodes, edge_index)) x = self.gat2(x, edge_index) # Разделяем обратно аудио и текст x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim] if self.out_mean: # Усреднение по модальностям fused = torch.mean(x, dim=1) # [batch_size, hidden_dim] return self.fc(fused) else: return x class GraphFusionLayerAtt(nn.Module): def __init__(self, hidden_dim, heads=2): super().__init__() # Проекционные слои для признаков self.proj_audio = nn.Linear(hidden_dim, hidden_dim) self.proj_text = nn.Linear(hidden_dim, hidden_dim) # Графовые слои self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads) self.gat2 = GATConv(hidden_dim*heads, hidden_dim) self.attention_fusion = nn.Linear(hidden_dim, 1) # Финальная проекция self.fc = nn.Linear(hidden_dim, hidden_dim) def build_complete_graph(self, num_nodes): # Создаем полный граф (каждый узел соединен со всеми) edge_index = [] for i in range(num_nodes): for j in range(num_nodes): if i != j: edge_index.append([i, j]) return torch.tensor(edge_index).t().contiguous() def forward(self, audio_stats, text_stats): """ audio_stats: [batch_size, hidden_dim] text_stats: [batch_size, hidden_dim] """ batch_size = audio_stats.size(0) # Проекция признаков x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim] x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim] # Объединение узлов (аудио и текст попеременно) nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim] nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim] # Построение графа (полный граф для каждого элемента батча) edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст edge_index = edge_index.to(audio_stats.device) # Применение GAT x = F.relu(self.gat1(nodes, edge_index)) x = self.gat2(x, edge_index) # Разделяем обратно аудио и текст x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim] # Усреднение по модальностям # fused = torch.mean(x, dim=1) # [batch_size, hidden_dim] weights = F.softmax(self.attention_fusion(x), dim=1) fused = torch.sum(weights * x, dim=1) # [batch_size, hidden_dim] return self.fc(fused) # Full code see https://github.com/leson502/CORECT_EMNLP2023/tree/master/corect/model class GNN(nn.Module): def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, gcn_conv, use_graph_transformer, graph_transformer_nheads): super(GNN, self).__init__() self.gcn_conv = gcn_conv self.use_graph_transformer=use_graph_transformer self.num_modals = num_modals if self.gcn_conv == "rgcn": print("GNN --> Use RGCN") self.conv1 = RGCNConv(g_dim, h1_dim, num_relations) if self.use_graph_transformer: print("GNN --> Use Graph Transformer") in_dim = h1_dim self.conv2 = TransformerConv(in_dim, h2_dim, heads=graph_transformer_nheads, concat=True) self.bn = nn.BatchNorm1d(h2_dim * graph_transformer_nheads) def forward(self, node_features, node_type, edge_index, edge_type): print(node_features.shape, edge_index.shape, edge_type.shape) if self.gcn_conv == "rgcn": x = self.conv1(node_features, edge_index, edge_type) if self.use_graph_transformer: x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index))) return x class GraphModel(nn.Module): def __init__(self, g_dim, h1_dim, h2_dim, device, modalities, wp, wf, edge_type, gcn_conv, use_graph_transformer, graph_transformer_nheads): super(GraphModel, self).__init__() self.n_modals = len(modalities) self.wp = wp self.wf = wf self.device = device self.gcn_conv=gcn_conv self.use_graph_transformer=use_graph_transformer print(f"GraphModel --> Edge type: {edge_type}") print(f"GraphModel --> Window past: {wp}") print(f"GraphModel --> Window future: {wf}") edge_temp = "temp" in edge_type edge_multi = "multi" in edge_type edge_type_to_idx = {} if edge_temp: temporal = [-1, 1, 0] for j in temporal: for k in range(self.n_modals): edge_type_to_idx[str(j) + str(k) + str(k)] = len(edge_type_to_idx) else: for j in range(self.n_modals): edge_type_to_idx['0' + str(j) + str(j)] = len(edge_type_to_idx) if edge_multi: for j in range(self.n_modals): for k in range(self.n_modals): if (j != k): edge_type_to_idx['0' + str(j) + str(k)] = len(edge_type_to_idx) self.edge_type_to_idx = edge_type_to_idx self.num_relations = len(edge_type_to_idx) self.edge_multi = edge_multi self.edge_temp = edge_temp self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, self.gcn_conv, self.use_graph_transformer, graph_transformer_nheads) def forward(self, x, lengths): # print(f"x shape: {x.shape}, lengths: {lengths}, lengths.shape: {lengths.shape}") node_features = feature_packing(x, lengths) node_type, edge_index, edge_type, edge_index_lengths = \ self.batch_graphify(lengths) out_gnn = self.gnn(node_features, node_type, edge_index, edge_type) out_gnn = multi_concat(out_gnn, lengths, self.n_modals) return out_gnn def batch_graphify(self, lengths): node_type, edge_index, edge_type, edge_index_lengths = [], [], [], [] edge_type_lengths = [0] * len(self.edge_type_to_idx) lengths = lengths.tolist() sum_length = 0 total_length = sum(lengths) batch_size = len(lengths) for k in range(self.n_modals): for j in range(batch_size): cur_len = lengths[j] node_type.extend([k] * cur_len) for j in range(batch_size): cur_len = lengths[j] perms = self.edge_perms(cur_len, total_length) edge_index_lengths.append(len(perms)) for item in perms: vertices = item[0] neighbor = item[1] edge_index.append(torch.tensor([vertices + sum_length, neighbor + sum_length])) if vertices % total_length > neighbor % total_length: temporal_type = 1 elif vertices % total_length < neighbor % total_length: temporal_type = -1 else: temporal_type = 0 edge_type.append(self.edge_type_to_idx[str(temporal_type) + str(node_type[vertices + sum_length]) + str(node_type[neighbor + sum_length])]) sum_length += cur_len node_type = torch.tensor(node_type).long().to(self.device) edge_index = torch.stack(edge_index).t().contiguous().to(self.device) # [2, E] edge_type = torch.tensor(edge_type).long().to(self.device) # [E] edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device) # [B] return node_type, edge_index, edge_type, edge_index_lengths def edge_perms(self, length, total_lengths): all_perms = set() array = np.arange(length) for j in range(length): if self.wp == -1 and self.wf == -1: eff_array = array elif self.wp == -1: # use all past context eff_array = array[: min(length, j + self.wf)] elif self.wf == -1: # use all future context eff_array = array[max(0, j - self.wp) :] else: eff_array = array[ max(0, j - self.wp) : min(length, j + self.wf) ] perms = set() for k in range(self.n_modals): node_index = j + k * total_lengths if self.edge_temp == True: for item in eff_array: perms.add((node_index, item + k * total_lengths)) else: perms.add((node_index, node_index)) if self.edge_multi == True: for l in range(self.n_modals): if l != k: perms.add((node_index, j + l * total_lengths)) all_perms = all_perms.union(perms) return list(all_perms) def feature_packing(multimodal_feature, lengths): batch_size = lengths.size(0) # print(multimodal_feature.shape, batch_size, lengths.shape) node_features = [] for feature in multimodal_feature: for j in range(batch_size): cur_len = lengths[j].item() # print(f"feature.shape: {feature.shape}, j: {j}, cur_len: {cur_len}") node_features.append(feature[j,:cur_len]) node_features = torch.cat(node_features, dim=0) return node_features def multi_concat(nodes_feature, lengths, n_modals): sum_length = lengths.sum().item() feature = [] for j in range(n_modals): feature.append(nodes_feature[j * sum_length : (j + 1) * sum_length]) feature = torch.cat(feature, dim=-1) return feature class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-8) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) def forward(self, x: Tensor) -> Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight class Mamba(nn.Module): def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, pooling=None): super().__init__() mamba_par = { 'd_input' : d_input, 'd_model' : d_model, 'd_state' : d_state, 'd_discr' : d_discr, 'ker_size': ker_size } self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)]) self.fc_out = nn.Linear(d_input, num_classes) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, seq, cache=None): seq = torch.tensor(self.embedding(seq)).to(self.device) for mamba, norm in self.layers: out, cache = mamba(norm(seq), cache) seq = out + seq return self.fc_out(seq.mean(dim = 1)) class MambaBlock(nn.Module): def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4): super().__init__() d_discr = d_discr if d_discr is not None else d_model // 16 self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_input, bias=False) self.s_B = nn.Linear(d_model, d_state, bias=False) self.s_C = nn.Linear(d_model, d_state, bias=False) self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),) self.conv = nn.Conv1d( in_channels=d_model, out_channels=d_model, kernel_size=ker_size, padding=ker_size - 1, groups=d_model, bias=True, ) self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1)) self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float)) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, seq, cache=None): b, l, d = seq.shape (prev_hid, prev_inp) = cache if cache is not None else (None, None) a, b = self.in_proj(seq).chunk(2, dim=-1) x = rearrange(a, 'b l d -> b d l') x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1) a = self.conv(x)[..., :l] a = rearrange(a, 'b d l -> b l d') a = silu(a) a, hid = self.ssm(a, prev_hid=prev_hid) b = silu(b) out = a * b out = self.out_proj(out) if cache: cache = (hid.squeeze(), x[..., 1:]) return out, cache def ssm(self, seq, prev_hid): A = -self.A D = +self.D B = self.s_B(seq) C = self.s_C(seq) s = softplus(D + self.s_D(seq)) A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s') B_bar = einsum( B, s, 'b l s, b l d -> b l d s') X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s') hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid) out = einsum(hid, C, 'b l d s, b l s -> b l d') out = out + D * seq return out, hid def _hid_states(self, A, X, prev_hid=None): b, l, d, s = A.shape A = rearrange(A, 'b l d s -> l b d s') X = rearrange(X, 'b l d s -> l b d s') if prev_hid is not None: return rearrange(A * prev_hid + X, 'l b d s -> b l d s') h = torch.zeros(b, d, s, device=self.device) return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)