Yanisadel commited on
Commit
d91714f
·
1 Parent(s): c4e3377

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +19 -4
chatNT.py CHANGED
@@ -1677,8 +1677,10 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1677
  attention_mask_1: Optional[torch.Tensor] = None,
1678
  attention_mask_2: Optional[torch.Tensor] = None,
1679
  ) -> Dict[str, torch.Tensor]:
 
1680
  res = x
1681
  x = self.norm_cross_attention_1(x)
 
1682
 
1683
  attn_output = self.cross_attention_1(
1684
  query=x,
@@ -1686,21 +1688,30 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1686
  value=cross_attention_embeddings_1,
1687
  attention_mask=attention_mask_1,
1688
  )["embeddings"]
 
1689
  x = res + attn_output
1690
 
1691
  res = x
1692
  x = self.norm_cross_attention_2(x)
 
1693
  attn_output = self.cross_attention_2(
1694
  query=x,
1695
  key=cross_attention_embeddings_2,
1696
  value=cross_attention_embeddings_2,
1697
  attention_mask=attention_mask_2,
1698
  )["embeddings"]
 
1699
  x = res + attn_output
1700
 
1701
  x = x + self.mlp(x)
 
1702
 
1703
- return {"embeddings": x}
 
 
 
 
 
1704
 
1705
 
1706
  class TorchMultiModalPerceiverResampler(nn.Module):
@@ -1763,8 +1774,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
- outs[f"PERCEIVER_RESAMPLER_concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
- outs[f"PERCEIVER_RESAMPLER_concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
 
1769
  output = layer(
1770
  x=x,
@@ -1774,7 +1785,11 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
- outs[f"PERCEIVER_RESAMPLER_attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
 
 
 
 
1778
 
1779
  return x, outs
1780
 
 
1677
  attention_mask_1: Optional[torch.Tensor] = None,
1678
  attention_mask_2: Optional[torch.Tensor] = None,
1679
  ) -> Dict[str, torch.Tensor]:
1680
+ outs_news = {}
1681
  res = x
1682
  x = self.norm_cross_attention_1(x)
1683
+ outs_news["ATTENTION_layer0_layer_norm_cross_attention_1"] = x.clone()
1684
 
1685
  attn_output = self.cross_attention_1(
1686
  query=x,
 
1688
  value=cross_attention_embeddings_1,
1689
  attention_mask=attention_mask_1,
1690
  )["embeddings"]
1691
+ outs_news["ATTENTION_layer1_cross_attention_layer_1"] = attn_output.clone()
1692
  x = res + attn_output
1693
 
1694
  res = x
1695
  x = self.norm_cross_attention_2(x)
1696
+ outs_news["ATTENTION_layer2_cross_attention_layer_2"] = x.clone()
1697
  attn_output = self.cross_attention_2(
1698
  query=x,
1699
  key=cross_attention_embeddings_2,
1700
  value=cross_attention_embeddings_2,
1701
  attention_mask=attention_mask_2,
1702
  )["embeddings"]
1703
+ outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
1704
  x = res + attn_output
1705
 
1706
  x = x + self.mlp(x)
1707
+ outs_news["ATTENTION_after_mlp"] = x.clone()
1708
 
1709
+ output = {}
1710
+ for key in outs_news.keys():
1711
+ output[key] = outs_news[key]
1712
+
1713
+ output["embeddings"] = x
1714
+ return output
1715
 
1716
 
1717
  class TorchMultiModalPerceiverResampler(nn.Module):
 
1774
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1775
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1776
 
1777
+ #outs[f"PERCEIVER_RESAMPLER_concat_input_1_{layer_idx}"] = concat_input_1.clone()
1778
+ #outs[f"PERCEIVER_RESAMPLER_concat_input_2_{layer_idx}"] = concat_input_2.clone()
1779
 
1780
  output = layer(
1781
  x=x,
 
1785
  attention_mask_2=attention_mask_2,
1786
  )
1787
  x = output["embeddings"]
1788
+ #outs[f"PERCEIVER_RESAMPLER_attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1789
+
1790
+ for key in output.keys():
1791
+ if key != "embeddings":
1792
+ outs[f"{key}_{layer_idx}"] = output[key].clone()
1793
 
1794
  return x, outs
1795