|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from omegaconf import OmegaConf |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
from torch.nn.utils import weight_norm |
|
from transformers import T5EncoderModel, T5Tokenizer |
|
from einops import rearrange |
|
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
|
|
|
|
|
N_REPEAT = 2 |
|
|
|
def _shift(x): |
|
|
|
for i, _slice in enumerate(x): |
|
n = x.shape[2] |
|
offset = np.random.randint(.24 * n, max(1, .74 * n)) |
|
print(offset) |
|
x[i, :, :] = torch.roll(_slice, offset, dims=1) |
|
return x |
|
|
|
class AudioGen(torch.nn.Module): |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
_file_1 = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="compression_state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(_file_1, map_location='cpu') |
|
self.compression_model = EncodecModel() |
|
self.compression_model.load_state_dict(pkg['best_state'], strict=False) |
|
self.compression_model.eval() |
|
self._chunk_len = 476 |
|
_file_2 = hf_hub_download( |
|
repo_id='facebook/audiogen-medium', |
|
filename="state_dict.bin", |
|
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None), |
|
library_name="audiocraft", |
|
library_version= '1.3.0a1') |
|
pkg = torch.load(_file_2, map_location='cpu') |
|
cfg = OmegaConf.create(pkg['xp.cfg']) |
|
_best = pkg['best_state'] |
|
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight') |
|
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias') |
|
self.lm = LMModel() |
|
self.lm.load_state_dict(pkg['best_state'], strict=True) |
|
self.lm.eval() |
|
|
|
|
|
@torch.no_grad() |
|
def generate(self, |
|
prompt='dogs mewo', |
|
duration=2.24, |
|
cache_lim=71, |
|
): |
|
torch.manual_seed(42) |
|
self.lm.cache_lim = cache_lim |
|
self.lm.n_draw = int(.8 * duration) + 1 |
|
with torch.autocast(device_type='cpu', dtype=torch.bfloat16): |
|
gen_tokens = self.lm.generate( |
|
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT, |
|
max_tokens=int(.04 * duration / N_REPEAT * self.compression_model.frame_rate) + 12) |
|
|
|
|
|
x = [] |
|
|
|
|
|
for i in range(7, gen_tokens.shape[2], self._chunk_len): |
|
|
|
decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len]) |
|
|
|
x.append(decoded_chunk) |
|
|
|
x = torch.cat(x, 2) |
|
|
|
x = _shift(x) |
|
|
|
return x.reshape(-1) |
|
|
|
|
|
class EncodecModel(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
self.decoder = SEANetDecoder() |
|
self.quantizer = ResidualVectorQuantizer() |
|
self.frame_rate = 50 |
|
|
|
|
|
def decode(self, codes): |
|
|
|
emb = self.quantizer.decode(codes) |
|
return self.decoder(emb) |
|
|
|
|
|
class StreamableLSTM(nn.Module): |
|
|
|
def __init__(self, |
|
dimension, |
|
num_layers=2, |
|
skip=True): |
|
|
|
super().__init__() |
|
self.skip = skip |
|
self.lstm = nn.LSTM(dimension, dimension, num_layers) |
|
|
|
def forward(self, x): |
|
x = x.permute(2, 0, 1) |
|
y, _ = self.lstm(x) |
|
if self.skip: |
|
y = y + x |
|
y = y.permute(1, 2, 0) |
|
return y |
|
|
|
|
|
|
|
class SEANetResnetBlock(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
kernel_sizes = [3, 1], |
|
pad_mode = 'reflect', |
|
compress = 2): |
|
|
|
super().__init__() |
|
|
|
hidden = dim // compress |
|
block = [] |
|
for i, kernel_size in enumerate(kernel_sizes): |
|
in_chs = dim if i == 0 else hidden |
|
out_chs = dim if i == len(kernel_sizes) - 1 else hidden |
|
block += [nn.ELU(), |
|
StreamableConv1d(in_chs, |
|
out_chs, |
|
kernel_size=kernel_size, |
|
pad_mode=pad_mode)] |
|
self.block = nn.Sequential(*block) |
|
|
|
def forward(self, x): |
|
return x + self.block(x) |
|
|
|
|
|
|
|
|
|
|
|
class SEANetDecoder(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
channels = 1, |
|
dimension = 128, |
|
n_filters = 64, |
|
n_residual_layers = 1, |
|
ratios = [8, 5, 4, 2], |
|
kernel_size = 7, |
|
last_kernel_size = 7, |
|
residual_kernel_size = 3, |
|
pad_mode = 'constant', |
|
compress = 2, |
|
lstm = 2): |
|
|
|
super().__init__() |
|
|
|
|
|
mult = int(2 ** len(ratios)) |
|
model = [ |
|
StreamableConv1d(dimension, mult * n_filters, |
|
kernel_size, |
|
pad_mode=pad_mode) |
|
] |
|
|
|
if lstm: |
|
print('\n\n\n\nLSTM IN SEANET\n\n\n\n') |
|
model += [StreamableLSTM(mult * n_filters, |
|
num_layers=lstm)] |
|
|
|
|
|
for i, ratio in enumerate(ratios): |
|
|
|
|
|
model += [ |
|
nn.ELU(), |
|
StreamableConvTranspose1d(mult * n_filters, |
|
mult * n_filters // 2, |
|
kernel_size=ratio * 2, |
|
stride=ratio), |
|
] |
|
|
|
for j in range(n_residual_layers): |
|
|
|
model += [ |
|
SEANetResnetBlock(mult * n_filters // 2, |
|
kernel_sizes=[residual_kernel_size, 1], |
|
pad_mode=pad_mode, |
|
compress=compress)] |
|
|
|
mult //= 2 |
|
|
|
|
|
model += [ |
|
nn.ELU(), |
|
StreamableConv1d(n_filters, |
|
channels, |
|
last_kernel_size, |
|
pad_mode=pad_mode)] |
|
self.model=nn.Sequential(*model) |
|
|
|
def forward(self, z): |
|
return self.model(z) |
|
|
|
|
|
|
|
|
|
def unpad1d(x, paddings): |
|
padding_left, padding_right = paddings |
|
end = x.shape[-1] - padding_right |
|
return x[..., padding_left: end] |
|
|
|
|
|
class NormConv1d(nn.Module): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__() |
|
|
|
self.conv = weight_norm(nn.Conv1d(*args, **kwargs)) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
|
|
|
|
|
|
class NormConvTranspose1d(nn.Module): |
|
|
|
def __init__(self, *args, causal: bool = False, norm: str = 'none', |
|
norm_kwargs = {}, **kwargs): |
|
super().__init__() |
|
|
|
self.convtr = weight_norm(nn.ConvTranspose1d(*args, **kwargs)) |
|
|
|
def forward(self, x): |
|
return self.convtr(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamableConv1d(nn.Module): |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
groups=1, |
|
bias=True, |
|
pad_mode='reflect'): |
|
super().__init__() |
|
if (stride != 1) or (groups != 1): |
|
raise ValueError |
|
self.conv = NormConv1d(in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride, |
|
groups=groups, |
|
bias=bias) |
|
self.pad_mode = pad_mode |
|
|
|
def forward(self, x): |
|
kernel_size = self.conv.conv.kernel_size[0] |
|
kernel_size = (kernel_size - 1) * self.conv.conv.dilation[0] + 1 |
|
padding_total = kernel_size - self.conv.conv.stride[0] |
|
padding_right = padding_total // 2 |
|
padding_left = padding_total - padding_right |
|
|
|
|
|
x = F.pad(x, (padding_left, padding_right), self.pad_mode) |
|
return self.conv(x) |
|
|
|
|
|
class StreamableConvTranspose1d(nn.Module): |
|
|
|
def __init__(self, in_channels: int, out_channels: int, |
|
kernel_size: int, stride: int = 1, causal: bool = False, |
|
norm: str = 'none', trim_right_ratio: float = 1., |
|
norm_kwargs = {}): |
|
super().__init__() |
|
self.convtr = NormConvTranspose1d(in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride) |
|
|
|
|
|
def forward(self, x): |
|
|
|
padding_total = self.convtr.convtr.kernel_size[0] - self.convtr.convtr.stride[0] |
|
|
|
y = self.convtr(x) |
|
|
|
|
|
|
|
padding_right = padding_total // 2 |
|
padding_left = padding_total - padding_right |
|
|
|
y = unpad1d(y, (padding_left, padding_right)) |
|
return y |
|
|
|
|
|
|
|
|
|
class EuclideanCodebook(nn.Module): |
|
def __init__(self, |
|
dim, |
|
codebook_size): |
|
super().__init__() |
|
self.register_buffer("embed", torch.zeros(codebook_size, dim)) |
|
|
|
|
|
|
|
|
|
class VectorQuantization(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
codebook_size): |
|
|
|
super().__init__() |
|
self._codebook = EuclideanCodebook(dim=dim, |
|
codebook_size=codebook_size) |
|
|
|
def decode(self, _ind): |
|
return F.embedding(_ind, self._codebook.embed) |
|
|
|
|
|
class ResidualVectorQuantization(nn.Module): |
|
|
|
def __init__(self, *, num_quantizers, **kwargs): |
|
super().__init__() |
|
self.layers = nn.ModuleList( |
|
[VectorQuantization(**kwargs) for _ in range(num_quantizers)] |
|
) |
|
|
|
def decode(self, _ind): |
|
x = 0.0 |
|
for i, _code in enumerate(_ind): |
|
x = x + self.layers[i].decode(_code) |
|
return x.transpose(1, 2) |
|
|
|
|
|
class ResidualVectorQuantizer(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
dimension = 128, |
|
n_q = 4, |
|
bins = 2048 |
|
): |
|
|
|
super().__init__() |
|
self.vq = ResidualVectorQuantization(dim=dimension, |
|
codebook_size=bins, |
|
num_quantizers=n_q) |
|
|
|
def decode(self, codes): |
|
|
|
return self.vq.decode(codes.transpose(0, 1)) |
|
|
|
|
|
class T5(nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
self.output_proj = nn.Linear(1024, |
|
1536) |
|
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True) |
|
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False) |
|
|
|
|
|
|
|
self.__dict__['t5'] = t5.to('cpu') |
|
|
|
def forward(self, prompt): |
|
with torch.set_grad_enabled(False): |
|
|
|
bs = len(prompt) // 2 |
|
d = self.t5_tokenizer(prompt, |
|
return_tensors='pt', |
|
padding=True).to(self.output_proj.bias.device) |
|
d['attention_mask'][bs:, :] = 0 |
|
|
|
x = self.t5(input_ids=d['input_ids'], |
|
attention_mask=d['attention_mask']).last_hidden_state |
|
|
|
|
|
x = self.output_proj(x) |
|
x[bs:, :, :] = 0 |
|
return x |
|
|
|
|
|
class LMModel(nn.Module): |
|
|
|
def __init__(self, |
|
n_q = 4, |
|
card = 2048, |
|
dim = 1536 |
|
): |
|
super().__init__() |
|
self.cache_lim = -1 |
|
self.t5 = T5() |
|
self.card = card |
|
self.n_draw = 1 |
|
|
|
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) |
|
self.transformer = StreamingTransformer() |
|
self.out_norm = nn.LayerNorm(dim, eps=1e-5) |
|
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) |
|
|
|
def forward(self, |
|
sequence, |
|
condition_tensors=None, |
|
cache_position=None |
|
): |
|
|
|
bs, n_q, time_frames = sequence.shape |
|
|
|
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)]) |
|
|
|
out = self.transformer(torch.cat([input_, input_], 0), |
|
cross_attention_src=condition_tensors, |
|
cache_position=cache_position) |
|
|
|
out = self.out_norm(out) |
|
|
|
logits = torch.stack([self.linears[k](out) for k in range(n_q)], dim=1) |
|
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] |
|
|
|
|
|
tokens = torch.multinomial(torch.softmax(logits.view(bs * self.n_draw * n_q, 2048), dim=1), |
|
num_samples=1) |
|
return tokens.view(bs, n_q, self.n_draw).transpose(1, 2) |
|
|
|
@torch.no_grad() |
|
def generate(self, |
|
max_tokens=None, |
|
text_condition=None |
|
): |
|
x = self.t5(text_condition) |
|
bs = x.shape[0] // 2 |
|
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94 |
|
cache_position = 0 |
|
|
|
out_codes = torch.full((bs, |
|
self.n_draw, |
|
4, |
|
4 + 3 + max_tokens), |
|
self.card, |
|
dtype=torch.long, |
|
device=x.device) |
|
|
|
|
|
for offset in range(0, max_tokens + 4 - 1): |
|
|
|
|
|
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], |
|
|
|
condition_tensors=x, |
|
cache_position=cache_position) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if offset == 0: |
|
|
|
next_token[:, :, 1:4] = 2048 |
|
|
|
elif offset == 1: |
|
|
|
next_token[:, :, 2:4] = 2048 |
|
|
|
elif offset == 2: |
|
|
|
next_token[:, :, 3:4] = 2048 |
|
|
|
elif offset == max_tokens: |
|
|
|
next_token[:, :, 0:1] = 2048 |
|
|
|
elif offset == (max_tokens + 1): |
|
|
|
next_token[:, :, 0:2] = 2048 |
|
|
|
elif offset == (max_tokens + 2): |
|
|
|
next_token[:, :, 0:3] = 2048 |
|
|
|
else: |
|
|
|
pass |
|
|
|
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token |
|
|
|
if (offset > 0) and (offset % self.cache_lim) == 0: |
|
n_preserve = 4 |
|
self.transformer._flush(n_preserve=n_preserve) |
|
cache_position = n_preserve |
|
else: |
|
cache_position += 1 |
|
|
|
|
|
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens) |
|
|
|
|
|
self.transformer._flush() |
|
|
|
return out_codes |
|
|
|
|
|
|
|
|
|
def create_sin_embedding(positions, |
|
dim, |
|
max_period=10000 |
|
): |
|
|
|
half_dim = dim // 2 |
|
positions = positions.to(torch.float) |
|
adim = torch.arange(half_dim, device=positions.device, |
|
dtype=torch.float).view(1, 1, -1) |
|
max_period_tensor = torch.full([], |
|
max_period, |
|
device=positions.device, |
|
dtype=torch.float) |
|
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
|
|
|
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
|
class StreamingMultiheadAttention(nn.Module): |
|
|
|
def __init__(self, |
|
embed_dim, |
|
num_heads, |
|
cross_attention=False, |
|
): |
|
|
|
super().__init__() |
|
|
|
self.cross_attention = cross_attention |
|
|
|
self.k_history = None |
|
|
|
self.v_history = None |
|
self.num_heads = num_heads |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim), |
|
dtype=torch.float)) |
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None): |
|
layout = "b h t d" |
|
if self.cross_attention: |
|
|
|
|
|
|
|
dim = self.in_proj_weight.shape[0] // 3 |
|
|
|
q = nn.functional.linear(query, self.in_proj_weight[:dim]) |
|
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim]) |
|
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:]) |
|
|
|
q, k, v = [ |
|
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
projected = nn.functional.linear(query, self.in_proj_weight, None) |
|
|
|
bound_layout = "b h p t d" |
|
packed = rearrange( |
|
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) |
|
q, k, v = packed.unbind(dim=2) |
|
if self.k_history is not None: |
|
|
|
|
|
self.k_history = torch.cat([self.k_history, k], 2) |
|
self.v_history = torch.cat([self.v_history, v], 2) |
|
else: |
|
self.k_history = k |
|
self.v_history = v |
|
|
|
|
|
|
|
k = self.k_history |
|
v = self.v_history |
|
|
|
|
|
|
|
x = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0) |
|
|
|
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class StreamingTransformerLayer(nn.Module): |
|
|
|
def __init__(self, |
|
d_model, |
|
num_heads, |
|
dim_feedforward): |
|
|
|
super().__init__() |
|
|
|
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model, |
|
num_heads=num_heads) |
|
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) |
|
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model, |
|
num_heads=num_heads, |
|
cross_attention=True) |
|
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5) |
|
self.norm1 = nn.LayerNorm(d_model, eps=1e-5) |
|
self.norm2 = nn.LayerNorm(d_model, eps=1e-5) |
|
|
|
def forward(self, |
|
x, |
|
cross_attention_src=None): |
|
x = x + self.self_attn(self.norm1(x)) |
|
x = x + self.cross_attention(query=self.norm_cross(x), |
|
key=cross_attention_src, |
|
value=cross_attention_src) |
|
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x)))) |
|
return x |
|
|
|
|
|
class StreamingTransformer(nn.Module): |
|
|
|
def __init__(self, |
|
d_model=1536, |
|
num_heads=24, |
|
num_layers=48, |
|
dim_feedforward=6144): |
|
super().__init__() |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
StreamingTransformerLayer(d_model=d_model, |
|
num_heads=num_heads, |
|
dim_feedforward=dim_feedforward) for _ in range(num_layers) |
|
] |
|
) |
|
|
|
def forward(self, |
|
x, |
|
cache_position=None, |
|
cross_attention_src=None): |
|
|
|
x = x + create_sin_embedding( |
|
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536) |
|
|
|
for lay in self.layers: |
|
x = lay(x, |
|
cross_attention_src=cross_attention_src) |
|
return x |
|
|
|
def _flush(self, |
|
n_preserve=None): |
|
|
|
for lay in self.layers: |
|
if n_preserve is not None: |
|
|
|
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :] |
|
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :] |
|
else: |
|
lay.self_attn.k_history = None |
|
lay.self_attn.v_history = None |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
import audiofile |
|
model = AudioGen().to('cpu') |
|
x = model.generate(prompt='swims in lake frogs', duration=6.4).cpu().numpy() |
|
audiofile.write('_sound_.wav', x, 16000) |
|
|