Yanisadel commited on
Commit
6535c3d
·
1 Parent(s): d31120f

Delete chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +0 -1939
chatNT.py DELETED
@@ -1,1939 +0,0 @@
1
- # This file stores ChatNT and all associated layers and configs
2
-
3
- from dataclasses import asdict, dataclass, field
4
- from typing import Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F # noqa: N812
10
- from transformers import PretrainedConfig, PreTrainedModel
11
-
12
-
13
- @dataclass
14
- class RotaryEmbeddingConfig:
15
- """
16
- Rotary Positional Embedding configuration
17
- max_seq_len: The number of positions to encode and cache.
18
- dim: Dimension of RoPE.
19
- theta: Rotation angle.
20
- """
21
-
22
- max_seq_len: int
23
- dim: int
24
- theta: float
25
-
26
-
27
- @dataclass
28
- class PerceiverResamplerConfig:
29
- """
30
- Parameters to initialize an PerceiverResampler model. Based on the ESM architecture.
31
-
32
- Args:
33
- emb_layer_norm_before: Whether to use layer norm before the first attention
34
- layer.
35
- attention_heads: Number of attention heads.
36
- key_size: The dimension of the query, key, and values within each attention
37
- head, if not specified, it is set to attention_heads//embed_dim.
38
- It can be useful to set a custom key size if we want to impose the size of
39
- the query, key and value tensor ( for example, tensors shaped with
40
- power of 2 are more efficiently handled on TPUs ).
41
- Note: Parametrizing the model with a custom key size has been done in :
42
- Brown, Tom, et al. "Language models are few-shot learners."
43
- Advances in neural information processing systems 33 (2020): 1877-1901.
44
- embed_dim: Embedding dimension.
45
- ffn_embed_dim: Feed forward embedding dimension.
46
- num_layers: Number of attention blocks.
47
- ffn_activation_name: Activation function to be used in FFN block. Supported
48
- names are "gelu", "relu", "swish".
49
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
- to True and use swish as ffn_activation_name.
52
- Same principle for a gated-relu. To keep the same number of parameters in
53
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
- resampled_length: length of the resampled output of the module
56
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
- gradients in the forward pass to reduce the computation in the backward).
58
- """
59
-
60
- # architecture
61
- emb_layer_norm_before: bool = False
62
- attention_heads: int = 20
63
- key_size: Optional[int] = None
64
- embed_dim: int = 1280
65
- ffn_embed_dim: int = 5120
66
- num_layers: int = 24
67
- add_bias_kv: bool = False
68
- add_bias_ffn: bool = True
69
- ffn_activation_name: str = "gelu-no-approx"
70
- use_glu_in_ffn: bool = False
71
- resampled_length: int = 64
72
-
73
- # performance
74
- use_gradient_checkpointing: bool = False
75
-
76
- def __post_init__(self) -> None:
77
- """
78
- Checks that the given values are compatible.
79
- """
80
-
81
- if self.key_size is None:
82
- if not self.embed_dim % self.attention_heads == 0:
83
- raise ValueError(
84
- f"When no key size is provided, the embedding dimension should be "
85
- f"divisible by the number of heads, however provided embedding "
86
- f"dimension is {self.embed_dim} and the number of heads is "
87
- f"{self.attention_heads}."
88
- )
89
- self.key_size = self.embed_dim // self.attention_heads
90
-
91
-
92
- @dataclass
93
- class GptConfig:
94
- """
95
- Parameters to initialize a Gpt model.
96
-
97
- NOTE: the pad token is not defined
98
-
99
- Args:
100
- vocab_size: Token vocabulary.
101
- eos_token_id: used to stop sentence generation
102
- embed_dim: Embedding dimension.
103
- ffn_embed_dim: Feed forward embedding dimension.
104
- num_heads: Number of attention heads.
105
- num_kv_heads: Number of key and value heads to support Grouped-Query and
106
- Multi-Query Attention. If None, the number of key and value heads is
107
- equal to the number of attention heads.
108
- num_layers: Number of Decoder layer_stack
109
- rope_config: The configuration for the rotary positional embeddings
110
- add_bias_ffn: Add bias in feed forward network block.
111
- ffn_activation_name: Activation function to be used in FFN block. Supported
112
- names are "gelu", "gelu-no-approx", "relu", "swish".
113
- use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
- Forward Network (FFN) block.
115
- example: To do a swiGLU (gated-swish) put this arg
116
- to True and use swish as ffn_activation_name.
117
- Same principle for a gated-relu.
118
- add_bias_lm_head: whether to use bias in the final LM layer
119
- norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
- one of ["layer_norm", "RMS_norm"]
121
- parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
- and then sum up the results as it is done in Gpt-NeoX :
123
- Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
- language model." arXiv preprint arXiv:2204.06745 (2022).
125
- It is said to improve the training time of 15% when compiling with JAX
126
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
- gradients in the forward pass to reduce the computation in the backward).
128
- add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
- output projections).
130
- """
131
-
132
- # vocabulary
133
- vocab_size: int
134
- eos_token_id: int
135
-
136
- # architecture
137
- embed_dim: int = 16
138
- ffn_embed_dim: int = 64
139
- num_heads: int = 2
140
- num_kv_heads: Optional[int] = None
141
- num_layers: int = 2
142
- rope_config: RotaryEmbeddingConfig = field(
143
- default_factory=lambda: RotaryEmbeddingConfig(
144
- max_seq_len=512, dim=8, theta=10000.0
145
- )
146
- )
147
- add_bias_ffn: bool = False
148
- ffn_activation_name: str = "swish"
149
- use_glu_in_ffn: bool = True
150
- add_bias_lm_head: bool = False
151
- norm_type: str = "RMS_norm"
152
- rms_norm_eps: float = 1e-6
153
- parallel_attention_ff: bool = True
154
-
155
- # inference / backward behavior
156
- use_gradient_checkpointing: bool = False
157
-
158
- # architecture params with default values
159
- add_bias_attn: bool = False
160
-
161
- def __post_init__(self) -> None:
162
- """
163
- Checks that the given values are compatible.
164
- """
165
- if not self.embed_dim % self.num_heads == 0:
166
- raise ValueError(
167
- f"The embedding dimension should be "
168
- f"divisible by the number of heads, however provided embedding "
169
- f"dimension is {self.embed_dim} and the number of heads is "
170
- f"{self.num_heads}."
171
- )
172
-
173
- if not self.embed_dim // self.num_heads > 1:
174
- raise ValueError(
175
- "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
- )
177
-
178
- if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
- raise ValueError(
180
- "embed_dim // num_heads must be higher than rope_config.dim "
181
- "to apply rotary embeddings"
182
- )
183
-
184
- def to_dict(self): # type: ignore
185
- output = asdict(self)
186
- output["rope_config"] = asdict(self.rope_config)
187
- return output
188
-
189
-
190
- @dataclass
191
- class ESMTransformerConfig:
192
- """
193
- Parameters to initialize an ESM model. While the ESM architecture is an encoder-only
194
- model, different choices have been made for each version and this configuration aims
195
- to cover most of them.
196
-
197
- Args:
198
- alphabet_size: Token vocabulary.
199
- pad_token_id: ID of pad token.
200
- mask_token_id: ID of mask token.
201
- max_positions: Maximum sequence length.
202
- embed_scale: Correction ratio applied to the embeddings to make up for the
203
- norm difference between the input during training and inference.
204
- emb_layer_norm_before: Whether to use layer norm before the first attention
205
- layer.
206
- attention_heads: Number of attention heads.
207
- key_size: The dimension of the query, key, and values within each attention
208
- head, if not specified, it is set to attention_heads//embed_dim.
209
- It can be useful to set a custom key size if we want to impose the size of
210
- the query, key and value tensor ( for example, tensors shaped with
211
- power of 2 are more efficiently handled on TPUs ).
212
- Note: Parametrizing the model with a custom key size has been done in :
213
- Brown, Tom, et al. "Language models are few-shot learners."
214
- Advances in neural information processing systems 33 (2020): 1877-1901.
215
- embed_dim: Embedding dimension.
216
- ffn_embed_dim: Feed forward embedding dimension.
217
- num_layers: Number of attention blocks.
218
- positional_embedding: Type of positional embedding to use before the first
219
- attention layer. Options: "learned", "learned_standard" "sinusoidal" or
220
- None.
221
- NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
222
- is a more standard one, used for example in DNAbert.
223
- lm_head: type of language model head. Options: "simple", "roberta" or None.
224
- add_bias_kv: Add bias in attention layer.
225
- add_bias_ffn: Add bias in feed forward network block.
226
- use_rotary_embedding: Whether to use rotary embeddings (for ESM2). Requires:
227
- positional_embeddings = None.
228
- rescaling_factor: Scaling factor to use for rotary embeddings.
229
- ffn_activation_name: Activation function to be used in FFN block. Supported
230
- names are "gelu", "relu", "swish".
231
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
232
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
233
- to True and use swish as ffn_activation_name.
234
- Same principle for a gated-relu. To keep the same number of parameters in
235
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
236
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
237
- mask_before_attention: Use mask before attention layers (for EMS1b and ESM2).
238
- layer_norm_eps: the eps factor in the different layer norms of the model (refer
239
- to layer norm implementation)
240
- token_dropout: Token dropout.
241
- masking_ratio: Masking ratio (used if token dropout is enabled).
242
- masking_prob: Masking probability (used if token dropout is enabled).
243
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
244
- gradients in the forward pass to reduce the computation in the backward).
245
- """
246
-
247
- alphabet_size: int
248
- pad_token_id: int
249
- mask_token_id: int
250
-
251
- max_positions: int = 1024
252
- embed_scale: float = 1.0
253
-
254
- # architecture
255
- emb_layer_norm_before: bool = False
256
- attention_heads: int = 20
257
- key_size: Optional[int] = None
258
- embed_dim: int = 1280
259
- ffn_embed_dim: int = 5120
260
- num_layers: int = 24
261
- positional_embedding: Optional[str] = "learned"
262
- lm_head: Optional[str] = "simple"
263
- add_bias_kv: bool = False
264
- add_bias_ffn: bool = True
265
- use_rotary_embedding: bool = False
266
- rescaling_factor: Optional[float] = None
267
- ffn_activation_name: str = "gelu-no-approx"
268
- use_glu_in_ffn: bool = False
269
- mask_before_attention: bool = False
270
- layer_norm_eps: float = 1e-5
271
- pre_layer_norm: bool = True
272
- bias_word_embedding: bool = False
273
-
274
- # dropout
275
- token_dropout: bool = False
276
- masking_ratio: float = 0.1
277
- masking_prob: float = 0.8
278
-
279
- # logging
280
- use_gradient_checkpointing: bool = False
281
-
282
- # return
283
- embeddings_layers_to_save: List[int] = field(default_factory=list)
284
- attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
285
-
286
- def __post_init__(self) -> None:
287
- """
288
- Checks that the given values are compatible.
289
- """
290
-
291
- if self.key_size is None:
292
- if not self.embed_dim % self.attention_heads == 0:
293
- raise ValueError(
294
- f"When no key size is provided, the embedding dimension should be "
295
- f"divisible by the number of heads, however provided embedding "
296
- f"dimension is {self.embed_dim} and the number of heads is "
297
- f"{self.attention_heads}."
298
- )
299
- self.key_size = self.embed_dim // self.attention_heads
300
- if self.positional_embedding is not None:
301
- if type(self.positional_embedding) != str:
302
- raise TypeError
303
-
304
- if self.positional_embedding not in [
305
- "learned",
306
- "sinusoidal",
307
- "learned_standard",
308
- "alibi_dnabert_2",
309
- ]:
310
- raise ValueError(
311
- "The positional_embedding argument should either be None,"
312
- "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
313
- )
314
- if self.lm_head is not None:
315
- if type(self.lm_head) != str:
316
- raise TypeError
317
-
318
- if self.lm_head not in ["simple", "roberta"]:
319
- raise ValueError(
320
- "The lm_head argument should either be None,"
321
- "`simple` or `roberta`."
322
- )
323
-
324
- if self.use_rotary_embedding and self.positional_embedding is not None:
325
- raise ValueError(
326
- "When using rotary embedding, positional_embedding must be set to none"
327
- )
328
-
329
- if self.add_bias_kv and self.use_rotary_embedding:
330
- raise ValueError(
331
- "Biases on key and values are not compatible with Rotary embeddings."
332
- )
333
-
334
- if self.positional_embedding == "alibi_dnabert_2":
335
- assert not self.add_bias_kv
336
-
337
-
338
- @dataclass
339
- class ChatNTConfig(PretrainedConfig):
340
- model_type = "ChatNT"
341
-
342
- def __init__(self, **kwargs): # type: ignore
343
- self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
344
- self.esm_config: ESMTransformerConfig = kwargs.get(
345
- "esm_config", ESMTransformerConfig(4000, 1, 4)
346
- )
347
- self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
348
- "perceiver_resampler_config", PerceiverResamplerConfig()
349
- )
350
- self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
351
- self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
352
- self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
353
- super().__init__(**kwargs)
354
-
355
- def to_dict(self): # type: ignore
356
- output = super().to_dict()
357
-
358
- def serialize(obj): # type: ignore
359
- return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
360
-
361
- output["gpt_config"] = serialize(self.gpt_config) # type: ignore
362
- output["esm_config"] = serialize(self.esm_config) # type: ignore
363
- output["perceiver_resampler_config"] = serialize( # type: ignore
364
- self.perceiver_resampler_config
365
- )
366
- return output
367
-
368
-
369
- class TorchBioBrainDecoder(nn.Module):
370
- def __init__(
371
- self,
372
- gpt_config: GptConfig,
373
- seq_token_id: int,
374
- ):
375
- """
376
- Initializes the BioBrain decoder, using a GPT model for text generation with
377
- bio embeddings.
378
-
379
- Args:
380
- gpt_config: Configuration for the GPT model
381
- seq_token_id: Index of the SEQ token
382
- """
383
- super(TorchBioBrainDecoder, self).__init__()
384
- self.gpt_config = gpt_config
385
- self.seq_token_id = seq_token_id
386
-
387
- # Initialize the GPT model (assumed you have it already in PyTorch)
388
- self.gpt_model = TorchGptDecoder(self.gpt_config)
389
-
390
- def forward(
391
- self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
392
- ) -> torch.Tensor:
393
- """
394
- Forward pass through the model.
395
-
396
- Args:
397
- english_token_ids: Tensor of English token IDs with shape
398
- (batch_size, num_english_tokens).
399
- projected_bio_embeddings: Optional tensor of bio embeddings with shape
400
- (batch_size, num_bio_sequences, ?, embed_dim).
401
-
402
- Returns:
403
- torch.Tensor: The logits from the GPT model,
404
- shaped (batch_size, num_english_tokens, vocab_size).
405
- """
406
-
407
- # Compute English token embeddings
408
- tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
409
-
410
- if projected_bio_embeddings is not None:
411
- (
412
- batch_size,
413
- num_bio_sequences,
414
- _,
415
- bio_embed_dim,
416
- ) = projected_bio_embeddings.shape
417
-
418
- # Insert the bio embeddings at the SEQ token positions
419
- processed_tokens_ids = english_token_ids.clone()
420
- for bio_seq_num in range(num_bio_sequences):
421
- tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
- processed_tokens_ids,
423
- tokens_embeddings,
424
- projected_bio_embeddings[:, bio_seq_num, :, :],
425
- bio_seq_num=bio_seq_num,
426
- )
427
-
428
- # Regular GPT pass through
429
- embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
- embeddings = self.gpt_model.final_norm(embeddings)
431
-
432
- # Compute logits
433
- logits = self.gpt_model.lm_head(embeddings)
434
-
435
- if projected_bio_embeddings is not None:
436
- # Clean logits sequentially
437
- processed_tokens_ids = english_token_ids.clone()
438
- resampled_length = projected_bio_embeddings.shape[-2]
439
- for _ in range(num_bio_sequences):
440
- logits, processed_tokens_ids = self.cleanup_logits(
441
- tokens=processed_tokens_ids,
442
- logits=logits,
443
- resampled_length=resampled_length,
444
- )
445
-
446
- return logits
447
-
448
- def insert_embeddings(
449
- self,
450
- tokens: torch.Tensor,
451
- input_embeddings: torch.Tensor,
452
- resampled_embeddings: torch.Tensor,
453
- bio_seq_num: int,
454
- ) -> Tuple[torch.Tensor, torch.Tensor]:
455
- """
456
- Inserts resampled embeddings in input_embeddings, starting at the SEQ token
457
-
458
- Args:
459
- tokens (torch.Tensor): Shape (batch_size, num_tokens)
460
- input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
461
- resampled_embeddings (torch.Tensor):
462
- Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
463
-
464
- Returns:
465
- Tuple[torch.Tensor, torch.Tensor]:
466
- - input_embeddings with resampled_embeddings inserted at the SEQ token
467
- - tokens with the SEQ token set to -1
468
- """
469
-
470
- def _insert(
471
- tokens_1d: torch.Tensor,
472
- input_embeddings_1d: torch.Tensor,
473
- resampled_embeddings_1d: torch.Tensor,
474
- ) -> Tuple[torch.Tensor, torch.Tensor]:
475
- """
476
- Args:
477
- tokens (torch.Tensor): Shape (num_tokens,)
478
- input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
479
- resampled_embeddings (torch.Tensor):
480
- Shape (bio_sequence_length, embed_dim,)
481
- """
482
- indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
- if indices.numel() > 0:
484
- idx = indices[0].item()
485
- insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
- x = torch.cat(
487
- [
488
- input_embeddings_1d[:insertion_pos, :],
489
- resampled_embeddings_1d,
490
- input_embeddings_1d[insertion_pos:, :],
491
- ],
492
- dim=0,
493
- )[: tokens_1d.shape[0] + 1, :]
494
- x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
495
- :-1, :
496
- ]
497
- tokens_1d[idx] = -1
498
- return x, tokens_1d
499
- else:
500
- return (
501
- input_embeddings,
502
- tokens_1d,
503
- ) # Return unchanged if seq_token_id is not found
504
-
505
- tokens_acc = []
506
- embeddings_acc = []
507
-
508
- for i in range(tokens.shape[0]):
509
- embeddings_out, tokens_out = _insert(
510
- tokens[i].clone(),
511
- input_embeddings[i].clone(),
512
- resampled_embeddings[i].clone(),
513
- )
514
- tokens_acc.append(tokens_out)
515
- embeddings_acc.append(embeddings_out)
516
-
517
- tokens_acc = torch.stack(tokens_acc)
518
- embeddings_acc = torch.stack(embeddings_acc)
519
-
520
- return embeddings_acc, tokens_acc
521
-
522
- def cleanup_logits(
523
- self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
524
- ) -> Tuple[torch.Tensor, torch.Tensor]:
525
- """
526
- Removes the logits corresponding to the unused embeddings.
527
-
528
- Args:
529
- tokens: Input english tokens.
530
- logits: Input logits.
531
-
532
- Returns:
533
- Cleaned logits, last values will be equal to 0.
534
- """
535
-
536
- def _clean(
537
- token: torch.Tensor, logit: torch.Tensor
538
- ) -> Tuple[torch.Tensor, torch.Tensor]:
539
- indices = torch.where(token == self.seq_token_id)[0]
540
- if indices.numel() > 0:
541
- idx = indices[0].item()
542
-
543
- mask_idx = (
544
- torch.arange(logit.shape[0] - resampled_length, device=logit.device)
545
- > idx
546
- )
547
- mask_idx = mask_idx.unsqueeze(1)
548
-
549
- # Remove values corresponding to bio tokens
550
- logit = (
551
- logit[:-resampled_length] * (~mask_idx)
552
- + logit[resampled_length:] * mask_idx
553
- )
554
-
555
- # Append zeros at the end
556
- logit = torch.cat(
557
- (
558
- logit,
559
- torch.zeros(
560
- (resampled_length, logit.shape[1]),
561
- dtype=logit.dtype,
562
- device=logit.device,
563
- ),
564
- )
565
- )
566
-
567
- # Update token
568
- token[idx] = -1
569
-
570
- return logit, token
571
-
572
- else:
573
- return logit, token
574
-
575
- tokens_acc = []
576
- logits_acc = []
577
-
578
- for i in range(tokens.shape[0]):
579
- logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
580
- tokens_acc.append(tokens_out)
581
- logits_acc.append(logits_out)
582
- tokens_acc = torch.stack(tokens_acc)
583
- logits_acc = torch.stack(logits_acc)
584
-
585
- return logits_acc, tokens_acc
586
-
587
-
588
- class TorchMultiOmicsModel(PreTrainedModel):
589
- config_class = ChatNTConfig
590
-
591
- def __init__(self, config: ChatNTConfig) -> None:
592
- if isinstance(config, dict):
593
- # If config is a dictionary instead of ChatNTConfig (which can happen
594
- # depending how the config was saved), we convert it to the config
595
- config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
596
- **config["gpt_config"]["rope_config"]
597
- )
598
- config["gpt_config"] = GptConfig(**config["gpt_config"])
599
- config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
600
- config["perceiver_resampler_config"] = PerceiverResamplerConfig(
601
- **config["perceiver_resampler_config"]
602
- )
603
- config = ChatNTConfig(**config) # type: ignore
604
-
605
- else:
606
- if isinstance(config.gpt_config, dict):
607
- config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
608
- **config.gpt_config["rope_config"]
609
- )
610
- config.gpt_config = GptConfig(**config.gpt_config)
611
-
612
- if isinstance(config.esm_config, dict):
613
- config.esm_config = ESMTransformerConfig(**config.esm_config)
614
-
615
- if isinstance(config.perceiver_resampler_config, dict):
616
- config.perceiver_resampler_config = PerceiverResamplerConfig(**config.perceiver_resampler_config)
617
-
618
- super().__init__(config=config)
619
- self.gpt_config = config.gpt_config
620
- self.esm_config = config.esm_config
621
- self.perceiver_resampler_config = config.perceiver_resampler_config
622
- self.seq_token_id = config.seq_token_id
623
- self.bio_pad_token_id = config.bio_pad_token_id
624
- self.english_pad_token_id = config.english_pad_token_id
625
-
626
- # Correct seq_token_id
627
- self.seq_token_id -= 1
628
-
629
- self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
630
- self.biobrain_decoder = TorchBioBrainDecoder(
631
- gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
632
- )
633
- self.projection_model = TorchMultiModalPerceiverResamplerProjection(
634
- perceiver_resampler_config=self.perceiver_resampler_config,
635
- input_embed_dim=self.esm_config.embed_dim,
636
- embed_dim=self.gpt_config.embed_dim,
637
- english_vocab_size=self.gpt_config.vocab_size,
638
- bio_pad_token_id=self.bio_pad_token_id,
639
- english_pad_token_id=self.english_pad_token_id,
640
- )
641
-
642
- def forward(
643
- self,
644
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
645
- projection_english_tokens_ids: torch.Tensor,
646
- projected_bio_embeddings: torch.Tensor = None,
647
- ) -> dict[str, torch.Tensor]:
648
- """
649
-
650
- Args:
651
- multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
652
- english_tokens_ids: Represents the prompt tokens (english tokens)
653
- Shape (batch_size, num_english_tokens)
654
-
655
- bio_tokens_ids: Represents the bio sequences tokens
656
- Shape (batch_size, num_bio_sequences, num_bio_tokens)
657
-
658
- projection_english_tokens_ids (torch.Tensor):
659
- Shape (batch_size, num_english_tokens)
660
-
661
- projected_bio_embeddings (projected_bio_embeddings, optional):
662
- Shape (batch_size, num_bio_sequencse, ?, embed_dim).
663
- Defaults to None.
664
-
665
- Returns:
666
- dict[str, torch.Tensor] containing:
667
- - logits:
668
- Shape (batch_size, num_tokens, vocab_size)
669
-
670
- - projected_bio_embeddings:
671
- Shape (batch_size, num_bio_sequences, ?, embed_dim)
672
- """
673
- english_token_ids, bio_token_ids = multi_omics_tokens_ids
674
- english_token_ids = english_token_ids.clone()
675
- bio_token_ids = bio_token_ids.clone()
676
- projection_english_tokens_ids = projection_english_tokens_ids.clone()
677
- if projected_bio_embeddings is not None:
678
- projected_bio_embeddings = projected_bio_embeddings.clone()
679
-
680
- # Replace config.vocab_size value in english tokens
681
- # We do this because the default vocab size (32000) doesn't match with the
682
- # number of tokens because of seq_token_id(=32000) that was added
683
- # Therefore, we will put seq_token_id to 31999
684
- # (I will also put token n°31999 to 0, which is for unknown token)
685
- # This is a workaround to avoid having to change the vocab size in the config
686
- vocab_size = self.gpt_config.vocab_size
687
- # Replace vocab
688
- english_token_ids[english_token_ids == vocab_size - 1] = 0
689
- projection_english_tokens_ids[
690
- projection_english_tokens_ids == vocab_size - 1
691
- ] = 0
692
- english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
693
- projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
694
- vocab_size - 1
695
- )
696
-
697
- outs = {}
698
- if bio_token_ids is None:
699
- projected_bio_embeddings = None
700
- else:
701
- num_bio_sequences = bio_token_ids.shape[1]
702
-
703
- if projected_bio_embeddings is None:
704
- # Compute bio sequences embeddings
705
- bio_embeddings_list = [
706
- self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
707
- for bio_seq_num in range(num_bio_sequences)
708
- ]
709
-
710
-
711
- # Project these embeddings
712
- projected_bio_embeddings = []
713
- print("(debug) remember to remove loop for projected")
714
- for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list):
715
- proj, output = self.projection_model(
716
- bio_token_ids=bio_token_ids[:, bio_seq_num],
717
- bio_embeddings=bio_embeddings,
718
- english_token_ids=projection_english_tokens_ids,
719
- )
720
- projected_bio_embeddings.append(proj)
721
- for key in output.keys():
722
- outs[f"{key}_{bio_seq_num}"] = output[key]
723
- outs[f"bio_embeddings_list_{bio_seq_num}"] = proj
724
-
725
-
726
- projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
727
- outs["projected_bio_embeddings"] = projected_bio_embeddings.clone()
728
-
729
- # decode
730
- logits = self.biobrain_decoder(
731
- english_token_ids=english_token_ids,
732
- projected_bio_embeddings=projected_bio_embeddings,
733
- )
734
-
735
- outs["logits"] = logits
736
- outs["projected_bio_embeddings_end"] = projected_bio_embeddings.clone()
737
-
738
- return outs
739
-
740
-
741
- class TorchRotaryEmbedding(torch.nn.Module):
742
- def __init__(self, config: RotaryEmbeddingConfig):
743
- super().__init__()
744
-
745
- self.max_seq_len = config.max_seq_len
746
- self.dim = config.dim
747
- self.theta = config.theta
748
- self.sincos_cache = self._create_sinusoidal_positions()
749
-
750
- def _create_sinusoidal_positions(self) -> torch.Tensor:
751
- """
752
- Create the sines and cosines for the RoPE.
753
-
754
- Returns:
755
- Sinusoidal positions of shape (self.max_seq_len, self.dim).
756
- """
757
- # Create the inverse frequency based on theta and dim
758
- inv_freq = 1.0 / (
759
- self.theta ** (torch.arange(0, self.dim, 2).float() / self.dim)
760
- )
761
-
762
- # Compute sinusoidal input using the broadcasting
763
- sinusoid_inp = torch.einsum(
764
- "i,j->ij", torch.arange(self.max_seq_len).float(), inv_freq
765
- )
766
-
767
- # Apply sin and cos to the sinusoidal input
768
- sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
769
-
770
- # Allocate a tensor for the final sin-cos values
771
- sincos = torch.zeros((self.max_seq_len, self.dim), dtype=torch.float32)
772
-
773
- # Fill the sincos tensor with sin and cos values
774
- sentinel = self.dim // 2 + self.dim % 2
775
- sincos[:, :sentinel] = sin
776
- sincos[:, sentinel:] = cos
777
-
778
- return sincos
779
-
780
- def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
781
- """
782
- Prepare a tensor to apply the RoPE mechanism.
783
-
784
- Args:
785
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
786
- typically this is the key or query tensor.
787
-
788
- Returns:
789
- The even indices in the last dimension have their sign flipped.
790
- Tensor of shape (batch_size, seq_len, num_heads, head_dim).
791
- """
792
- # Split the tensor into two halves (odd and even indexed dimensions)
793
- rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
794
-
795
- # Reshape the tensor to the original shape
796
- rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
797
- return rotate_half
798
-
799
- def _apply_rotary_pos_emb(
800
- self, x: torch.Tensor, sincos: torch.Tensor
801
- ) -> torch.Tensor:
802
- """
803
- Applies rotary embeddings to x.
804
-
805
- Args:
806
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
807
- typically this is the key or query tensor.
808
- sincos: Tuple of sine and cosine tensors for position encoding.
809
-
810
- Returns:
811
- RoPE embeddings tensor.
812
- """
813
- sin_pos, cos_pos = sincos
814
-
815
- # Reshape the sin and cos tensors for broadcasting
816
- sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
817
- cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
818
-
819
- # Apply the rotary embedding mechanism
820
- return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
821
-
822
- def __call__(
823
- self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
824
- ) -> tuple[torch.Tensor, torch.Tensor]:
825
- """
826
- Applies rotary embeddings to k and q.
827
-
828
- Args:
829
- k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
830
- q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
831
- positions: optional positions offset useful when caching,
832
-
833
- Returns:
834
- RoPE embeddings for the keys and values.
835
- """
836
- batch_size, seq_len, num_heads, head_dim = k.shape
837
-
838
- # Generate position ids
839
- position_ids = (
840
- torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
841
- )
842
-
843
- if positions is not None:
844
- position_ids += positions
845
-
846
- # Retrieve sincos values using the position_ids
847
- sincos = self.sincos_cache[position_ids]
848
-
849
- # Split sincos into sin_pos and cos_pos
850
- sincos = torch.chunk(sincos, 2, dim=-1)
851
-
852
- # Apply rotary position embedding to key (k) and query (q)
853
- k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
854
- k_pass = k[..., self.dim :]
855
-
856
- q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
857
- q_pass = q[..., self.dim :]
858
-
859
- # Concatenate the rotated and non-rotated parts
860
- keys = torch.cat([k_rot, k_pass], dim=-1)
861
- values = torch.cat([q_rot, q_pass], dim=-1)
862
-
863
- return keys, values
864
-
865
-
866
- class TorchGptGroupedQueryAttention(nn.Module):
867
- def __init__(
868
- self,
869
- embed_dim: int,
870
- num_heads: int,
871
- rope_config: RotaryEmbeddingConfig,
872
- num_kv_heads: int = None, # type: ignore
873
- head_dim: int = None, # type: ignore
874
- add_bias_attn: bool = False, # type: ignore
875
- ) -> None:
876
- super().__init__()
877
- self.num_heads = num_heads
878
- self.num_kv_heads = num_kv_heads or num_heads
879
- self.embed_dim = embed_dim
880
- self.head_dim = head_dim or (embed_dim // num_heads)
881
- self.add_bias_attn = add_bias_attn
882
- self.rope = TorchRotaryEmbedding(rope_config)
883
-
884
- self.query_linear = nn.Linear(
885
- embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
886
- )
887
- self.key_linear = nn.Linear(
888
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
889
- )
890
- self.value_linear = nn.Linear(
891
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
892
- )
893
- self.out_linear = nn.Linear(
894
- self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
895
- )
896
-
897
- def forward(
898
- self,
899
- query_inputs: torch.Tensor,
900
- key_inputs: torch.Tensor,
901
- value_inputs: torch.Tensor,
902
- attention_mask: torch.Tensor = None,
903
- ) -> torch.Tensor:
904
- batch_size, seq_len, _ = query_inputs.shape
905
-
906
- queries = self.query_linear(query_inputs).view( # noqa
907
- batch_size, seq_len, self.num_heads, self.head_dim
908
- )
909
- keys = self.key_linear(key_inputs).view( # noqa
910
- batch_size, seq_len, self.num_kv_heads, self.head_dim
911
- )
912
- values = self.value_linear(value_inputs).view( # noqa
913
- batch_size, seq_len, self.num_kv_heads, self.head_dim
914
- )
915
-
916
- keys, queries = self.rope(keys, queries)
917
-
918
- n_rep = self.num_heads // self.num_kv_heads
919
- keys = keys.repeat_interleave(n_rep, dim=2)
920
- values = values.repeat_interleave(n_rep, dim=2)
921
-
922
- attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
923
- self.head_dim**0.5
924
- )
925
-
926
- if attention_mask is not None:
927
- attention_logits = attention_logits.masked_fill(
928
- attention_mask == 0, float("-inf")
929
- )
930
-
931
- attention_weights = nn.functional.softmax(attention_logits, dim=-1)
932
-
933
- values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
934
- values = values.contiguous().view(batch_size, seq_len, -1)
935
-
936
- return self.out_linear(values)
937
-
938
-
939
- class TorchGptDecoder(nn.Module):
940
- def __init__(self, config: GptConfig, name: Optional[str] = None):
941
- super().__init__()
942
- self.config = config
943
-
944
- self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
945
-
946
- if config.norm_type == "layer_norm":
947
- self.final_norm = nn.LayerNorm(config.embed_dim)
948
- elif config.norm_type == "RMS_norm":
949
- self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
950
- else:
951
- raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
952
-
953
- self.layers = nn.ModuleList(
954
- [
955
- TorchGptDecoderLayer(
956
- embed_dim=config.embed_dim,
957
- ffn_embed_dim=config.ffn_embed_dim,
958
- num_heads=config.num_heads,
959
- rope_config=config.rope_config,
960
- norm_type=config.norm_type,
961
- parallel_attention_ff=config.parallel_attention_ff,
962
- add_bias_ffn=config.add_bias_ffn,
963
- ffn_activation_name=config.ffn_activation_name,
964
- use_glu_in_ffn=config.use_glu_in_ffn,
965
- num_kv_heads=config.num_kv_heads, # type: ignore
966
- add_bias_attn=config.add_bias_attn,
967
- rms_norm_eps=config.rms_norm_eps,
968
- )
969
- for _ in range(config.num_layers)
970
- ]
971
- )
972
-
973
- self.lm_head = TorchSimpleLMHead(
974
- embed_dim=config.embed_dim,
975
- alphabet_size=config.vocab_size,
976
- add_bias_lm_head=config.add_bias_lm_head,
977
- )
978
-
979
- def apply_transformer_layers(
980
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
981
- ) -> torch.Tensor:
982
- if attention_mask is None:
983
- attention_mask = build_causal_attention_mask(1, embeddings.shape[1])
984
- for layer in self.layers:
985
- embeddings = layer(embeddings, attention_mask)
986
-
987
- return embeddings
988
-
989
- def forward(
990
- self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
991
- ) -> dict[str, torch.Tensor]:
992
- if attention_mask is None:
993
- attention_mask = build_causal_attention_mask(1, token_ids.shape[1])
994
-
995
- tokens_embeddings = self.token_embed(token_ids)
996
-
997
- after_transformer_embeddings = self.apply_transformer_layers(
998
- tokens_embeddings, attention_mask=attention_mask
999
- )
1000
-
1001
- embeddings = self.final_norm(after_transformer_embeddings)
1002
- logits = self.lm_head(embeddings)
1003
- return {"embeddings": embeddings, "logits": logits}
1004
-
1005
-
1006
- class TorchSimpleLMHead(nn.Module):
1007
- def __init__(
1008
- self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
1009
- ) -> None:
1010
- super().__init__()
1011
- self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
1012
-
1013
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1014
- return self.fc(x)
1015
-
1016
-
1017
- class TorchGptDecoderLayer(nn.Module):
1018
- def __init__(
1019
- self,
1020
- embed_dim: int,
1021
- ffn_embed_dim: int,
1022
- num_heads: int,
1023
- rope_config: RotaryEmbeddingConfig,
1024
- norm_type: str,
1025
- parallel_attention_ff: bool,
1026
- add_bias_ffn: bool,
1027
- ffn_activation_name: str,
1028
- use_glu_in_ffn: bool,
1029
- num_kv_heads: int,
1030
- add_bias_attn: bool,
1031
- rms_norm_eps: float = 1e-6,
1032
- ) -> None:
1033
- super().__init__()
1034
- self.num_heads = num_heads
1035
- self.parallel_attention_ff = parallel_attention_ff
1036
- self.use_glu_in_ffn = use_glu_in_ffn
1037
-
1038
- # Self-Attention layer
1039
- self.self_attn = TorchGptGroupedQueryAttention(
1040
- embed_dim=embed_dim,
1041
- num_heads=num_heads,
1042
- num_kv_heads=num_kv_heads,
1043
- rope_config=rope_config,
1044
- add_bias_attn=add_bias_attn,
1045
- )
1046
-
1047
- # Normalization layers
1048
- if norm_type == "layer_norm":
1049
- self.attn_norm = nn.LayerNorm(embed_dim)
1050
- if not self.parallel_attention_ff:
1051
- self.ffn_norm = nn.LayerNorm(embed_dim)
1052
- elif norm_type == "RMS_norm":
1053
- self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1054
- if not self.parallel_attention_ff:
1055
- self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1056
- else:
1057
- raise ValueError(f"unrecognized norm_type: {norm_type}")
1058
-
1059
- # Feedforward network
1060
- self.activation = get_activation_fn(ffn_activation_name)
1061
- ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1062
- self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1063
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1064
-
1065
- def forward(
1066
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1067
- ) -> torch.Tensor:
1068
- residuals = embeddings
1069
-
1070
- if self.parallel_attention_ff:
1071
- # Parallel Attention + MLP
1072
- embeddings_normed = self.attn_norm(embeddings)
1073
-
1074
- attn_output, _ = self.self_attn(
1075
- embeddings_normed,
1076
- embeddings_normed,
1077
- embeddings_normed,
1078
- attn_mask=attention_mask,
1079
- )
1080
- ffn_output = self.mlp(embeddings_normed) # type: ignore
1081
-
1082
- return residuals + attn_output + ffn_output
1083
- else:
1084
- # Sequential Attention + MLP
1085
- normed_embeddings = self.attn_norm(embeddings)
1086
-
1087
- attn_output = embeddings + self.self_attn(
1088
- normed_embeddings,
1089
- normed_embeddings,
1090
- normed_embeddings,
1091
- attention_mask=attention_mask,
1092
- )
1093
-
1094
- normed_embeddings2 = self.ffn_norm(attn_output)
1095
- ffn_output = self.mlp(normed_embeddings2) # type: ignore
1096
- return attn_output + ffn_output # Residual connection
1097
-
1098
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1099
- """Applies the feedforward network (MLP) with optional GLU."""
1100
- ffn_output = self.fc1(x)
1101
-
1102
- if self.use_glu_in_ffn:
1103
- ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1104
- ffn_output = self.activation(ffn_output1) * ffn_output2
1105
- else:
1106
- ffn_output = self.activation(ffn_output)
1107
-
1108
- return self.fc2(ffn_output)
1109
-
1110
-
1111
- class TorchRMSNorm(nn.Module):
1112
- def __init__(self, dim: int, eps: float = 1e-6) -> None:
1113
- super().__init__()
1114
- self.eps = eps
1115
- self.scale = nn.Parameter(torch.ones(dim))
1116
-
1117
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1118
- return (
1119
- x
1120
- * self.scale
1121
- / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1122
- )
1123
-
1124
-
1125
- def get_activation_fn(activation_name: str): # type: ignore
1126
- activations = {
1127
- "gelu": nn.functional.gelu,
1128
- "relu": nn.functional.relu,
1129
- "swish": nn.functional.silu,
1130
- "silu": nn.functional.silu,
1131
- }
1132
- return activations.get(activation_name, nn.functional.relu)
1133
-
1134
-
1135
- def build_causal_attention_mask(batch_size: int, seq_len: int) -> torch.Tensor:
1136
- """
1137
- Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1138
- to an attention layer.
1139
-
1140
- Args:
1141
- batch_size: Batch size.
1142
- seq_len: Length of the sequences.
1143
-
1144
- Returns:
1145
- Batch of causal masks.
1146
- """
1147
- mask = torch.ones((batch_size, 1, seq_len, seq_len))
1148
- causal_mask = torch.tril(mask)
1149
- return causal_mask
1150
-
1151
-
1152
- @dataclass
1153
- class RotaryEmbeddingConfigBis:
1154
- """
1155
- Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1156
- to adapt the rotary embeddings to larger lengths than what was used for training.
1157
- One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1158
- Args:
1159
- """
1160
-
1161
- rescaling_factor: Optional[float]
1162
-
1163
-
1164
- class RotaryEmbeddingBis(torch.nn.Module):
1165
- """
1166
- Rotary position embeddings based on those in
1167
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1168
- Query and keys are transformed by rotation
1169
- matrices which depend on their relative positions.
1170
- """
1171
-
1172
- def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1173
- super().__init__()
1174
-
1175
- # Extract argument from the config
1176
- self.rescaling_factor = rotary_embedding_config.rescaling_factor
1177
- self.upper_freq = 10000
1178
- self.dim = dim
1179
-
1180
- self._seq_len_cached = None
1181
- self._cos_cached = None
1182
- self._sin_cached = None
1183
-
1184
- def _apply_rotary_pos_emb(
1185
- self,
1186
- heads: torch.Tensor,
1187
- cos: torch.Tensor,
1188
- sin: torch.Tensor,
1189
- ) -> torch.Tensor:
1190
- """ """
1191
- x_first, x_second = (
1192
- heads[..., : heads.shape[-1] // 2],
1193
- heads[..., heads.shape[-1] // 2 :],
1194
- )
1195
-
1196
- first_part = x_first * cos - x_second * sin
1197
- second_part = x_second * cos + x_first * sin
1198
-
1199
- return torch.cat((first_part, second_part), dim=-1)
1200
-
1201
- def _compute_cos_sin_tables(
1202
- self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1203
- ) -> tuple[torch.Tensor, torch.Tensor]:
1204
- seq_len = x.shape[seq_dimension]
1205
- # Reset the tables if the sequence length has changed,
1206
- # or if we're on a new device (possibly due to tracing for instance)
1207
- self._seq_len_cached = seq_len
1208
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1209
- # freqs = torch.outer(t, inv_freq)
1210
- freqs = torch.einsum("i, j -> ij", t, inv_freq)
1211
-
1212
- self._cos_cached = torch.cos(freqs)[None, :, None, :]
1213
- self._sin_cached = torch.sin(freqs)[None, :, None, :]
1214
- # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1215
-
1216
- # self._cos_cached = emb.cos()[None, None, :, :]
1217
- # self._sin_cached = emb.sin()[None, None, :, :]
1218
-
1219
- return self._cos_cached, self._sin_cached
1220
-
1221
- def forward(
1222
- self, q: torch.Tensor, k: torch.Tensor
1223
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1224
- if self.rescaling_factor is None:
1225
- inv_freq = 1.0 / (
1226
- self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
1227
- )
1228
- else:
1229
- updated_base = self.upper_freq * (
1230
- self.rescaling_factor ** (self.dim / (self.dim - 2))
1231
- )
1232
- inv_freq = 1.0 / (
1233
- updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
1234
- )
1235
-
1236
- self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1237
- q,
1238
- inv_freq,
1239
- seq_dimension=-3,
1240
- )
1241
-
1242
- return (
1243
- self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1244
- self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1245
- )
1246
-
1247
-
1248
- class MultiHeadAttention(nn.Module):
1249
- def __init__(
1250
- self,
1251
- num_heads: int,
1252
- key_size: int,
1253
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1254
- add_bias_kv: bool = False,
1255
- value_size: Optional[int] = None,
1256
- model_size: Optional[int] = None,
1257
- name: Optional[str] = None,
1258
- ):
1259
- super().__init__()
1260
- if not model_size:
1261
- model_size = key_size * num_heads
1262
- if not value_size:
1263
- value_size = key_size
1264
- self.model_size = model_size
1265
- self.key_size = key_size
1266
- self.value_size = value_size
1267
- self.add_bias_kv = add_bias_kv
1268
- self.name = name
1269
- self.num_heads = num_heads
1270
- self._rotary_embedding_config = rotary_embedding_config
1271
-
1272
- self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1273
- self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1274
- self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1275
- self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1276
- if self._rotary_embedding_config:
1277
- self._rotary_embedding = RotaryEmbeddingBis(
1278
- self.key_size, self._rotary_embedding_config
1279
- )
1280
-
1281
- def apply_rotary_embeddings(
1282
- self,
1283
- query: torch.Tensor,
1284
- key: torch.Tensor,
1285
- ) -> tuple[torch.Tensor, torch.Tensor]:
1286
- """ """
1287
- query, key = self._rotary_embedding(query, key)
1288
- return query, key
1289
-
1290
- def forward(
1291
- self,
1292
- query: torch.Tensor,
1293
- key: torch.Tensor,
1294
- value: torch.Tensor,
1295
- attention_mask: Optional[torch.Tensor] = None,
1296
- attention_weight_bias: Optional[torch.Tensor] = None,
1297
- ) -> dict[str, torch.Tensor]:
1298
- """
1299
- Returns:
1300
- dictionary containing attention weights
1301
- and outputs.
1302
- """
1303
- key_heads = self.w_k(key).reshape(
1304
- (*key.shape[:-1], self.num_heads, self.key_size)
1305
- )
1306
- query_heads = self.w_q(query).reshape(
1307
- (*query.shape[:-1], self.num_heads, self.key_size)
1308
- )
1309
- value_heads = self.w_v(value).reshape(
1310
- (*value.shape[:-1], self.num_heads, self.value_size)
1311
- )
1312
- if self._rotary_embedding_config:
1313
- query_heads, key_heads = self.apply_rotary_embeddings(
1314
- query_heads, key_heads
1315
- )
1316
- attention_weights = torch.einsum(
1317
- "...thd, ...Thd -> ...htT", query_heads, key_heads
1318
- )
1319
- sqrt_key_size = np.sqrt(self.key_size)
1320
- attention_weights = attention_weights / sqrt_key_size
1321
- if attention_mask is not None:
1322
- attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1323
- if attention_weight_bias is not None:
1324
- attention_weights = F.softmax(
1325
- attention_weights + attention_weight_bias, dim=-1
1326
- )
1327
- else:
1328
- attention_weights = F.softmax(attention_weights, dim=-1)
1329
- value_out = torch.einsum(
1330
- "...htT, ...Thd->...thd", attention_weights, value_heads
1331
- )
1332
- value_out = value_out.reshape((*value_out.shape[:-2], -1))
1333
- embeddings = self.output(value_out)
1334
-
1335
- return {"attention_weights": attention_weights, "embeddings": embeddings}
1336
-
1337
-
1338
- class SelfAttentionBlock(nn.Module):
1339
- def __init__(
1340
- self,
1341
- num_heads: int,
1342
- embed_dim: int,
1343
- ffn_embed_dim: int,
1344
- key_size: Optional[int] = None,
1345
- add_bias_kv: bool = False,
1346
- add_bias_fnn: bool = True,
1347
- ffn_activation_name: str = "gelu-no-approx",
1348
- use_glu_in_ffn: bool = False,
1349
- layer_norm_eps: float = 1e-5, # this is the default haiku value
1350
- pre_layer_norm: bool = True,
1351
- name: Optional[str] = None,
1352
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1353
- ):
1354
- super().__init__()
1355
- if key_size is None:
1356
- if embed_dim % num_heads != 0:
1357
- raise ValueError(
1358
- f"The embedding dimension should be divisible by the number of "
1359
- f"heads, however provided embedding dimension is {embed_dim} and "
1360
- f"the number of heads is {num_heads}."
1361
- )
1362
- else:
1363
- key_size = embed_dim // num_heads
1364
-
1365
- # Get ffn activation function
1366
- self._pre_layer_norm = pre_layer_norm
1367
- self._use_glu_in_fnn = use_glu_in_ffn
1368
- # Define layers
1369
- if use_glu_in_ffn:
1370
- # user should multiply ffn_embed_dim by 2/3 when using GLU
1371
- # to keep total number of parameters equal
1372
- # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1373
- # we multiply by 2 here as the output will be split in 2 for GLU
1374
- self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1375
- else:
1376
- self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1377
-
1378
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1379
-
1380
- self.layer_norm_self_attention = nn.LayerNorm(
1381
- embed_dim,
1382
- )
1383
- self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1384
- if ffn_activation_name == "swish":
1385
- self._ffn_activation_fn = nn.SiLU()
1386
- elif ffn_activation_name == "gelu-no-approx":
1387
- self._ffn_activation_fn = nn.GELU(approximate="tanh")
1388
- else:
1389
- self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1390
-
1391
- self.mha = MultiHeadAttention(
1392
- num_heads=num_heads,
1393
- key_size=key_size,
1394
- add_bias_kv=add_bias_kv,
1395
- model_size=embed_dim,
1396
- name="self_attention",
1397
- rotary_embedding_config=rotary_embedding_config,
1398
- )
1399
-
1400
- def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1401
-
1402
- if self._pre_layer_norm:
1403
- x = self.layer_norm_mlp(embed)
1404
- else:
1405
- x = embed
1406
-
1407
- if self._use_glu_in_fnn:
1408
- x = self.fc1(x)
1409
- x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1410
- x = self._ffn_activation_fn(x1) * x2
1411
- else:
1412
- x = self._ffn_activation_fn(self.fc1(x))
1413
- x = self.fc2(x)
1414
-
1415
- if not self._pre_layer_norm:
1416
- x = self.layer_norm_mlp(x + embed)
1417
- return x
1418
-
1419
- def forward(
1420
- self,
1421
- x: torch.Tensor,
1422
- attention_mask: Optional[torch.Tensor] = None,
1423
- attention_weight_bias: Optional[torch.Tensor] = None,
1424
- ) -> dict[str, torch.Tensor]:
1425
-
1426
- res = x
1427
- if self._pre_layer_norm:
1428
- x = self.layer_norm_self_attention(x)
1429
-
1430
- output: dict[str, torch.Tensor] = self.mha(
1431
- x,
1432
- x,
1433
- x,
1434
- attention_mask=attention_mask,
1435
- attention_weight_bias=attention_weight_bias,
1436
- )
1437
-
1438
- if not self._pre_layer_norm:
1439
- output["embeddings"] = self.layer_norm_self_attention(
1440
- output["embeddings"] + res
1441
- )
1442
-
1443
- x = output["embeddings"]
1444
- else:
1445
- x = output["embeddings"]
1446
- x = res + x
1447
-
1448
- # MLP
1449
- if not self._pre_layer_norm:
1450
- x = self.mlp(x)
1451
- else:
1452
- x = x + self.mlp(x)
1453
-
1454
- output["embeddings"] = x
1455
- return output
1456
-
1457
-
1458
- class RobertaLMHead(nn.Module):
1459
- """
1460
- Roberta Language Model head. Transforms final attention layer output into a
1461
- distribution over tokens at each position.
1462
- """
1463
-
1464
- def __init__(self, embed_dim: int, alphabet_size: int):
1465
- """
1466
- Args:
1467
- embed_dim: Embedding dimension.
1468
- alphabet_size: Number of tokens in the alphabet.
1469
- """
1470
- super().__init__()
1471
- self.embed_dim = embed_dim
1472
- self.alphabet_size = alphabet_size
1473
-
1474
- # Define layers
1475
- self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1476
- self._fc1 = nn.Linear(embed_dim, embed_dim)
1477
- self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1478
- self._final_fc = nn.Linear(embed_dim, alphabet_size)
1479
-
1480
- def forward(self, x: torch.Tensor) -> dict:
1481
- x = self._first_layer_norm(x)
1482
- embeddings = x
1483
- x = self._fc1(x)
1484
- x = nn.functional.gelu(x)
1485
- x = self._second_layer_norm(x)
1486
- logits = self._final_fc(x)
1487
- return {"embeddings": embeddings, "logits": logits}
1488
-
1489
-
1490
- class TorchESMTransformer(nn.Module):
1491
- def __init__(
1492
- self,
1493
- esm_config: ESMTransformerConfig,
1494
- ):
1495
- super(TorchESMTransformer, self).__init__()
1496
- self.esm_config = esm_config
1497
-
1498
- # Other cases are not implemented
1499
- assert esm_config.positional_embedding is None
1500
- assert esm_config.lm_head == "roberta"
1501
- assert esm_config.use_rotary_embedding is True
1502
- assert esm_config.token_dropout is False
1503
- assert esm_config.emb_layer_norm_before is False
1504
- assert esm_config.mask_before_attention is False
1505
- assert esm_config.bias_word_embedding is False
1506
- assert esm_config.use_gradient_checkpointing is False
1507
-
1508
- self.embed_layer = nn.Embedding(esm_config.alphabet_size, esm_config.embed_dim)
1509
-
1510
- self.lm_head = RobertaLMHead(
1511
- embed_dim=esm_config.embed_dim,
1512
- alphabet_size=esm_config.alphabet_size,
1513
- )
1514
-
1515
- self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1516
- rescaling_factor=esm_config.rescaling_factor
1517
- )
1518
-
1519
- self.attention_blocks = nn.ModuleList(
1520
- [
1521
- SelfAttentionBlock( # type: ignore
1522
- num_heads=esm_config.attention_heads,
1523
- embed_dim=esm_config.embed_dim,
1524
- key_size=esm_config.key_size,
1525
- ffn_embed_dim=esm_config.ffn_embed_dim,
1526
- add_bias_kv=esm_config.add_bias_kv,
1527
- add_bias_fnn=esm_config.add_bias_ffn,
1528
- ffn_activation_name=esm_config.ffn_activation_name,
1529
- use_glu_in_ffn=esm_config.use_glu_in_ffn,
1530
- rotary_embedding_config=self.rotary_embedding_config,
1531
- layer_norm_eps=esm_config.layer_norm_eps,
1532
- pre_layer_norm=esm_config.pre_layer_norm,
1533
- )
1534
- for _ in range(esm_config.num_layers)
1535
- ]
1536
- )
1537
-
1538
- def forward(
1539
- self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1540
- ) -> torch.Tensor:
1541
- """
1542
- Computes the embeddings based on the input tokens.
1543
-
1544
- Args:
1545
- tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1546
- attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1547
- If no mask is provided, a mask by default which equals 1 over all non
1548
- pad tokens and 0 over pad tokens is computed.
1549
-
1550
- Returns:
1551
- Dictionary containing the final embeddings and logits.
1552
- """
1553
- x = self.embed_layer(tokens)
1554
-
1555
- # RoBERTa's mask scaling factor
1556
- x = self.esm_config.embed_scale * x
1557
-
1558
- if attention_mask is None:
1559
- attention_mask = build_padding_attention_mask(
1560
- tokens=tokens, pad_token_id=self.esm_config.pad_token_id
1561
- )
1562
-
1563
- for layer in self.attention_blocks:
1564
- x = layer(x, attention_mask)["embeddings"]
1565
-
1566
- assert self.esm_config.lm_head == "roberta"
1567
- x = self.lm_head(x)["embeddings"]
1568
-
1569
- return x
1570
-
1571
-
1572
- def build_padding_attention_mask(
1573
- tokens: torch.Tensor, pad_token_id: int
1574
- ) -> torch.Tensor:
1575
- """
1576
- Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1577
-
1578
- Args:
1579
- tokens: Batch of sequences of shape (batch_size, seq_len).
1580
- pad_token_id: Int corresponding to the <pad> token to mask.
1581
-
1582
- Returns:
1583
- Batch of attention masks, masking out <pad> tokens.
1584
- """
1585
- padding_mask = tokens != pad_token_id
1586
- padding_mask = padding_mask.unsqueeze(1)
1587
- padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1588
- return padding_mask
1589
-
1590
-
1591
- class TorchBioBrainEncoder(nn.Module):
1592
- def __init__(
1593
- self,
1594
- esm_config: ESMTransformerConfig,
1595
- ):
1596
- super(TorchBioBrainEncoder, self).__init__()
1597
- self.esm_config = esm_config
1598
- self.esm_model = TorchESMTransformer(self.esm_config)
1599
-
1600
- def forward(
1601
- self,
1602
- bio_token_ids: torch.Tensor,
1603
- ) -> torch.Tensor:
1604
- """
1605
- Args:
1606
- bio_token_ids (torch.Tensor):
1607
- Shape (batch_size, num_bio_tokens)
1608
-
1609
- Returns:
1610
- torch.Tensor:
1611
- Shape (batch_size, num_bio_tokens, embed_dim)
1612
- """
1613
- bio_embeddings = self.esm_model(tokens=bio_token_ids)
1614
-
1615
- return bio_embeddings
1616
-
1617
-
1618
- class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1619
- def __init__(
1620
- self,
1621
- num_heads: int,
1622
- embed_dim: int,
1623
- ffn_embed_dim: int,
1624
- key_size: Optional[int] = None,
1625
- add_bias_kv: bool = False,
1626
- add_bias_ffn: bool = True,
1627
- ffn_activation_name: str = "gelu",
1628
- use_glu_in_ffn: bool = False,
1629
- ):
1630
- super().__init__()
1631
-
1632
- if key_size is None:
1633
- if embed_dim % num_heads != 0:
1634
- raise ValueError(
1635
- f"Embedding dimension {embed_dim} should be divisible by "
1636
- f"num_heads {num_heads}."
1637
- )
1638
- key_size = embed_dim // num_heads
1639
-
1640
- self.num_heads = num_heads
1641
- self.embed_dim = embed_dim
1642
- self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1643
- self.use_glu_in_ffn = use_glu_in_ffn
1644
-
1645
- self.cross_attention_1 = MultiHeadAttention(
1646
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1647
- )
1648
- self.cross_attention_2 = MultiHeadAttention(
1649
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1650
- )
1651
-
1652
- self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1653
- self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1654
- self.norm_mlp = nn.LayerNorm(embed_dim)
1655
-
1656
- self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1657
- self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1658
-
1659
- self.activation_fn = getattr(
1660
- nn.functional, ffn_activation_name, nn.functional.gelu
1661
- )
1662
-
1663
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1664
- outs = {}
1665
- x = self.norm_mlp(x)
1666
- outs["MLP_layer0_layer_norm"] = x.clone()
1667
- if self.use_glu_in_ffn:
1668
- x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1669
- x = self.activation_fn(x1) * x2
1670
- else:
1671
- x = self.fc1(x)
1672
- outs["MLP_layer1_fc1"] = x.clone()
1673
- x = self.activation_fn(x)
1674
- outs["MLP_layer2_activation"] = x.clone()
1675
-
1676
- x = self.fc2(x)
1677
- outs["MLP_layer3_fc2"] = x.clone()
1678
- outs["x"] = x.clone()
1679
-
1680
- return outs
1681
-
1682
- def forward(
1683
- self,
1684
- x: torch.Tensor,
1685
- cross_attention_embeddings_1: torch.Tensor,
1686
- cross_attention_embeddings_2: torch.Tensor,
1687
- attention_mask_1: Optional[torch.Tensor] = None,
1688
- attention_mask_2: Optional[torch.Tensor] = None,
1689
- ) -> Dict[str, torch.Tensor]:
1690
- outs_news = {}
1691
- res = x
1692
- x = self.norm_cross_attention_1(x)
1693
- outs_news["ATTENTION_layer0_layer_norm_cross_attention_1"] = x.clone()
1694
-
1695
- attn_output = self.cross_attention_1(
1696
- query=x,
1697
- key=cross_attention_embeddings_1,
1698
- value=cross_attention_embeddings_1,
1699
- attention_mask=attention_mask_1,
1700
- )["embeddings"]
1701
- outs_news["ATTENTION_layer1_cross_attention_layer_1"] = attn_output.clone()
1702
- x = res + attn_output
1703
-
1704
- res = x
1705
- x = self.norm_cross_attention_2(x)
1706
- outs_news["ATTENTION_layer2_cross_attention_layer_2"] = x.clone()
1707
- attn_output = self.cross_attention_2(
1708
- query=x,
1709
- key=cross_attention_embeddings_2,
1710
- value=cross_attention_embeddings_2,
1711
- attention_mask=attention_mask_2,
1712
- )["embeddings"]
1713
- outs_news["ATTENTION_layer3_cross_attention_layer_2"] = attn_output.clone()
1714
- x = res + attn_output
1715
-
1716
- mlp_output = self.mlp(x)
1717
- x = x + mlp_output["x"]
1718
- outs_news["ATTENTION_after_mlp"] = x.clone()
1719
-
1720
- output = {}
1721
- for key in outs_news.keys():
1722
- output[key] = outs_news[key]
1723
-
1724
- output["embeddings"] = x
1725
- return output
1726
-
1727
-
1728
- class TorchMultiModalPerceiverResampler(nn.Module):
1729
- """
1730
- Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1731
- """
1732
-
1733
- def __init__(
1734
- self,
1735
- config: PerceiverResamplerConfig,
1736
- name: Optional[str] = None,
1737
- ):
1738
- """
1739
- Initialize a Perceiver Resampler model.
1740
-
1741
- Args:
1742
- config: Dataclass containing model hyperparameters.
1743
- name: Name for module (custom will break weight loading).
1744
- """
1745
- super().__init__()
1746
- self.config = config
1747
- self.name = name
1748
- self.layers = nn.ModuleList(
1749
- [
1750
- TorchMultiModalPerceiverResamplerBlock(
1751
- num_heads=self.config.attention_heads,
1752
- embed_dim=self.config.embed_dim,
1753
- key_size=self.config.key_size,
1754
- ffn_embed_dim=self.config.ffn_embed_dim,
1755
- add_bias_kv=self.config.add_bias_kv,
1756
- add_bias_ffn=self.config.add_bias_ffn,
1757
- ffn_activation_name=self.config.ffn_activation_name,
1758
- use_glu_in_ffn=self.config.use_glu_in_ffn,
1759
- )
1760
- for _ in range(self.config.num_layers)
1761
- ]
1762
- )
1763
-
1764
- self.latent_queries = torch.nn.Parameter(
1765
- torch.randn(self.config.resampled_length, self.config.embed_dim)
1766
- * (
1767
- 1.0
1768
- / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1769
- )
1770
- )
1771
-
1772
- def apply_attention_blocks(
1773
- self,
1774
- x: torch.Tensor,
1775
- xf_1: torch.Tensor,
1776
- xf_2: torch.Tensor,
1777
- outs: Dict[str, torch.Tensor],
1778
- attention_mask_1: Optional[torch.Tensor] = None,
1779
- attention_mask_2: Optional[torch.Tensor] = None,
1780
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1781
- """
1782
- Create the blocks of attention layers and applies them.
1783
- """
1784
- for layer_idx, layer in enumerate(self.layers):
1785
- concat_input_1 = torch.cat([xf_1, x], dim=1)
1786
- concat_input_2 = torch.cat([xf_2, x], dim=1)
1787
-
1788
- #outs[f"PERCEIVER_RESAMPLER_concat_input_1_{layer_idx}"] = concat_input_1.clone()
1789
- #outs[f"PERCEIVER_RESAMPLER_concat_input_2_{layer_idx}"] = concat_input_2.clone()
1790
-
1791
- output = layer(
1792
- x=x,
1793
- cross_attention_embeddings_1=concat_input_1,
1794
- cross_attention_embeddings_2=concat_input_2,
1795
- attention_mask_1=attention_mask_1,
1796
- attention_mask_2=attention_mask_2,
1797
- )
1798
- x = output["embeddings"]
1799
- #outs[f"PERCEIVER_RESAMPLER_attention_embeddings_{layer_idx}"] = output["embeddings"].clone()
1800
-
1801
- for key in output.keys():
1802
- if key != "embeddings":
1803
- outs[f"{key}_{layer_idx}"] = output[key].clone()
1804
-
1805
- return x, outs
1806
-
1807
- def forward(
1808
- self,
1809
- input_embeddings_1: torch.Tensor,
1810
- input_embeddings_2: torch.Tensor,
1811
- attention_mask_1: Optional[torch.Tensor] = None,
1812
- attention_mask_2: Optional[torch.Tensor] = None,
1813
- ) -> Dict[str, torch.Tensor]:
1814
- """
1815
- Computes the embeddings based on the input tokens.
1816
- """
1817
- new_outs = {}
1818
- new_outs["input_embeddings_1"] = input_embeddings_1.clone()
1819
- new_outs["input_embeddings_2"] = input_embeddings_2.clone()
1820
-
1821
- assert (
1822
- input_embeddings_1.shape[-1] == self.config.embed_dim
1823
- ), "The input embedding dim should match the model embed dim"
1824
- assert (
1825
- input_embeddings_2.shape[-1] == self.config.embed_dim
1826
- ), "The input embedding dim should match the model embed dim"
1827
-
1828
- batch_size = input_embeddings_1.shape[0]
1829
-
1830
- latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1831
-
1832
- outs: Dict[str, torch.Tensor] = {}
1833
- x = latent_queries
1834
-
1835
- new_outs["latent_queries"] = x.clone()
1836
-
1837
- x, outs = self.apply_attention_blocks(
1838
- x=x,
1839
- xf_1=input_embeddings_1,
1840
- xf_2=input_embeddings_2,
1841
- outs=outs,
1842
- attention_mask_1=attention_mask_1,
1843
- attention_mask_2=attention_mask_2,
1844
- )
1845
-
1846
- for key in outs.keys():
1847
- new_outs[key] = outs[key].clone()
1848
-
1849
- outs["embeddings"] = x
1850
-
1851
- return outs, new_outs
1852
-
1853
-
1854
- class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1855
- def __init__(
1856
- self,
1857
- perceiver_resampler_config: PerceiverResamplerConfig,
1858
- input_embed_dim: int,
1859
- embed_dim: int,
1860
- bio_pad_token_id: int,
1861
- english_pad_token_id: int,
1862
- english_vocab_size: int,
1863
- ):
1864
- super().__init__()
1865
- self.config = perceiver_resampler_config
1866
- self.input_embed_dim = input_embed_dim
1867
- self.embed_dim = embed_dim
1868
- self.bio_pad_token_id = bio_pad_token_id
1869
- self.english_pad_token_id = english_pad_token_id
1870
- self.english_vocab_size = english_vocab_size
1871
-
1872
- self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1873
- self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1874
- self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1875
-
1876
- def forward(
1877
- self,
1878
- bio_token_ids: torch.Tensor,
1879
- bio_embeddings: torch.Tensor,
1880
- english_token_ids: torch.Tensor,
1881
- ) -> torch.Tensor:
1882
- """
1883
- Args:
1884
- bio_token_ids (torch.Tensor):
1885
- Shape (batch_size, num_bio_tokens)
1886
-
1887
- bio_embeddings (torch.Tensor):
1888
- Shape (batch_size, num_bio_tokens, embed_dim)
1889
-
1890
- english_token_ids (torch.Tensor):
1891
- Shape (batch_size, num_english_tokens)
1892
- """
1893
- outs = {}
1894
- projected_bio_embeddings = self.bio_projection(bio_embeddings)
1895
- print("(debug) remember to remove this projected_bio_embeddings out, and 'outs' output")
1896
- outs['projected_bio_embeddings'] = projected_bio_embeddings
1897
- english_embeddings = self.token_embedding(english_token_ids)
1898
- outs['english_embeddings'] = english_embeddings
1899
-
1900
- bio_attention_mask = build_perceiver_padding_attention_mask(
1901
- bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1902
- )
1903
- english_attention_mask = build_perceiver_padding_attention_mask(
1904
- english_token_ids, self.config.resampled_length, self.english_pad_token_id
1905
- )
1906
-
1907
- projected_embeddings, new_outs = self.perceiver_resampler(
1908
- input_embeddings_1=projected_bio_embeddings,
1909
- attention_mask_1=bio_attention_mask,
1910
- input_embeddings_2=english_embeddings,
1911
- attention_mask_2=english_attention_mask,
1912
- )
1913
- projected_embeddings = projected_embeddings["embeddings"]
1914
-
1915
- for key in new_outs.keys():
1916
- outs[f"PERCEIVER_{key}"] = new_outs[key]
1917
-
1918
- return projected_embeddings, outs
1919
-
1920
-
1921
- def build_perceiver_padding_attention_mask(
1922
- tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1923
- ) -> torch.Tensor:
1924
- batch_size, seq_len = tokens.shape
1925
- padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1926
-
1927
- padding_mask = torch.cat(
1928
- [
1929
- padding_mask,
1930
- torch.ones(
1931
- (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1932
- ),
1933
- ],
1934
- dim=1,
1935
- ) # (batch_size, seq_len + resampled_length)
1936
-
1937
- padding_mask = padding_mask[:, None, None, :]
1938
- padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1939
- return padding_mask