Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1661,14 +1661,24 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
1661 |
)
|
1662 |
|
1663 |
def mlp(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
1664 |
x = self.norm_mlp(x)
|
|
|
1665 |
if self.use_glu_in_ffn:
|
1666 |
x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
|
1667 |
x = self.activation_fn(x1) * x2
|
1668 |
else:
|
1669 |
-
x = self.
|
1670 |
-
|
|
|
|
|
1671 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1672 |
def forward(
|
1673 |
self,
|
1674 |
x: torch.Tensor,
|
@@ -1703,7 +1713,8 @@ class TorchMultiModalPerceiverResamplerBlock(nn.Module):
|
|
1703 |
outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
|
1704 |
x = res + attn_output
|
1705 |
|
1706 |
-
|
|
|
1707 |
outs_news["ATTENTION_after_mlp"] = x.clone()
|
1708 |
|
1709 |
output = {}
|
|
|
1661 |
)
|
1662 |
|
1663 |
def mlp(self, x: torch.Tensor) -> torch.Tensor:
|
1664 |
+
outs = {}
|
1665 |
x = self.norm_mlp(x)
|
1666 |
+
outs["MLP_layer0_layer_norm"] = x.clone()
|
1667 |
if self.use_glu_in_ffn:
|
1668 |
x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
|
1669 |
x = self.activation_fn(x1) * x2
|
1670 |
else:
|
1671 |
+
x = self.fc1(x)
|
1672 |
+
outs["MLP_layer1_fc1"] = x.clone()
|
1673 |
+
x = self.activation_fn(x)
|
1674 |
+
outs["MLP_layer2_activation"] = x.clone()
|
1675 |
|
1676 |
+
x = self.fc2(x)
|
1677 |
+
outs["MLP_layer3_fc2"] = x.clone()
|
1678 |
+
outs["x"] = x.clone()
|
1679 |
+
|
1680 |
+
return outs
|
1681 |
+
|
1682 |
def forward(
|
1683 |
self,
|
1684 |
x: torch.Tensor,
|
|
|
1713 |
outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
|
1714 |
x = res + attn_output
|
1715 |
|
1716 |
+
mlp_output = self.mlp(x)
|
1717 |
+
x = x + mlp_output["x"]
|
1718 |
outs_news["ATTENTION_after_mlp"] = x.clone()
|
1719 |
|
1720 |
output = {}
|