Yanisadel commited on
Commit
c4e3377
·
1 Parent(s): 13dcc57

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -3
chatNT.py CHANGED
@@ -1763,8 +1763,8 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
- #outs[f"concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
- #outs[f"concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
 
1769
  output = layer(
1770
  x=x,
@@ -1774,7 +1774,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
- #outs[f"attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1778
 
1779
  return x, outs
1780
 
 
1763
  concat_input_1 = torch.cat([xf_1, x], dim=1)
1764
  concat_input_2 = torch.cat([xf_2, x], dim=1)
1765
 
1766
+ outs[f"PERCEIVER_RESAMPLER_concat_input_1_{layer_idx}"] = concat_input_1.clone()
1767
+ outs[f"PERCEIVER_RESAMPLER_concat_input_2_{layer_idx}"] = concat_input_2.clone()
1768
 
1769
  output = layer(
1770
  x=x,
 
1774
  attention_mask_2=attention_mask_2,
1775
  )
1776
  x = output["embeddings"]
1777
+ outs[f"PERCEIVER_RESAMPLER_attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1778
 
1779
  return x, outs
1780