Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -416,9 +416,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
416 |
_,
|
417 |
bio_embed_dim,
|
418 |
) = projected_bio_embeddings.shape
|
419 |
-
|
420 |
# Insert the bio embeddings at the SEQ token positions
|
421 |
processed_tokens_ids = english_token_ids.clone()
|
|
|
|
|
422 |
for bio_seq_num in range(num_bio_sequences):
|
423 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
424 |
processed_tokens_ids,
|
@@ -426,6 +428,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
426 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
427 |
bio_seq_num=bio_seq_num,
|
428 |
)
|
|
|
429 |
|
430 |
# Regular GPT pass through
|
431 |
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|
|
|
416 |
_,
|
417 |
bio_embed_dim,
|
418 |
) = projected_bio_embeddings.shape
|
419 |
+
|
420 |
# Insert the bio embeddings at the SEQ token positions
|
421 |
processed_tokens_ids = english_token_ids.clone()
|
422 |
+
print("(debug) Inside : processed tokens ids shape : ", processed_tokens_ids.shape)
|
423 |
+
print("(debug) Inside : projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
424 |
for bio_seq_num in range(num_bio_sequences):
|
425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
426 |
processed_tokens_ids,
|
|
|
428 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
429 |
bio_seq_num=bio_seq_num,
|
430 |
)
|
431 |
+
print("After call : ", tokens_embeddings.shape)
|
432 |
|
433 |
# Regular GPT pass through
|
434 |
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|