Upload x_transformer_1_23_2.py
Browse files- x_transformer_1_23_2.py +14 -5
x_transformer_1_23_2.py
CHANGED
|
@@ -26,10 +26,16 @@
|
|
| 26 |
from functools import partial
|
| 27 |
from typing import Optional, Tuple
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
import torch
|
| 30 |
from torch import nn, einsum, Tensor
|
| 31 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 32 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
|
|
| 33 |
|
| 34 |
from collections import namedtuple
|
| 35 |
from functools import wraps
|
|
@@ -259,11 +265,14 @@ class Attend(nn.Module):
|
|
| 259 |
|
| 260 |
# Legacy code...
|
| 261 |
# with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
|
|
|
|
| 262 |
|
| 263 |
-
#
|
| 264 |
-
|
| 265 |
-
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 266 |
|
|
|
|
|
|
|
|
|
|
| 267 |
out = F.scaled_dot_product_attention(
|
| 268 |
q, k, v,
|
| 269 |
attn_mask = mask,
|
|
@@ -508,7 +517,7 @@ class AutoregressiveWrapper(Module):
|
|
| 508 |
# whether to add router z-loss
|
| 509 |
self.add_attn_z_loss = add_attn_z_loss
|
| 510 |
|
| 511 |
-
@torch.
|
| 512 |
@eval_decorator
|
| 513 |
def generate(
|
| 514 |
self,
|
|
@@ -2462,4 +2471,4 @@ class XTransformer(nn.Module):
|
|
| 2462 |
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
| 2463 |
|
| 2464 |
out = self.decoder(tgt, context = enc, context_mask = mask)
|
| 2465 |
-
return out
|
|
|
|
| 26 |
from functools import partial
|
| 27 |
from typing import Optional, Tuple
|
| 28 |
|
| 29 |
+
import os
|
| 30 |
+
os.environ['USE_FLASH_ATTENTION'] = '1'
|
| 31 |
+
|
| 32 |
import torch
|
| 33 |
from torch import nn, einsum, Tensor
|
| 34 |
import torch.nn.functional as F
|
| 35 |
+
|
| 36 |
+
# Flash attention
|
| 37 |
from torch.nn.attention import SDPBackend, sdpa_kernel
|
| 38 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 39 |
|
| 40 |
from collections import namedtuple
|
| 41 |
from functools import wraps
|
|
|
|
| 265 |
|
| 266 |
# Legacy code...
|
| 267 |
# with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
|
| 268 |
+
# with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
|
| 269 |
|
| 270 |
+
# PyTorch 2.3-2.4 SDPA backend code...
|
| 271 |
+
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION]):
|
|
|
|
| 272 |
|
| 273 |
+
# New PyTorch 2.5 SDPA backend code:
|
| 274 |
+
# with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
|
| 275 |
+
|
| 276 |
out = F.scaled_dot_product_attention(
|
| 277 |
q, k, v,
|
| 278 |
attn_mask = mask,
|
|
|
|
| 517 |
# whether to add router z-loss
|
| 518 |
self.add_attn_z_loss = add_attn_z_loss
|
| 519 |
|
| 520 |
+
@torch.inference_mode()
|
| 521 |
@eval_decorator
|
| 522 |
def generate(
|
| 523 |
self,
|
|
|
|
| 2471 |
enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
|
| 2472 |
|
| 2473 |
out = self.decoder(tgt, context = enc, context_mask = mask)
|
| 2474 |
+
return out
|