# coding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F from .help_layers import TransformerEncoderLayer, GAL,GraphFusionLayer, GraphFusionLayerAtt, MambaBlock, RMSNorm class PredictionsFusion(nn.Module): def __init__(self, num_matrices=2, num_classes=7): super(PredictionsFusion, self).__init__() self.weights = nn.Parameter(torch.rand(num_matrices, num_classes)) def forward(self, pred): normalized_weights = torch.softmax(self.weights, dim=0) weighted_matrix = sum(mat * normalized_weights[i] for i, mat in enumerate(pred)) return weighted_matrix class MultiModalTransformer_v3(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(MultiModalTransformer_v3, self).__init__() self.mode = mode self.hidden_dim = hidden_dim # Проекционные слои # self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() # self.audio_proj = nn.Sequential( # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), # nn.LayerNorm(hidden_dim), # nn.Dropout(dropout) # ) self.audio_proj = nn.Sequential( nn.Conv1d(audio_dim, hidden_dim, 1), nn.GELU(), ) self.text_proj = nn.Sequential( nn.Conv1d(text_dim, hidden_dim, 1), nn.GELU(), ) # self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() # self.text_proj = nn.Sequential( # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), # nn.LayerNorm(hidden_dim), # nn.Dropout(dropout) # ) # Механизмы внимания self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) # Классификатор # self.classifier = nn.Sequential( # nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features), # nn.ReLU(), # nn.Linear(out_features, num_classes) # ) self.classifier = nn.Sequential( nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features), # nn.LayerNorm(out_features), # nn.GELU(), # nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # self._init_weights() def forward(self, audio_features, text_features): # Преобразование размерностей audio_features = audio_features.float() text_features = text_features.float() # audio_features = self.audio_proj(audio_features) # text_features = self.text_proj(text_features) audio_features = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) text_features = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) # Адаптивная пуллинг до минимальной длины min_seq_len = min(audio_features.size(1), text_features.size(1)) audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) # Трансформерные блоки for i in range(len(self.audio_to_text_attn)): attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) audio_features += attn_audio text_features += attn_text # Статистики std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) std_text, mean_text = torch.std_mean(attn_text, dim=1) # Классификация if self.mode == 'mean': return self.classifier(torch.cat([mean_audio, mean_audio], dim=1)) else: return self.classifier(torch.cat([mean_audio, std_audio, mean_text, std_text], dim=1)) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class MultiModalTransformer_v4(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(MultiModalTransformer_v4, self).__init__() self.mode = mode self.hidden_dim = hidden_dim # Проекционные слои self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() # Механизмы внимания self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) # Графовое слияние вместо GAL if self.mode == 'mean': self.graph_fusion = GraphFusionLayer(hidden_dim, heads=num_graph_heads) else: self.graph_fusion = GraphFusionLayer(hidden_dim*2, heads=num_graph_heads) # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features), nn.ReLU(), nn.Linear(out_features, num_classes) ) def forward(self, audio_features, text_features): # Преобразование размерностей audio_features = audio_features.float() text_features = text_features.float() audio_features = self.audio_proj(audio_features) text_features = self.text_proj(text_features) # Адаптивная пуллинг до минимальной длины min_seq_len = min(audio_features.size(1), text_features.size(1)) audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) # Трансформерные блоки for i in range(len(self.audio_to_text_attn)): attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) audio_features += attn_audio text_features += attn_text # Статистики std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) std_text, mean_text = torch.std_mean(attn_text, dim=1) # Графовое слияние статистик if self.mode == 'mean': h_ta = self.graph_fusion(mean_audio, mean_text) else: h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1)) # Классификация return self.classifier(h_ta) class MultiModalTransformer_v5(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, tr_layer_number=1, positional_encoding=True, dropout=0, mode='mean', device="cuda", out_features=128, num_classes=7): super(MultiModalTransformer_v5, self).__init__() self.hidden_dim = hidden_dim self.mode = mode # Приведение к общей размерности (адаптивные проекции) self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() # Механизмы внимания self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) # Гейтед аттеншн if self.mode == 'mean': self.gal = GAL(hidden_dim, hidden_dim, hidden_dim_gated) else: self.gal = GAL(hidden_dim*2, hidden_dim*2, hidden_dim_gated) # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim, out_features), nn.ReLU(), nn.Linear(out_features, num_classes) ) def forward(self, audio_features, text_features): bs, seq_audio, audio_feat_dim = audio_features.shape bs, seq_text, text_feat_dim = text_features.shape text_features = text_features.to(torch.float32) audio_features = audio_features.to(torch.float32) # Приведение размерности audio_features = self.audio_proj(audio_features) # (bs, seq_audio, hidden_dim) text_features = self.text_proj(text_features) # (bs, seq_text, hidden_dim) # Определяем минимальную длину последовательности min_seq_len = min(seq_audio, seq_text) # Усреднение до минимальной длины audio_features = F.adaptive_avg_pool2d(audio_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1) text_features = F.adaptive_avg_pool2d(text_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1) # Трансформерные блоки for i in range(len(self.audio_to_text_attn)): attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) audio_features += attn_audio text_features += attn_text # Статистики std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) std_text, mean_text = torch.std_mean(attn_text, dim=1) # # Гейтед аттеншн # h_audio = torch.tanh(self.Wa(torch.cat([min_audio, std_audio], dim=1))) # h_text = torch.tanh(self.Wt(torch.cat([min_text, std_text], dim=1))) # z_ta = torch.sigmoid(self.W_at(torch.cat([min_audio, std_audio, min_text, std_text], dim=1))) # h_ta = z_ta * h_text + (1 - z_ta) * h_audio if self.mode == 'mean': h_ta = self.gal(mean_audio, mean_text) else: h_ta = self.gal(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1)) # Классификация output = self.classifier(h_ta) return output class MultiModalTransformer_v7(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, num_heads=2, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(MultiModalTransformer_v7, self).__init__() self.mode = mode self.hidden_dim = hidden_dim # Проекционные слои self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() # Механизмы внимания self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) ]) # Графовое слияние вместо GAL if self.mode == 'mean': self.graph_fusion = GraphFusionLayerAtt(hidden_dim, heads=num_heads) else: self.graph_fusion = GraphFusionLayerAtt(hidden_dim*2, heads=num_heads) # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features), nn.ReLU(), nn.Linear(out_features, num_classes) ) def forward(self, audio_features, text_features): # Преобразование размерностей audio_features = audio_features.float() text_features = text_features.float() audio_features = self.audio_proj(audio_features) text_features = self.text_proj(text_features) # Адаптивная пуллинг до минимальной длины min_seq_len = min(audio_features.size(1), text_features.size(1)) audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) # Трансформерные блоки for i in range(len(self.audio_to_text_attn)): attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) audio_features += attn_audio text_features += attn_text # Статистики std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) std_text, mean_text = torch.std_mean(attn_text, dim=1) # Графовое слияние статистик if self.mode == 'mean': h_ta = self.graph_fusion(mean_audio, mean_text) else: h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_audio, std_text], dim=1)) # Классификация return self.classifier(h_ta) class BiFormer(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiFormer, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # self.audio_proj = nn.Sequential( # nn.Conv1d(audio_dim, hidden_dim, 1), # nn.GELU(), # ) # self.text_proj = nn.Sequential( # nn.Conv1d(text_dim, hidden_dim, 1), # nn.GELU(), # ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.classifier_input_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool = self._pool_features(dummy_audio) text_pool = self._pool_features(dummy_text) combined = torch.cat([audio_pool, text_pool], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return torch.cat([mean_temp, mean_feat], dim=1) def forward(self, audio_features, text_features): # Проекция признаков # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool = self._pool_features(audio) text_pool = self._pool_features(text) # Классификация features = torch.cat([audio_pool, text_pool], dim=1) return self.classifier(features) def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiGraphFormer(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiGraphFormer, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads) self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.classifier_input_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # Финальная проекция графов self.fc_feat = nn.Sequential( nn.Linear(self.seg_len, self.seg_len), nn.LayerNorm(self.seg_len), nn.Dropout(dropout) ) self.fc_temp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) # text_pool_temp, _ = self._pool_features(dummy_text) combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return mean_temp, mean_feat def forward(self, audio_features, text_features): # Проекция признаков audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool_temp, audio_pool_feat = self._pool_features(audio) text_pool_temp, text_pool_feat = self._pool_features(text) # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) # print(graph_feat.shape, graph_temp.shape) # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) # graph_feat = self.fc_feat(graph_feat) # graph_temp = self.fc_temp(graph_temp) # Классификация features = torch.cat([graph_feat, graph_temp], dim=1) # print(graph_feat.shape, graph_temp.shape, features.shape) return self.classifier(features) def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiGatedGraphFormer(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiGatedGraphFormer, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim_gated*2, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # Финальная проекция графов self.fc_graph_feat = nn.Sequential( nn.Linear(self.seg_len, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.fc_graph_temp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) # Финальная проекция gated self.fc_gated_feat = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.fc_gated_temp = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) # text_pool_temp, _ = self._pool_features(dummy_text) combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return mean_temp, mean_feat def forward(self, audio_features, text_features): # Проекция признаков audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool_temp, audio_pool_feat = self._pool_features(audio) text_pool_temp, text_pool_feat = self._pool_features(text) # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :]) gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :]) fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) # print(graph_feat.shape, graph_temp.shape) # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) # graph_feat = self.fc_feat(graph_feat) # graph_temp = self.fc_temp(graph_temp) # Классификация features = torch.cat([fused_feat, fused_temp], dim=1) # print(graph_feat.shape, graph_temp.shape, features.shape) return self.classifier(features) def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiFormerWithProb(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiFormerWithProb, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # self.audio_proj = nn.Sequential( # nn.Conv1d(audio_dim, hidden_dim, 1), # nn.GELU(), # ) # self.text_proj = nn.Sequential( # nn.Conv1d(text_dim, hidden_dim, 1), # nn.GELU(), # ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.classifier_input_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool = self._pool_features(dummy_audio) text_pool = self._pool_features(dummy_text) combined = torch.cat([audio_pool, text_pool], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return torch.cat([mean_temp, mean_feat], dim=1) def forward(self, audio_features, text_features, audio_pred, text_pred): # Проекция признаков # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool = self._pool_features(audio) text_pool = self._pool_features(text) # Классификация features = torch.cat([audio_pool, text_pool], dim=1) out = self.classifier(features) w_out = self.pred_fusion([audio_pred, text_pred, out]) return w_out def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiGraphFormerWithProb(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiGraphFormerWithProb, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads) self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.classifier_input_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # Финальная проекция графов self.fc_feat = nn.Sequential( nn.Linear(self.seg_len, self.seg_len), nn.LayerNorm(self.seg_len), nn.Dropout(dropout) ) self.fc_temp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) # text_pool_temp, _ = self._pool_features(dummy_text) combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return mean_temp, mean_feat def forward(self, audio_features, text_features, audio_pred, text_pred): # Проекция признаков audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool_temp, audio_pool_feat = self._pool_features(audio) text_pool_temp, text_pool_feat = self._pool_features(text) # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) # print(graph_feat.shape, graph_temp.shape) # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) # graph_feat = self.fc_feat(graph_feat) # graph_temp = self.fc_temp(graph_temp) # Классификация features = torch.cat([graph_feat, graph_temp], dim=1) # print(graph_feat.shape, graph_temp.shape, features.shape) out = self.classifier(features) w_out = self.pred_fusion([audio_pred, text_pred, out]) return w_out def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiGatedGraphFormerWithProb(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiGatedGraphFormerWithProb, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim_gated*2, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # Финальная проекция графов self.fc_graph_feat = nn.Sequential( nn.Linear(self.seg_len, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.fc_graph_temp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) # Финальная проекция gated self.fc_gated_feat = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.fc_gated_temp = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) # text_pool_temp, _ = self._pool_features(dummy_text) combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return mean_temp, mean_feat def forward(self, audio_features, text_features, audio_pred, text_pred): # Проекция признаков audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool_temp, audio_pool_feat = self._pool_features(audio) text_pool_temp, text_pool_feat = self._pool_features(text) # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :]) gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :]) fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) # print(graph_feat.shape, graph_temp.shape) # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) # graph_feat = self.fc_feat(graph_feat) # graph_temp = self.fc_temp(graph_temp) # Классификация features = torch.cat([fused_feat, fused_temp], dim=1) # print(graph_feat.shape, graph_temp.shape, features.shape) out = self.classifier(features) w_out = self.pred_fusion([audio_pred, text_pred, out]) return w_out def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiGatedFormer(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): super(BiGatedFormer, self).__init__() self.mode = mode self.hidden_dim = hidden_dim self.seg_len = seg_len self.tr_layer_number = tr_layer_number # Проекционные слои с нормализацией self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Трансформерные слои (сохраняем вашу реализацию) self.audio_to_text_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) self.text_to_audio_attn = nn.ModuleList([ TransformerEncoderLayer( input_dim=hidden_dim, num_heads=num_transformer_heads, dropout=dropout, positional_encoding=positional_encoding ) for _ in range(tr_layer_number) ]) # self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) # self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) # Автоматический расчёт размерности для классификатора self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(hidden_dim_gated*2, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) # Финальная проекция графов # self.fc_graph_feat = nn.Sequential( # nn.Linear(self.seg_len, hidden_dim_gated), # nn.LayerNorm(hidden_dim_gated), # nn.Dropout(dropout) # ) # self.fc_graph_temp = nn.Sequential( # nn.Linear(hidden_dim, hidden_dim_gated), # nn.LayerNorm(hidden_dim_gated), # nn.Dropout(dropout) # ) # Финальная проекция gated self.fc_gated_feat = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self.fc_gated_temp = nn.Sequential( nn.Linear(hidden_dim_gated, hidden_dim_gated), nn.LayerNorm(hidden_dim_gated), nn.Dropout(dropout) ) self._init_weights() def _calculate_classifier_input_dim(self): """Вычисляет размер входных признаков для классификатора""" # Тестовый проход через пулинг с dummy-данными dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) # text_pool_temp, _ = self._pool_features(dummy_text) combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) self.classifier_input_dim = combined.size(1) def _pool_features(self, x): # Статистики по временной оси (seq_len) mean_temp = x.mean(dim=1) # [batch, hidden_dim] # Статистики по feature оси (hidden_dim) mean_feat = x.mean(dim=-1) # [batch, seq_len] return mean_temp, mean_feat def forward(self, audio_features, text_features): # Проекция признаков audio = self.audio_proj(audio_features.float()) text = self.text_proj(text_features.float()) # Адаптивный пулинг min_len = min(audio.size(1), text.size(1)) audio = self.adaptive_temporal_pool(audio, min_len) text = self.adaptive_temporal_pool(text, min_len) # Кросс-модальное взаимодействие for i in range(self.tr_layer_number): attn_audio = self.audio_to_text_attn[i](text, audio, audio) attn_text = self.text_to_audio_attn[i](audio, text, text) audio = audio + attn_audio text = text + attn_text # Агрегация признаков audio_pool_temp, audio_pool_feat = self._pool_features(audio) text_pool_temp, text_pool_feat = self._pool_features(text) # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) # graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) # graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) gated_feat = self.gated_feat(audio_pool_feat, text_pool_feat) gated_temp = self.gated_temp(audio_pool_temp, text_pool_temp) # fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) # fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) # print(graph_feat.shape, graph_temp.shape) # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) # graph_feat = self.fc_feat(graph_feat) # graph_temp = self.fc_temp(graph_temp) # Классификация features = torch.cat([gated_feat, gated_temp], dim=1) # print(graph_feat.shape, graph_temp.shape, features.shape) return self.classifier(features) def adaptive_temporal_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiMamba(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16, d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='', positional_encoding=False, out_features=128, num_classes=7, device="cuda"): super(BiMamba, self).__init__() self.hidden_dim = hidden_dim self.seg_len = seg_len self.num_mamba_layers = mamba_layer_number self.device = device # Проекционные слои для каждой модальности self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Слой для объединения модальностей self.fusion_proj = nn.Sequential( nn.Linear(2 * hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Mamba блоки для обработки объединенных признаков mamba_params = { 'd_input': hidden_dim, 'd_model': hidden_dim, 'd_state': mamba_d_state, 'd_discr': d_discr, 'ker_size': mamba_ker_size } self.mamba_blocks = nn.ModuleList([ nn.Sequential( MambaBlock(**mamba_params), RMSNorm(hidden_dim) ) for _ in range(self.num_mamba_layers) ]) # Автоматический расчет размерности классификатора # self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.seg_len + self.hidden_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) self._init_weights() # def _calculate_classifier_input_dim(self): # """Вычисляет размер входных признаков для классификатора""" # dummy = torch.randn(1, self.seg_len, self.hidden_dim) # pooled = self._pool_features(dummy) # self.classifier_input_dim = pooled.size(1) def _pool_features(self, x): """Объединение временных и feature статистик""" mean_temp = x.mean(dim=1) # [batch, hidden_dim] mean_feat = x.mean(dim=-1) # [batch, seq_len] full_feature = torch.cat([mean_temp, mean_feat], dim=1) if full_feature.shape[-1] == self.seg_len+self.hidden_dim: return torch.cat([mean_temp, mean_feat], dim=1) else: pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1] return F.pad(full_feature, (0, pad_size), mode="constant", value=0) def forward(self, audio_features, text_features): # Проекция признаков audio = self.audio_proj(audio_features.float()) # [B, T, D] text = self.text_proj(text_features.float()) # [B, T, D] # Адаптивный пулинг к минимальной длине min_len = min(audio.size(1), text.size(1)) audio = self._adaptive_pool(audio, min_len) text = self._adaptive_pool(text, min_len) # Объединение модальностей fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D] fused = self.fusion_proj(fused) # [B, T, D] # Обработка объединенных признаков через Mamba for mamba_block in self.mamba_blocks: out, _ = mamba_block[0](fused, None) out = mamba_block[1](out) fused = fused + out # Residual connection # Агрегация признаков и классификация pooled = self._pool_features(fused) return self.classifier(pooled) def _adaptive_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class BiMambaWithProb(nn.Module): def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16, d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='',positional_encoding=False, out_features=128, num_classes=7, device="cuda"): super(BiMambaWithProb, self).__init__() self.hidden_dim = hidden_dim self.seg_len = seg_len self.num_mamba_layers = mamba_layer_number self.device = device # Проекционные слои для каждой модальности self.audio_proj = nn.Sequential( nn.Linear(audio_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) self.text_proj = nn.Sequential( nn.Linear(text_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Слой для объединения модальностей self.fusion_proj = nn.Sequential( nn.Linear(2 * hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(dropout) ) # Mamba блоки для обработки объединенных признаков mamba_params = { 'd_input': hidden_dim, 'd_model': hidden_dim, 'd_state': mamba_d_state, 'd_discr': d_discr, 'ker_size': mamba_ker_size } self.mamba_blocks = nn.ModuleList([ nn.Sequential( MambaBlock(**mamba_params), RMSNorm(hidden_dim) ) for _ in range(self.num_mamba_layers) ]) # Автоматический расчет размерности классификатора # self._calculate_classifier_input_dim() # Классификатор self.classifier = nn.Sequential( nn.Linear(self.seg_len + self.hidden_dim, out_features), nn.LayerNorm(out_features), nn.GELU(), nn.Dropout(dropout), nn.Linear(out_features, num_classes) ) self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) self._init_weights() # def _calculate_classifier_input_dim(self): # """Вычисляет размер входных признаков для классификатора""" # dummy = torch.randn(1, self.seg_len, self.hidden_dim) # pooled = self._pool_features(dummy) # self.classifier_input_dim = pooled.size(1) def _pool_features(self, x): """Объединение временных и feature статистик""" mean_temp = x.mean(dim=1) # [batch, hidden_dim] mean_feat = x.mean(dim=-1) # [batch, seq_len] full_feature = torch.cat([mean_temp, mean_feat], dim=1) if full_feature.shape[-1] == self.seg_len+self.hidden_dim: return torch.cat([mean_temp, mean_feat], dim=1) else: pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1] return F.pad(full_feature, (0, pad_size), mode="constant", value=0) def forward(self, audio_features, text_features, audio_pred, text_pred): # Проекция признаков audio = self.audio_proj(audio_features.float()) # [B, T, D] text = self.text_proj(text_features.float()) # [B, T, D] # Адаптивный пулинг к минимальной длине min_len = min(audio.size(1), text.size(1)) audio = self._adaptive_pool(audio, min_len) text = self._adaptive_pool(text, min_len) # Объединение модальностей fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D] fused = self.fusion_proj(fused) # [B, T, D] # Обработка объединенных признаков через Mamba for mamba_block in self.mamba_blocks: out, _ = mamba_block[0](fused, None) out = mamba_block[1](out) fused = fused + out # Residual connection # Агрегация признаков и классификация pooled = self._pool_features(fused) out = self.classifier(pooled) w_out = self.pred_fusion([audio_pred, text_pred, out]) return w_out def _adaptive_pool(self, x, target_len): """Адаптивное изменение временной длины""" if x.size(1) == target_len: return x return F.interpolate( x.permute(0, 2, 1), size=target_len, mode='linear', align_corners=False ).permute(0, 2, 1) def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)