|
"""from https://github.com/lucidrains/x-transformers""" |
|
import math |
|
from random import random |
|
|
|
import torch |
|
from torch import nn, einsum |
|
import torch.nn.functional as F |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from functools import partial, wraps |
|
from inspect import isfunction |
|
|
|
from einops import rearrange, repeat, reduce |
|
|
|
|
|
|
|
|
|
DEFAULT_DIM_HEAD = 64 |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
if exists(val): |
|
return val |
|
return d() if isfunction(d) else d |
|
|
|
|
|
def cast_tuple(val, depth): |
|
return val if isinstance(val, tuple) else (val,) * depth |
|
|
|
|
|
|
|
|
|
def init_zero_(layer): |
|
nn.init.constant_(layer.weight, 0.) |
|
if exists(layer.bias): |
|
nn.init.constant_(layer.bias, 0.) |
|
|
|
|
|
|
|
|
|
def pick_and_pop(keys, d): |
|
values = list(map(lambda key: d.pop(key), keys)) |
|
return dict(zip(keys, values)) |
|
|
|
|
|
def group_dict_by_key(cond, d): |
|
return_val = [dict(), dict()] |
|
for key in d.keys(): |
|
match = bool(cond(key)) |
|
ind = int(not match) |
|
return_val[ind][key] = d[key] |
|
return (*return_val,) |
|
|
|
|
|
def string_begins_with(prefix, str): |
|
return str.startswith(prefix) |
|
|
|
|
|
def group_by_key_prefix(prefix, d): |
|
return group_dict_by_key(partial(string_begins_with, prefix), d) |
|
|
|
|
|
def groupby_prefix_and_trim(prefix, d): |
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) |
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) |
|
return kwargs_without_prefix, kwargs |
|
|
|
|
|
|
|
|
|
def deepnorm_init( |
|
transformer, |
|
beta, |
|
module_name_match_list=['.ff.', '.to_v', '.to_out'] |
|
): |
|
for name, module in transformer.named_modules(): |
|
if type(module) != nn.Linear: |
|
continue |
|
|
|
needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list)) |
|
gain = beta if needs_beta_gain else 1 |
|
nn.init.xavier_normal_(module.weight.data, gain=gain) |
|
|
|
if exists(module.bias): |
|
nn.init.constant_(module.bias.data, 0) |
|
|
|
|
|
|
|
|
|
class ReluSquared(nn.Module): |
|
def forward(self, x): |
|
return F.relu(x) ** 2 |
|
|
|
|
|
|
|
|
|
class Scale(nn.Module): |
|
def __init__(self, value, fn): |
|
super().__init__() |
|
self.value = value |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
out = self.fn(x, **kwargs) |
|
scale_fn = lambda t: t * self.value |
|
|
|
if not isinstance(out, tuple): |
|
return scale_fn(out) |
|
|
|
return (scale_fn(out[0]), *out[1:]) |
|
|
|
|
|
class ScaleNorm(nn.Module): |
|
def __init__(self, dim, eps=1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5)) |
|
|
|
def forward(self, x): |
|
norm = torch.norm(x, dim=-1, keepdim=True) |
|
return x / norm.clamp(min=self.eps) * self.g |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim, eps=1e-8): |
|
super().__init__() |
|
self.scale = dim ** -0.5 |
|
self.eps = eps |
|
self.g = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale |
|
return x / norm.clamp(min=self.eps) * self.g |
|
|
|
|
|
|
|
|
|
class Residual(nn.Module): |
|
def __init__(self, dim, scale_residual=False, scale_residual_constant=1.): |
|
super().__init__() |
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None |
|
self.scale_residual_constant = scale_residual_constant |
|
|
|
def forward(self, x, residual): |
|
if exists(self.residual_scale): |
|
residual = residual * self.residual_scale |
|
|
|
if self.scale_residual_constant != 1: |
|
residual = residual * self.scale_residual_constant |
|
|
|
return x + residual |
|
|
|
|
|
class GRUGating(nn.Module): |
|
def __init__(self, dim, scale_residual=False, **kwargs): |
|
super().__init__() |
|
self.gru = nn.GRUCell(dim, dim) |
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None |
|
|
|
def forward(self, x, residual): |
|
if exists(self.residual_scale): |
|
residual = residual * self.residual_scale |
|
|
|
gated_output = self.gru( |
|
rearrange(x, 'b n d -> (b n) d'), |
|
rearrange(residual, 'b n d -> (b n) d') |
|
) |
|
|
|
return gated_output.reshape_as(x) |
|
|
|
|
|
|
|
class GLU(nn.Module): |
|
def __init__(self, dim_in, dim_out, activation): |
|
super().__init__() |
|
self.act = activation |
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
def forward(self, x): |
|
x, gate = self.proj(x).chunk(2, dim=-1) |
|
return x * self.act(gate) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_out=None, |
|
mult=4, |
|
glu=False, |
|
swish=False, |
|
relu_squared=False, |
|
post_act_ln=False, |
|
dropout=0., |
|
no_bias=False, |
|
zero_init_output=False |
|
): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
dim_out = default(dim_out, dim) |
|
|
|
if relu_squared: |
|
activation = ReluSquared() |
|
elif swish: |
|
activation = nn.SiLU() |
|
else: |
|
activation = nn.GELU() |
|
|
|
project_in = nn.Sequential( |
|
nn.Linear(dim, inner_dim, bias=not no_bias), |
|
activation |
|
) if not glu else GLU(dim, inner_dim, activation) |
|
|
|
self.ff = nn.Sequential( |
|
project_in, |
|
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), |
|
nn.Dropout(dropout), |
|
nn.Linear(inner_dim, dim_out, bias=not no_bias) |
|
) |
|
|
|
|
|
if zero_init_output: |
|
init_zero_(self.ff[-1]) |
|
|
|
def forward(self, x): |
|
return self.ff(x) |
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
kv_dim=None, |
|
dim_head=DEFAULT_DIM_HEAD, |
|
heads=8, |
|
causal=False, |
|
dropout=0., |
|
zero_init_output=False, |
|
shared_kv=False, |
|
value_dim_head=None, |
|
flash_attention=True, |
|
): |
|
super().__init__() |
|
self.scale = dim_head ** -0.5 |
|
if kv_dim is None: |
|
kv_dim = dim |
|
|
|
self.heads = heads |
|
self.causal = causal |
|
|
|
value_dim_head = default(value_dim_head, dim_head) |
|
q_dim = k_dim = dim_head * heads |
|
v_dim = out_dim = value_dim_head * heads |
|
|
|
self.to_q = nn.Linear(dim, q_dim, bias=False) |
|
self.to_k = nn.Linear(kv_dim, k_dim, bias=False) |
|
|
|
|
|
assert not ( |
|
shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values' |
|
self.to_v = nn.Linear(kv_dim, v_dim, bias=False) if not shared_kv else None |
|
|
|
|
|
self.to_out = nn.Linear(out_dim, dim, bias=False) |
|
|
|
|
|
self.dropout_p = dropout |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.flash = flash_attention |
|
assert self.flash |
|
|
|
|
|
|
|
|
|
self.use_xformer = not hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
|
|
|
|
if zero_init_output: |
|
init_zero_(self.to_out) |
|
|
|
def forward( |
|
self, |
|
x, |
|
context=None, |
|
mask=None, |
|
context_mask=None, |
|
): |
|
|
|
h = self.heads |
|
kv_input = default(context, x) |
|
|
|
q_input = x |
|
k_input = kv_input |
|
v_input = kv_input |
|
|
|
q = self.to_q(q_input) |
|
k = self.to_k(k_input) |
|
v = self.to_v(v_input) if exists(self.to_v) else k |
|
|
|
|
|
|
|
|
|
|
|
if self.use_xformer: |
|
|
|
dtype = q.dtype |
|
q, k, v = map(lambda t: t.bfloat16() if t.dtype == torch.float32 else t, (q, k, v)) |
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) |
|
try: |
|
import xformers.ops as xops |
|
except ImportError as e: |
|
print("Please install xformers to use flash attention for PyTorch < 2.0.0.") |
|
raise e |
|
|
|
|
|
if self.causal: |
|
attention_bias = xops.LowerTriangularMask() |
|
else: |
|
attention_bias = None |
|
|
|
|
|
out = xops.memory_efficient_attention( |
|
q, k, v, attn_bias=attention_bias, |
|
|
|
) |
|
|
|
out = out.to(dtype) |
|
|
|
out = rearrange(out, 'b n h d -> b n (h d)') |
|
else: |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) |
|
|
|
out = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=None, dropout_p=self.dropout_p, is_causal=self.causal, |
|
) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
|
|
out = self.to_out(out) |
|
|
|
if exists(mask): |
|
mask = rearrange(mask, 'b n -> b n 1') |
|
out = out.masked_fill(~mask, 0.) |
|
|
|
return out |
|
|
|
def extra_repr(self) -> str: |
|
return f"causal: {self.causal}, flash attention: {self.flash}, " \ |
|
f"use_xformers (if False, use torch.nn.functional.scaled_dot_product_attention): {self.use_xformer}" |
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
class AttentionLayers(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
depth, |
|
heads=8, |
|
ctx_dim=None, |
|
causal=False, |
|
cross_attend=False, |
|
only_cross=False, |
|
use_scalenorm=False, |
|
use_rmsnorm=False, |
|
residual_attn=False, |
|
cross_residual_attn=False, |
|
macaron=False, |
|
pre_norm=True, |
|
gate_residual=False, |
|
scale_residual=False, |
|
scale_residual_constant=1., |
|
deepnorm=False, |
|
sandwich_norm=False, |
|
zero_init_branch_output=False, |
|
layer_dropout=0., |
|
|
|
modulate_feature_size=-1, |
|
checkpointing=False, |
|
checkpoint_every=1, |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
|
|
self.checkpointing = checkpointing |
|
self.checkpoint_every = checkpoint_every |
|
|
|
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) |
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs) |
|
|
|
self.dim = dim |
|
self.depth = depth |
|
self.layers = nn.ModuleList([]) |
|
|
|
|
|
if deepnorm: |
|
assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings' |
|
pre_norm = sandwich_norm = False |
|
scale_residual = True |
|
scale_residual_constant = (2 * depth) ** 0.25 |
|
|
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' |
|
self.pre_norm = pre_norm |
|
self.sandwich_norm = sandwich_norm |
|
|
|
self.residual_attn = residual_attn |
|
self.cross_residual_attn = cross_residual_attn |
|
self.cross_attend = cross_attend |
|
|
|
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm |
|
norm_class = RMSNorm if use_rmsnorm else norm_class |
|
norm_fn = partial(norm_class, dim) |
|
|
|
if cross_attend and not only_cross: |
|
default_block = ('a', 'c', 'f') |
|
elif cross_attend and only_cross: |
|
default_block = ('c', 'f') |
|
else: |
|
default_block = ('a', 'f') |
|
|
|
if macaron: |
|
default_block = ('f',) + default_block |
|
|
|
|
|
|
|
if zero_init_branch_output: |
|
attn_kwargs = {**attn_kwargs, 'zero_init_output': True} |
|
ff_kwargs = {**ff_kwargs, 'zero_init_output': True} |
|
|
|
|
|
layer_types = default_block * depth |
|
|
|
self.layer_types = layer_types |
|
|
|
|
|
self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types)) |
|
|
|
|
|
for ind, layer_type in enumerate(self.layer_types): |
|
is_last_layer = ind == (len(self.layer_types) - 1) |
|
|
|
if layer_type == 'a': |
|
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) |
|
elif layer_type == 'c': |
|
layer = Attention(dim, kv_dim=ctx_dim, heads=heads, **attn_kwargs) |
|
elif layer_type == 'f': |
|
layer = FeedForward(dim, **ff_kwargs) |
|
layer = layer if not macaron else Scale(0.5, layer) |
|
else: |
|
raise Exception(f'invalid layer type {layer_type}') |
|
|
|
residual_fn = GRUGating if gate_residual else Residual |
|
residual = residual_fn(dim, scale_residual=scale_residual, scale_residual_constant=scale_residual_constant) |
|
|
|
pre_branch_norm = norm_fn() if pre_norm else None |
|
post_branch_norm = norm_fn() if sandwich_norm else None |
|
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None |
|
|
|
|
|
|
|
modulation = None |
|
if modulate_feature_size is not None: |
|
modulation = nn.Sequential( |
|
nn.LayerNorm(modulate_feature_size), |
|
nn.GELU(), |
|
nn.Linear(modulate_feature_size, 3 * dim, bias=True) |
|
) |
|
|
|
norms = nn.ModuleList([ |
|
pre_branch_norm, |
|
post_branch_norm, |
|
post_main_norm, |
|
]) |
|
|
|
self.layers.append(nn.ModuleList([ |
|
norms, |
|
layer, |
|
residual, |
|
modulation, |
|
])) |
|
|
|
if deepnorm: |
|
init_gain = (8 * depth) ** -0.25 |
|
deepnorm_init(self, init_gain) |
|
|
|
def forward( |
|
self, |
|
x, |
|
context=None, |
|
modulation=None, |
|
mask=None, |
|
context_mask=None, |
|
): |
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' |
|
|
|
num_layers = len(self.layer_types) |
|
assert num_layers % self.checkpoint_every == 0 |
|
|
|
for start_layer_idx in range(0, num_layers, self.checkpoint_every): |
|
end_layer_idx = min(start_layer_idx + self.checkpoint_every, num_layers) |
|
|
|
def run_layers(x, context, modulation, start, end): |
|
for ind, (layer_type, (norm, block, residual_fn, modulation_fn), layer_dropout) in enumerate( |
|
zip(self.layer_types[start: end], self.layers[start: end], self.layer_dropouts[start: end])): |
|
residual = x |
|
|
|
pre_branch_norm, post_branch_norm, post_main_norm = norm |
|
|
|
if exists(pre_branch_norm): |
|
x = pre_branch_norm(x) |
|
|
|
if modulation_fn is not None: |
|
shift, scale, gate = modulation_fn(modulation).chunk(3, dim=1) |
|
x = modulate(x, shift, scale) |
|
|
|
if layer_type == 'a': |
|
out = block(x, mask=mask) |
|
elif layer_type == 'c': |
|
out = block(x, context=context, mask=mask, context_mask=context_mask) |
|
elif layer_type == 'f': |
|
out = block(x) |
|
|
|
if exists(post_branch_norm): |
|
out = post_branch_norm(out) |
|
|
|
if modulation_fn is not None: |
|
|
|
out = out * gate.unsqueeze(1) |
|
|
|
x = residual_fn(out, residual) |
|
|
|
if exists(post_main_norm): |
|
x = post_main_norm(x) |
|
|
|
return x |
|
|
|
if self.checkpointing: |
|
|
|
x = checkpoint(run_layers, x, context, modulation, start_layer_idx, end_layer_idx) |
|
else: |
|
x = run_layers(x, context, modulation, start_layer_idx, end_layer_idx) |
|
|
|
return x |
|
|
|
|