audiogen2 / audiocraft.py
Dionyssos's picture
f
0a4b027
import torch
from torch import nn
import torch.nn.functional as F
from omegaconf import OmegaConf
import numpy as np
from huggingface_hub import hf_hub_download
import os
from torch.nn.utils import weight_norm
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
from einops import rearrange
torch.backends.cuda.enable_mem_efficient_sdp(True)
N_REPEAT = 2 # num (virtual batch_size) clones of audio sounds
def _shift(x):
#print(x.shape, 'BATCH Independent SHIFT\n AudioGen')
for i, _slice in enumerate(x):
n = x.shape[2]
offset = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD
print(offset)
x[i, :, :] = torch.roll(_slice, offset, dims=1) # _slice 2D
return x
class AudioGen(torch.nn.Module):
# https://huggingface.co/facebook/audiogen-medium
def __init__(self):
super().__init__()
_file_1 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="compression_state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_1, map_location='cpu')# kwargs = OmegaConf.create(pkg['xp.cfg'])
self.compression_model = EncodecModel()
self.compression_model.load_state_dict(pkg['best_state'], strict=False)
self.compression_model.eval() # ckpt has also unused encoder weights
self._chunk_len = 476
_file_2 = hf_hub_download(
repo_id='facebook/audiogen-medium',
filename="state_dict.bin",
cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
library_name="audiocraft",
library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
pkg = torch.load(_file_2, map_location='cpu')
cfg = OmegaConf.create(pkg['xp.cfg']) # CFG inside torch bin
_best = pkg['best_state']
_best['t5.output_proj.weight'] = _best.pop('condition_provider.conditioners.description.output_proj.weight')#.to(torch.float)
_best['t5.output_proj.bias'] = _best.pop('condition_provider.conditioners.description.output_proj.bias')#.to(torch.float)
self.lm = LMModel()
self.lm.load_state_dict(pkg['best_state'], strict=True)
self.lm.eval()
@torch.no_grad()
def generate(self,
prompt='dogs mewo',
duration=2.24, # seconds of audio
cache_lim=71, # flush kv cache after cache_lim tok
):
torch.manual_seed(42) # https://github.com/facebookresearch/audiocraft/issues/111#issuecomment-1614732858
self.lm.cache_lim = cache_lim
self.lm.n_draw = int(.8 * duration) + 1 # different beam every 0.47 seconds of audio
with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
gen_tokens = self.lm.generate(
text_condition=[prompt] * N_REPEAT + [''] * N_REPEAT,#['dogs', 'dogs...!', '', '']
max_tokens=int(.04 * duration / N_REPEAT * self.compression_model.frame_rate) + 12) # [bs, 4, 74*self.lm.n_draw]
# OOM if vocode all tokens
x = []
for i in range(7, gen_tokens.shape[2], self._chunk_len): # min soundscape 2s assures 10 tokens
decoded_chunk = self.compression_model.decode(gen_tokens[:, :, i-7:i+self._chunk_len])
x.append(decoded_chunk)
x = torch.cat(x, 2) # [bs, 1, 114000]
x = _shift(x) # clone() to have xN
return x.reshape(-1) #x / (x.abs().max() + 1e-7)
class EncodecModel(nn.Module):
def __init__(self):
super().__init__()
self.decoder = SEANetDecoder()
self.quantizer = ResidualVectorQuantizer()
self.frame_rate = 50
def decode(self, codes):
# B,K,T -> B,C,T
emb = self.quantizer.decode(codes)
return self.decoder(emb)
class StreamableLSTM(nn.Module):
def __init__(self,
dimension,
num_layers=2,
skip=True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y
class SEANetResnetBlock(nn.Module):
def __init__(self,
dim,
kernel_sizes = [3, 1],
pad_mode = 'reflect',
compress = 2):
super().__init__()
hidden = dim // compress
block = []
for i, kernel_size in enumerate(kernel_sizes):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [nn.ELU(),
StreamableConv1d(in_chs,
out_chs,
kernel_size=kernel_size,
pad_mode=pad_mode)]
self.block = nn.Sequential(*block)
def forward(self, x):
return x + self.block(x)
class SEANetDecoder(nn.Module):
# channels=1 dimension=128 n_filters=64 n_residual_layers=1 ratios=[8, 5, 4, 2]
# activation='ELU' activation_params={'alpha': 1.0}, final_activation=None
# final_activation_params=None norm='weight_norm'
# norm_params={} kernel_size=7 last_kernel_size=7 residual_kernel_size=3 dilation_base=2
# causal=False pad_mode='constant'
# true_skip=True compress=2 lstm=2 disable_norm_outer_blocks=0 trim_right_ratio=1.0
def __init__(self,
channels = 1,
dimension = 128,
n_filters = 64,
n_residual_layers = 1,
ratios = [8, 5, 4, 2],
kernel_size = 7,
last_kernel_size = 7,
residual_kernel_size = 3,
pad_mode = 'constant',
compress = 2,
lstm = 2):
super().__init__()
mult = int(2 ** len(ratios))
model = [
StreamableConv1d(dimension, mult * n_filters,
kernel_size,
pad_mode=pad_mode)
]
if lstm:
print('\n\n\n\nLSTM IN SEANET\n\n\n\n')
model += [StreamableLSTM(mult * n_filters,
num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(ratios):
model += [
nn.ELU(),
StreamableConvTranspose1d(mult * n_filters,
mult * n_filters // 2,
kernel_size=ratio * 2,
stride=ratio),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters // 2,
kernel_sizes=[residual_kernel_size, 1],
pad_mode=pad_mode,
compress=compress)]
mult //= 2
# Add final layers
model += [
nn.ELU(),
StreamableConv1d(n_filters,
channels,
last_kernel_size,
pad_mode=pad_mode)]
self.model=nn.Sequential(*model)
def forward(self, z):
return self.model(z)
def unpad1d(x, paddings):
padding_left, padding_right = paddings
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.conv = weight_norm(nn.Conv1d(*args, **kwargs)) # norm = weight_norm
def forward(self, x):
return self.conv(x)
class NormConvTranspose1d(nn.Module):
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs = {}, **kwargs):
super().__init__()
self.convtr = weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def forward(self, x):
return self.convtr(x)
class StreamableConv1d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
bias=True,
pad_mode='reflect'):
super().__init__()
if (stride != 1) or (groups != 1):
raise ValueError
self.conv = NormConv1d(in_channels,
out_channels,
kernel_size,
stride,
groups=groups,
bias=bias)
self.pad_mode = pad_mode
def forward(self, x):
kernel_size = self.conv.conv.kernel_size[0]
kernel_size = (kernel_size - 1) * self.conv.conv.dilation[0] + 1
padding_total = kernel_size - self.conv.conv.stride[0]
padding_right = padding_total // 2
padding_left = padding_total - padding_right
# x = pad1d(x, (padding_left, padding_right), mode=self.pad_mode)
x = F.pad(x, (padding_left, padding_right), self.pad_mode)
return self.conv(x)
class StreamableConvTranspose1d(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs = {}):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels,
out_channels,
kernel_size,
stride)
def forward(self, x):
padding_total = self.convtr.convtr.kernel_size[0] - self.convtr.convtr.stride[0]
y = self.convtr(x)
# Asymmetric padding required for odd strides
# print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n')
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
# VQ
class EuclideanCodebook(nn.Module):
def __init__(self,
dim,
codebook_size):
super().__init__()
self.register_buffer("embed", torch.zeros(codebook_size, dim))
class VectorQuantization(nn.Module):
def __init__(self,
dim,
codebook_size):
super().__init__()
self._codebook = EuclideanCodebook(dim=dim,
codebook_size=codebook_size)
def decode(self, _ind):
return F.embedding(_ind, self._codebook.embed)
class ResidualVectorQuantization(nn.Module):
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def decode(self, _ind):
x = 0.0
for i, _code in enumerate(_ind):
x = x + self.layers[i].decode(_code)
return x.transpose(1, 2)
class ResidualVectorQuantizer(nn.Module):
# dimension=128 n_q=4 q_dropout=False bins=2048 decay=0.99 kmeans_init=True
# kmeans_iters=50 threshold_ema_dead_code=2
# orthogonal_reg_weight=0.0 orthogonal_reg_active_codes_only=False
# orthogonal_reg_max_codes=None
def __init__(
self,
dimension = 128,
n_q = 4,
bins = 2048
):
super().__init__()
self.vq = ResidualVectorQuantization(dim=dimension,
codebook_size=bins,
num_quantizers=n_q)
def decode(self, codes):
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
return self.vq.decode(codes.transpose(0, 1))
class T5(nn.Module):
def __init__(self):
super().__init__()
self.output_proj = nn.Linear(1024, # t5-large
1536) # lm hidden
self.t5_tokenizer = T5Tokenizer.from_pretrained('t5-large', legacy=True)
t5 = T5EncoderModel.from_pretrained('t5-large').train(mode=False)
# this makes sure that the t5 is not part
# of the saved checkpoint
self.__dict__['t5'] = t5.to('cpu')
def forward(self, prompt):
with torch.set_grad_enabled(False): #, torch.autocast(device_type='cpu', dtype=torch.float32):
bs = len(prompt) // 2
d = self.t5_tokenizer(prompt,
return_tensors='pt',
padding=True).to(self.output_proj.bias.device)
d['attention_mask'][bs:, :] = 0 # null condition t5 attn_mask should be zero
x = self.t5(input_ids=d['input_ids'],
attention_mask=d['attention_mask']).last_hidden_state # no kv
# Float 16
# > self.output_proj() is outside of autocast of t5 - however inside the autocast of lm thus computed in torch.float16
x = self.output_proj(x) # nn.Linear() - produces different result if there is no duplicate txt condition here
x[bs:, :, :] = 0 # venv/../site-packages/audiocraft/modules/conditioners.py -> tokenize()
return x
class LMModel(nn.Module):
def __init__(self,
n_q = 4,
card = 2048,
dim = 1536
):
super().__init__()
self.cache_lim = -1
self.t5 = T5()
self.card = card # 2048
self.n_draw = 1 # draw > 1 tokens of different CFG scale
# batch size > 1 is slower from n_draw as calls transformer on larger batch
self.emb = nn.ModuleList([nn.Embedding(self.card + 1, dim) for _ in range(n_q)]) # EMBEDDING HAS 2049
self.transformer = StreamingTransformer()
self.out_norm = nn.LayerNorm(dim, eps=1e-5)
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=False) for _ in range(n_q)]) # LINEAR DOESNT HAVE 2049
def forward(self,
sequence,
condition_tensors=None,
cache_position=None
):
bs, n_q, time_frames = sequence.shape # [bs, 4, time]
input_ = sum([self.emb[k](sequence[:, k]) for k in range(n_q)])
out = self.transformer(torch.cat([input_, input_], 0), # duplicate null condition (bs x 2) for ClassifierFreeGuidance
cross_attention_src=condition_tensors,
cache_position=cache_position)
out = self.out_norm(out)
logits = torch.stack([self.linears[k](out) for k in range(n_q)], dim=1) # [2*bs, 4, 1, 2048]
logits = 3 * logits[:bs, :, :, :] - self._scale * logits[bs:, :, :, :] # [ bs, 4, n_draw, 2048]
#bs, n_q, n_draw, vocab = logits.shape
tokens = torch.multinomial(torch.softmax(logits.view(bs * self.n_draw * n_q, 2048), dim=1),
num_samples=1)
return tokens.view(bs, n_q, self.n_draw).transpose(1, 2)
@torch.no_grad()
def generate(self,
max_tokens=None,
text_condition=None
):
x = self.t5(text_condition)
bs = x.shape[0] // 2 # has null conditions - bs*2*N_REPEAT applys in builders.py
self._scale = .3 * torch.rand(1, 1, self.n_draw, 1, device=x.device) + 1.94
cache_position = 0
out_codes = torch.full((bs,
self.n_draw,
4,
4 + 3 + max_tokens), # 4 + max_tokens + 4-1 to have sufficient to index the 1st antidiagonal of 4x4 + 4 xtra tokens
self.card,
dtype=torch.long,
device=x.device) # [bs, n_draw, 4, dur]
# A/R
for offset in range(0, max_tokens + 4 - 1): # max_tokens + n_q - 1
# extract diagonal via indexing out_codes[ [0, 1, 2, 3], [0, 1, 2, 3] ]
next_token = self.forward(out_codes[:, 0, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset][:, :, None], # index diagonal & exapnd to [bs, n_q, dur=1]
#gen_sequence[:, 0, :, offset-1:offset], # DIAGINDEXING for setting prediction of lm into gen_sequence THE GENSEQUENCE has to be un-delayed in the end [Because it has to be de-delayed for the vocoder then is actually only the lm input that requires to see the delay thus we could just feed by diaggather] so it matches gen_codes -1 a[[0, 1, 2, 3], torch.tensor([0, 1, 2, 3]) + 5] the gen_sequence is indexed by vertical column and fed to lm however the prediction of lm is place diagonally with delay to the gen_sequence
condition_tensors=x, # utilisation of the attention mask of txt condition ?
cache_position=cache_position) # [bs, n_draw, 4]
# Fill of next_token should be also placed on antidiagonal [not column]
# Do Not Overwrite 2048 of TRIU/TRIL = START/END => Do Not Fill them by Predicted Tokens
# 0-th antidiagonal should be full of card = [2048, 2048, 2048, 2048]
#
# [2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6, 2048],
# [2048, 2048, 2048, 2048, 2048, 2048, 2048, 0, 1, 2, 3, 4, 5, 6]]
# NO OVerWriting
if offset == 0:
next_token[:, :, 1:4] = 2048 # self.card - bottom 3 entries of the antidiagonal should remain 2048
elif offset == 1:
next_token[:, :, 2:4] = 2048 # bottom 2 entries of the antidiagonal should remain 2048
elif offset == 2:
next_token[:, :, 3:4] = 2048
elif offset == max_tokens:
next_token[:, :, 0:1] = 2048 # top 1 entry of the antidiagonal should stay to 2048
elif offset == (max_tokens + 1):
next_token[:, :, 0:2] = 2048
elif offset == (max_tokens + 2):
next_token[:, :, 0:3] = 2048
else: # offset 3,4,5,6,7...... max_tokens-1 # FILL Complete n_q = 4 ANTIDIAGONAL ENTRIES
pass #print('No delete anti-diag')
out_codes[:, :, [0, 1, 2, 3], torch.tensor([3, 2, 1, 0]) + offset + 1] = next_token
# Sink Attn
if (offset > 0) and (offset % self.cache_lim) == 0:
n_preserve = 4
self.transformer._flush(n_preserve=n_preserve)
cache_position = n_preserve
else:
cache_position += 1
# [bs, n_draw, 4, time+xtra] -> [bs, 4, n_draw, time] -> [bs, 4, time * n_draw]
out_codes = out_codes[:, :, :, 4:max_tokens+4].transpose(1, 2).reshape(bs, 4, self.n_draw * max_tokens)
# flush for next API call
self.transformer._flush()
return out_codes # SKIP THE 4 fill 2048
def create_sin_embedding(positions,
dim,
max_period=10000
):
# assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(torch.float)
adim = torch.arange(half_dim, device=positions.device,
dtype=torch.float).view(1, 1, -1)
max_period_tensor = torch.full([],
max_period,
device=positions.device,
dtype=torch.float) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
# OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
class StreamingMultiheadAttention(nn.Module):
def __init__(self,
embed_dim,
num_heads,
cross_attention=False,
):
super().__init__()
self.cross_attention = cross_attention
# if not self.cross_attention then it has kvcachingn
self.k_history = None
# cleanup history through LM inside GENERATION - Each 0,..,47 mha has different kv history
self.v_history = None
self.num_heads = num_heads
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
dtype=torch.float))
def forward(self,
query,
key=None,
value=None):
layout = "b h t d"
if self.cross_attention:
# Different queries, keys, values > split in_proj_weight
dim = self.in_proj_weight.shape[0] // 3
q = nn.functional.linear(query, self.in_proj_weight[:dim])
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
q, k, v = [
rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
else:
# Here <else> = self_attention for audio with itself (above is cross attention txt)
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
# here we have different floating values from official
projected = nn.functional.linear(query, self.in_proj_weight, None)
# print(query.sum(), projected.sum() , self.in_proj_weight.sum(), 'Lc') # verified official AudioGen values
bound_layout = "b h p t d"
packed = rearrange(
projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = packed.unbind(dim=2)
if self.k_history is not None:
# IF ctrl^c during live_demo the assigning of each of kv is non-atomic k!=v
# thus it will try to continue with incompatible k/v dims!
self.k_history = torch.cat([self.k_history, k], 2)
self.v_history = torch.cat([self.v_history, v], 2)
else:
self.k_history = k
self.v_history = v
# Assign Completed k / v to k / v
k = self.k_history
v = self.v_history
# -> kv CACHE ONLY APPLIES if not self.cross_attention
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, is_causal=False, dropout_p=0.0)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
return x
class StreamingTransformerLayer(nn.Module):
def __init__(self,
d_model,
num_heads,
dim_feedforward):
super().__init__()
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads)
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False)
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
num_heads=num_heads,
cross_attention=True)
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
def forward(self,
x,
cross_attention_src=None):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attention(query=self.norm_cross(x),
key=cross_attention_src,
value=cross_attention_src) # txtcondition
x = x + self.linear2(F.gelu(self.linear1(self.norm2(x))))
return x
class StreamingTransformer(nn.Module):
def __init__(self,
d_model=1536,
num_heads=24,
num_layers=48,
dim_feedforward=6144):
super().__init__()
self.layers = nn.ModuleList(
[
StreamingTransformerLayer(d_model=d_model,
num_heads=num_heads,
dim_feedforward=dim_feedforward) for _ in range(num_layers)
]
)
def forward(self,
x,
cache_position=None,
cross_attention_src=None):
x = x + create_sin_embedding(
torch.zeros(x.shape[0], 1, 1, device=x.device) + cache_position, 1536)
for lay in self.layers:
x = lay(x,
cross_attention_src=cross_attention_src)
return x
def _flush(self,
n_preserve=None):
for lay in self.layers:
if n_preserve is not None:
# cache position is difficult to choose to also preserve kv from end
lay.self_attn.k_history = lay.self_attn.k_history[:, :, :n_preserve, :]
lay.self_attn.v_history = lay.self_attn.v_history[:, :, :n_preserve, :]
else:
lay.self_attn.k_history = None
lay.self_attn.v_history = None
if __name__ == '__main__':
import audiofile
model = AudioGen().to('cpu')
x = model.generate(prompt='swims in lake frogs', duration=6.4).cpu().numpy()
audiofile.write('_sound_.wav', x, 16000)