BiBiER / models /help_layers.py
farbverlauf's picture
gpu
960b1a0
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
import math
from torch.nn.functional import silu
from torch.nn.functional import softplus
from einops import rearrange, einsum
from torch import Tensor
from torch_geometric.nn import GATConv, RGCNConv, TransformerConv
class PositionWiseFeedForward(nn.Module):
def __init__(self, input_dim, hidden_dim, dropout=0.1):
super().__init__()
self.layer_1 = nn.Linear(input_dim, hidden_dim)
self.layer_2 = nn.Linear(hidden_dim, input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.layer_1(x)
x = F.gelu(x) # Более плавная активация
x = self.dropout(x)
return self.layer_2(x)
class AddAndNorm(nn.Module):
def __init__(self, input_dim, dropout=0.1):
super().__init__()
self.norm = nn.LayerNorm(input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, residual):
return self.norm(x + self.dropout(residual))
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(1)].detach() # Отключаем градиенты
return self.dropout(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, input_dim, num_heads, dropout=0.1, positional_encoding=False):
super().__init__()
self.input_dim = input_dim
self.self_attention = nn.MultiheadAttention(input_dim, num_heads, dropout=dropout, batch_first=True)
# self.self_attention = MHA(
# embed_dim=input_dim,
# num_heads=num_heads,
# dropout=dropout,
# # bias=True,
# use_flash_attn=True
# )
self.feed_forward = PositionWiseFeedForward(input_dim, input_dim, dropout=dropout)
self.add_norm_after_attention = AddAndNorm(input_dim, dropout=dropout)
self.add_norm_after_ff = AddAndNorm(input_dim, dropout=dropout)
self.positional_encoding = PositionalEncoding(input_dim) if positional_encoding else None
def forward(self, key, value, query):
if self.positional_encoding:
key = self.positional_encoding(key)
value = self.positional_encoding(value)
query = self.positional_encoding(query)
attn_output, _ = self.self_attention(query, key, value, need_weights=False)
# attn_output = self.self_attention(query, key, value)
x = self.add_norm_after_attention(attn_output, query)
ff_output = self.feed_forward(x)
x = self.add_norm_after_ff(ff_output, x)
return x
class GAL(nn.Module):
def __init__(self, input_dim_F1, input_dim_F2, gated_dim, dropout_rate):
super(GAL, self).__init__()
self.WF1 = nn.Parameter(torch.Tensor(input_dim_F1, gated_dim))
self.WF2 = nn.Parameter(torch.Tensor(input_dim_F2, gated_dim))
init.xavier_uniform_(self.WF1)
init.xavier_uniform_(self.WF2)
dim_size_f = input_dim_F1 + input_dim_F2
self.WF = nn.Parameter(torch.Tensor(dim_size_f, gated_dim))
init.xavier_uniform_(self.WF)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, f1, f2):
h_f1 = self.dropout(torch.tanh(torch.matmul(f1, self.WF1)))
h_f2 = self.dropout(torch.tanh(torch.matmul(f2, self.WF2)))
# print(h_f1.shape, h_f2.shape, self.WF.shape, torch.cat([f1, f2], dim=1).shape)
z_f = torch.softmax(self.dropout(torch.matmul(torch.cat([f1, f2], dim=1), self.WF)), dim=1)
h_f = z_f*h_f1 + (1 - z_f)*h_f2
return h_f
class GraphFusionLayer(nn.Module):
def __init__(self, hidden_dim, dropout=0.0, heads=2, out_mean=True):
super().__init__()
self.out_mean = out_mean
# # Проекционные слои для признаков
self.proj_audio = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout)
)
self.proj_text = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout)
)
# Графовые слои
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
# Финальная проекция
self.fc = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout)
)
def build_complete_graph(self, num_nodes):
# Создаем полный граф (каждый узел соединен со всеми)
edge_index = []
for i in range(num_nodes):
for j in range(num_nodes):
if i != j:
edge_index.append([i, j])
return torch.tensor(edge_index).t().contiguous()
def forward(self, audio_stats, text_stats):
"""
audio_stats: [batch_size, hidden_dim]
text_stats: [batch_size, hidden_dim]
"""
batch_size = audio_stats.size(0)
# Проекция признаков
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
# Объединение узлов (аудио и текст попеременно)
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
# Построение графа (полный граф для каждого элемента батча)
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
edge_index = edge_index.to(audio_stats.device)
# Применение GAT
x = F.relu(self.gat1(nodes, edge_index))
x = self.gat2(x, edge_index)
# Разделяем обратно аудио и текст
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
if self.out_mean:
# Усреднение по модальностям
fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
return self.fc(fused)
else:
return x
class GraphFusionLayerAtt(nn.Module):
def __init__(self, hidden_dim, heads=2):
super().__init__()
# Проекционные слои для признаков
self.proj_audio = nn.Linear(hidden_dim, hidden_dim)
self.proj_text = nn.Linear(hidden_dim, hidden_dim)
# Графовые слои
self.gat1 = GATConv(hidden_dim, hidden_dim, heads=heads)
self.gat2 = GATConv(hidden_dim*heads, hidden_dim)
self.attention_fusion = nn.Linear(hidden_dim, 1)
# Финальная проекция
self.fc = nn.Linear(hidden_dim, hidden_dim)
def build_complete_graph(self, num_nodes):
# Создаем полный граф (каждый узел соединен со всеми)
edge_index = []
for i in range(num_nodes):
for j in range(num_nodes):
if i != j:
edge_index.append([i, j])
return torch.tensor(edge_index).t().contiguous()
def forward(self, audio_stats, text_stats):
"""
audio_stats: [batch_size, hidden_dim]
text_stats: [batch_size, hidden_dim]
"""
batch_size = audio_stats.size(0)
# Проекция признаков
x_audio = F.relu(self.proj_audio(audio_stats)) # [batch_size, hidden_dim]
x_text = F.relu(self.proj_text(text_stats)) # [batch_size, hidden_dim]
# Объединение узлов (аудио и текст попеременно)
nodes = torch.stack([x_audio, x_text], dim=1) # [batch_size, 2, hidden_dim]
nodes = nodes.view(-1, nodes.size(-1)) # [batch_size*2, hidden_dim]
# Построение графа (полный граф для каждого элемента батча)
edge_index = self.build_complete_graph(2) # Граф для одной пары аудио-текст
edge_index = edge_index.to(audio_stats.device)
# Применение GAT
x = F.relu(self.gat1(nodes, edge_index))
x = self.gat2(x, edge_index)
# Разделяем обратно аудио и текст
x = x.view(batch_size, 2, -1) # [batch_size, 2, hidden_dim]
# Усреднение по модальностям
# fused = torch.mean(x, dim=1) # [batch_size, hidden_dim]
weights = F.softmax(self.attention_fusion(x), dim=1)
fused = torch.sum(weights * x, dim=1) # [batch_size, hidden_dim]
return self.fc(fused)
# Full code see https://github.com/leson502/CORECT_EMNLP2023/tree/master/corect/model
class GNN(nn.Module):
def __init__(self, g_dim, h1_dim, h2_dim, num_relations, num_modals, gcn_conv, use_graph_transformer, graph_transformer_nheads):
super(GNN, self).__init__()
self.gcn_conv = gcn_conv
self.use_graph_transformer=use_graph_transformer
self.num_modals = num_modals
if self.gcn_conv == "rgcn":
print("GNN --> Use RGCN")
self.conv1 = RGCNConv(g_dim, h1_dim, num_relations)
if self.use_graph_transformer:
print("GNN --> Use Graph Transformer")
in_dim = h1_dim
self.conv2 = TransformerConv(in_dim, h2_dim, heads=graph_transformer_nheads, concat=True)
self.bn = nn.BatchNorm1d(h2_dim * graph_transformer_nheads)
def forward(self, node_features, node_type, edge_index, edge_type):
print(node_features.shape, edge_index.shape, edge_type.shape)
if self.gcn_conv == "rgcn":
x = self.conv1(node_features, edge_index, edge_type)
if self.use_graph_transformer:
x = nn.functional.leaky_relu(self.bn(self.conv2(x, edge_index)))
return x
class GraphModel(nn.Module):
def __init__(self, g_dim, h1_dim, h2_dim, device, modalities, wp, wf, edge_type, gcn_conv, use_graph_transformer, graph_transformer_nheads):
super(GraphModel, self).__init__()
self.n_modals = len(modalities)
self.wp = wp
self.wf = wf
self.device = device
self.gcn_conv=gcn_conv
self.use_graph_transformer=use_graph_transformer
print(f"GraphModel --> Edge type: {edge_type}")
print(f"GraphModel --> Window past: {wp}")
print(f"GraphModel --> Window future: {wf}")
edge_temp = "temp" in edge_type
edge_multi = "multi" in edge_type
edge_type_to_idx = {}
if edge_temp:
temporal = [-1, 1, 0]
for j in temporal:
for k in range(self.n_modals):
edge_type_to_idx[str(j) + str(k) + str(k)] = len(edge_type_to_idx)
else:
for j in range(self.n_modals):
edge_type_to_idx['0' + str(j) + str(j)] = len(edge_type_to_idx)
if edge_multi:
for j in range(self.n_modals):
for k in range(self.n_modals):
if (j != k):
edge_type_to_idx['0' + str(j) + str(k)] = len(edge_type_to_idx)
self.edge_type_to_idx = edge_type_to_idx
self.num_relations = len(edge_type_to_idx)
self.edge_multi = edge_multi
self.edge_temp = edge_temp
self.gnn = GNN(g_dim, h1_dim, h2_dim, self.num_relations, self.n_modals, self.gcn_conv, self.use_graph_transformer, graph_transformer_nheads)
def forward(self, x, lengths):
# print(f"x shape: {x.shape}, lengths: {lengths}, lengths.shape: {lengths.shape}")
node_features = feature_packing(x, lengths)
node_type, edge_index, edge_type, edge_index_lengths = \
self.batch_graphify(lengths)
out_gnn = self.gnn(node_features, node_type, edge_index, edge_type)
out_gnn = multi_concat(out_gnn, lengths, self.n_modals)
return out_gnn
def batch_graphify(self, lengths):
node_type, edge_index, edge_type, edge_index_lengths = [], [], [], []
edge_type_lengths = [0] * len(self.edge_type_to_idx)
lengths = lengths.tolist()
sum_length = 0
total_length = sum(lengths)
batch_size = len(lengths)
for k in range(self.n_modals):
for j in range(batch_size):
cur_len = lengths[j]
node_type.extend([k] * cur_len)
for j in range(batch_size):
cur_len = lengths[j]
perms = self.edge_perms(cur_len, total_length)
edge_index_lengths.append(len(perms))
for item in perms:
vertices = item[0]
neighbor = item[1]
edge_index.append(torch.tensor([vertices + sum_length, neighbor + sum_length]))
if vertices % total_length > neighbor % total_length:
temporal_type = 1
elif vertices % total_length < neighbor % total_length:
temporal_type = -1
else:
temporal_type = 0
edge_type.append(self.edge_type_to_idx[str(temporal_type)
+ str(node_type[vertices + sum_length])
+ str(node_type[neighbor + sum_length])])
sum_length += cur_len
node_type = torch.tensor(node_type).long().to(self.device)
edge_index = torch.stack(edge_index).t().contiguous().to(self.device) # [2, E]
edge_type = torch.tensor(edge_type).long().to(self.device) # [E]
edge_index_lengths = torch.tensor(edge_index_lengths).long().to(self.device) # [B]
return node_type, edge_index, edge_type, edge_index_lengths
def edge_perms(self, length, total_lengths):
all_perms = set()
array = np.arange(length)
for j in range(length):
if self.wp == -1 and self.wf == -1:
eff_array = array
elif self.wp == -1: # use all past context
eff_array = array[: min(length, j + self.wf)]
elif self.wf == -1: # use all future context
eff_array = array[max(0, j - self.wp) :]
else:
eff_array = array[
max(0, j - self.wp) : min(length, j + self.wf)
]
perms = set()
for k in range(self.n_modals):
node_index = j + k * total_lengths
if self.edge_temp == True:
for item in eff_array:
perms.add((node_index, item + k * total_lengths))
else:
perms.add((node_index, node_index))
if self.edge_multi == True:
for l in range(self.n_modals):
if l != k:
perms.add((node_index, j + l * total_lengths))
all_perms = all_perms.union(perms)
return list(all_perms)
def feature_packing(multimodal_feature, lengths):
batch_size = lengths.size(0)
# print(multimodal_feature.shape, batch_size, lengths.shape)
node_features = []
for feature in multimodal_feature:
for j in range(batch_size):
cur_len = lengths[j].item()
# print(f"feature.shape: {feature.shape}, j: {j}, cur_len: {cur_len}")
node_features.append(feature[j,:cur_len])
node_features = torch.cat(node_features, dim=0)
return node_features
def multi_concat(nodes_feature, lengths, n_modals):
sum_length = lengths.sum().item()
feature = []
for j in range(n_modals):
feature.append(nodes_feature[j * sum_length : (j + 1) * sum_length])
feature = torch.cat(feature, dim=-1)
return feature
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-8) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
class Mamba(nn.Module):
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, pooling=None):
super().__init__()
mamba_par = {
'd_input' : d_input,
'd_model' : d_model,
'd_state' : d_state,
'd_discr' : d_discr,
'ker_size': ker_size
}
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
self.fc_out = nn.Linear(d_input, num_classes)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, seq, cache=None):
seq = torch.tensor(self.embedding(seq)).to(self.device)
for mamba, norm in self.layers:
out, cache = mamba(norm(seq), cache)
seq = out + seq
return self.fc_out(seq.mean(dim = 1))
class MambaBlock(nn.Module):
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
super().__init__()
d_discr = d_discr if d_discr is not None else d_model // 16
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_input, bias=False)
self.s_B = nn.Linear(d_model, d_state, bias=False)
self.s_C = nn.Linear(d_model, d_state, bias=False)
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
self.conv = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=ker_size,
padding=ker_size - 1,
groups=d_model,
bias=True,
)
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, seq, cache=None):
b, l, d = seq.shape
(prev_hid, prev_inp) = cache if cache is not None else (None, None)
a, b = self.in_proj(seq).chunk(2, dim=-1)
x = rearrange(a, 'b l d -> b d l')
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
a = self.conv(x)[..., :l]
a = rearrange(a, 'b d l -> b l d')
a = silu(a)
a, hid = self.ssm(a, prev_hid=prev_hid)
b = silu(b)
out = a * b
out = self.out_proj(out)
if cache:
cache = (hid.squeeze(), x[..., 1:])
return out, cache
def ssm(self, seq, prev_hid):
A = -self.A
D = +self.D
B = self.s_B(seq)
C = self.s_C(seq)
s = softplus(D + self.s_D(seq))
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
out = einsum(hid, C, 'b l d s, b l s -> b l d')
out = out + D * seq
return out, hid
def _hid_states(self, A, X, prev_hid=None):
b, l, d, s = A.shape
A = rearrange(A, 'b l d s -> l b d s')
X = rearrange(X, 'b l d s -> l b d s')
if prev_hid is not None:
return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
h = torch.zeros(b, d, s, device=self.device)
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)