Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1763,8 +1763,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
1765 |
|
1766 |
-
|
1767 |
-
|
1768 |
|
1769 |
output = layer(
|
1770 |
x=x,
|
@@ -1774,7 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
1774 |
attention_mask_2=attention_mask_2,
|
1775 |
)
|
1776 |
x = output["embeddings"]
|
1777 |
-
|
1778 |
|
1779 |
return x, outs
|
1780 |
|
|
|
1763 |
concat_input_1 = torch.cat([xf_1, x], dim=1)
|
1764 |
concat_input_2 = torch.cat([xf_2, x], dim=1)
|
1765 |
|
1766 |
+
outs[f"PERCEIVER_RESAMPLER_concat_input_1_{layer_idx}"] = concat_input_1.clone()
|
1767 |
+
outs[f"PERCEIVER_RESAMPLER_concat_input_2_{layer_idx}"] = concat_input_2.clone()
|
1768 |
|
1769 |
output = layer(
|
1770 |
x=x,
|
|
|
1774 |
attention_mask_2=attention_mask_2,
|
1775 |
)
|
1776 |
x = output["embeddings"]
|
1777 |
+
outs[f"PERCEIVER_RESAMPLER_attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
|
1778 |
|
1779 |
return x, outs
|
1780 |
|