Yanisadel commited on
Commit
d31120f
·
1 Parent(s): d91714f

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +14 -3
chatNT.py CHANGED
@@ -1661,14 +1661,24 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1661
  )
1662
 
1663
  def mlp(self, x: torch.Tensor) -> torch.Tensor:
 
1664
  x = self.norm_mlp(x)
 
1665
  if self.use_glu_in_ffn:
1666
  x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1667
  x = self.activation_fn(x1) * x2
1668
  else:
1669
- x = self.activation_fn(self.fc1(x))
1670
- return self.fc2(x)
 
 
1671
 
 
 
 
 
 
 
1672
  def forward(
1673
  self,
1674
  x: torch.Tensor,
@@ -1703,7 +1713,8 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1703
  outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
1704
  x = res + attn_output
1705
 
1706
- x = x + self.mlp(x)
 
1707
  outs_news["ATTENTION_after_mlp"] = x.clone()
1708
 
1709
  output = {}
 
1661
  )
1662
 
1663
  def mlp(self, x: torch.Tensor) -> torch.Tensor:
1664
+ outs = {}
1665
  x = self.norm_mlp(x)
1666
+ outs["MLP_layer0_layer_norm"] = x.clone()
1667
  if self.use_glu_in_ffn:
1668
  x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1669
  x = self.activation_fn(x1) * x2
1670
  else:
1671
+ x = self.fc1(x)
1672
+ outs["MLP_layer1_fc1"] = x.clone()
1673
+ x = self.activation_fn(x)
1674
+ outs["MLP_layer2_activation"] = x.clone()
1675
 
1676
+ x = self.fc2(x)
1677
+ outs["MLP_layer3_fc2"] = x.clone()
1678
+ outs["x"] = x.clone()
1679
+
1680
+ return outs
1681
+
1682
  def forward(
1683
  self,
1684
  x: torch.Tensor,
 
1713
  outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
1714
  x = res + attn_output
1715
 
1716
+ mlp_output = self.mlp(x)
1717
+ x = x + mlp_output["x"]
1718
  outs_news["ATTENTION_after_mlp"] = x.clone()
1719
 
1720
  output = {}