Yanisadel commited on
Commit
e98c2c7
·
1 Parent(s): bdfd38c

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +4 -1
chatNT.py CHANGED
@@ -416,9 +416,11 @@ class TorchBioBrainDecoder(nn.Module):
416
  _,
417
  bio_embed_dim,
418
  ) = projected_bio_embeddings.shape
419
-
420
  # Insert the bio embeddings at the SEQ token positions
421
  processed_tokens_ids = english_token_ids.clone()
 
 
422
  for bio_seq_num in range(num_bio_sequences):
423
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
424
  processed_tokens_ids,
@@ -426,6 +428,7 @@ class TorchBioBrainDecoder(nn.Module):
426
  projected_bio_embeddings[:, bio_seq_num, :, :],
427
  bio_seq_num=bio_seq_num,
428
  )
 
429
 
430
  # Regular GPT pass through
431
  print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
 
416
  _,
417
  bio_embed_dim,
418
  ) = projected_bio_embeddings.shape
419
+
420
  # Insert the bio embeddings at the SEQ token positions
421
  processed_tokens_ids = english_token_ids.clone()
422
+ print("(debug) Inside : processed tokens ids shape : ", processed_tokens_ids.shape)
423
+ print("(debug) Inside : projected bio embeddings shape : ", projected_bio_embeddings.shape)
424
  for bio_seq_num in range(num_bio_sequences):
425
  tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
426
  processed_tokens_ids,
 
428
  projected_bio_embeddings[:, bio_seq_num, :, :],
429
  bio_seq_num=bio_seq_num,
430
  )
431
+ print("After call : ", tokens_embeddings.shape)
432
 
433
  # Regular GPT pass through
434
  print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)