asr-model / model_simple.py
Sin2pi's picture
Update model_simple.py
bcea466 verified
import warnings
import os
import logging
from itertools import chain
import torch
from torch import nn, Tensor, einsum
from typing import Optional
import numpy as np
from dataclasses import dataclass
from einops import rearrange
from datasets import load_dataset, Audio
from echoutils import extract_features, setup_tokenizer, compute_metrics, DataCollator, preprocess_logits_for_metrics, sinusoids, get_activation
from datetime import datetime
from transformers.trainer_seq2seq import Seq2SeqTrainer
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)
def there_is_a(val):
return val is not None
@dataclass
class Dimensions:
vocab: int
mels: int
ctx: int
dims: int
head: int
layer: int
act: str
def qkv_init(dims, head):
head_dim = dims // head
q = nn.Linear(dims, dims)
k = nn.Linear(dims, dims)
v = nn.Linear(dims, dims)
o = nn.Linear(dims, dims)
lna = nn.LayerNorm(dims)
lnb = nn.LayerNorm(dims)
lnc = nn.LayerNorm(head_dim)
lnd = nn.LayerNorm(head_dim)
return q, k, v, o, lna, lnb, lnc, lnd
def shape(dims, head, q, k, v):
batch_size = q.shape[0]
seq_len_q = q.shape[1]
seq_len_kv = k.shape[1]
head_dim = dims // head
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2)
return q, k, v
class rotary(nn.Module):
def __init__(self, dims, head):
super(rotary, self).__init__()
self.dims = dims
self.head = head
self.head_dim = dims // head
self.theta = nn.Parameter((torch.tensor(10000, device=device, dtype=dtype)), requires_grad=True)
self.register_buffer('freqs_base', self._compute_freqs_base(), persistent=False)
def _compute_freqs_base(self):
mel_scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), self.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
return 200 * mel_scale / 1000
def forward(self, x) -> Tensor:
freqs = (self.theta / 220.0) * self.freqs_base
pos = torch.arange(x.shape[2], device=device, dtype=dtype)
freqs = pos[:, None] * freqs
freqs=torch.polar(torch.ones_like(freqs), freqs)
x1 = x[..., :freqs.shape[-1]*2]
x2 = x[..., freqs.shape[-1]*2:]
orig_shape = x1.shape
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
x1 = torch.view_as_complex(x1) * freqs
x1 = torch.view_as_real(x1).flatten(-2)
x1 = x1.view(orig_shape)
return torch.cat([x1.type_as(x), x2], dim=-1)
def calculate_attention(q, k, v, mask=None, temp=1.0, pytorch=True):
scaled_q = q
if temp != 1.0 and temp > 0:
scaled_q = q * (1.0 / temp)**.5
if pytorch:
out = torch.nn.functional.scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1)
else:
scale = q.shape[-1] ** -0.35
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if there_is_a(mask):
mask = mask[:qk.shape[2], :qk.shape[2]]
qk = qk.masked_fill(mask.bool(), -torch.inf)
qk = qk.float()
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
return out
class LocalOut(nn.Module):
def __init__(self, dims: int, head: int):
super().__init__()
self.head_dim = dims // head
self.dims = dims
self.q_hd = nn.Linear(self.head_dim, self.head_dim)
self.k_hd = nn.Linear(self.head_dim, self.head_dim)
self.v_hd = nn.Linear(self.head_dim, self.head_dim)
self.out = nn.Linear(self.head_dim, self.head_dim)
def _reshape_to_output(self, attn_output: Tensor) -> Tensor:
batch, _, ctx, _ = attn_output.shape
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims)
class attentionb(nn.Module):
def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.5, temp = 1.0):
super(attentionb, self).__init__()
self.head = head
self.dims = dims
self.head_dim = dims // head
self.que = nn.Linear(dims, dims, bias=False)
self.kv = nn.Linear(dims, dims * 2, bias=False)
self.out = nn.Linear(dims, dims, bias=False)
self.lna = nn.LayerNorm(dims)
self.lnb = nn.LayerNorm(dims // head)
self.rope = rotary(dims, head)
self.max_iter = max_iter
self.threshold = nn.Parameter(torch.tensor(threshold), requires_grad=True)
self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True)
self.local = LocalOut(dims, head)
def update_win(self, win_size=None):
if win_size is not None:
self.win_size = win_size
return win_size
elif hasattr(self, 'win_size') and self.win_size is not None:
win_size = self.win_size
return win_size
return None
def _focus(self, x, xa = None, mask = None, win_size=None):
q = self.que(self.lna(x))
k, v = self.kv(self.lna(x if xa is None else xa)).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
self.scale = q.shape[-1] ** -0.35
q = self.rope(q)
k = self.rope(k)
iteration = 0
temp = self.temp.item()
prev_out = torch.zeros_like(q)
attn_out = torch.zeros_like(q)
threshold = self.threshold
curq = q #if curq is None else curq
while iteration < self.max_iter:
eff_span = curq.shape[2]
if eff_span == 0:
break
qiter = curq[:, :, :eff_span, :]
kiter = k[:, :, :eff_span, :]
viter = v[:, :, :eff_span, :]
q = self.local.q_hd(qiter)
k = self.local.k_hd(kiter)
v = self.local.v_hd(viter)
iter_mask = None
if mask is not None:
if mask.dim() == 4:
iter_mask = mask[:, :, :eff_span, :eff_span]
elif mask.dim() == 2:
iter_mask = mask[:eff_span, :eff_span]
attn_iter = calculate_attention(
self.lnb(q), self.lnb(k), v,
mask=iter_mask, temp=temp)
iter_out = torch.zeros_like(curq)
iter_out[:, :, :eff_span, :] = attn_iter
diff = torch.abs(iter_out - prev_out).mean()
if diff < threshold and iteration > 0:
attn_out = iter_out
break
prev_out = iter_out.clone()
curq = curq + iter_out
attn_out = iter_out
iteration += 1
temp -= 0.005
return rearrange(attn_out, 'b h c d -> b c (h d)')
def _slide_win_local(self, x, mask = None) -> Tensor:
win = self.update_win()
win_size = win if win is not None else self.head_dim
span_len = win_size + win_size // self.head
_, ctx, _ = x.shape
out = torch.zeros_like(x)
windows = (ctx + win_size - 1) // win_size
for i in range(windows):
qstart = i * win_size
qend = min(qstart + win_size, ctx)
qlen = qend - qstart
if qlen == 0:
continue
kstart = max(0, qend - span_len)
qwin = x[:, qstart:qend, :]
kwin = x[:, kstart:qend, :]
win_mask = None
if mask is not None:
if mask.dim() == 4:
win_mask = mask[:, :, qstart:qend, kstart:qend]
elif mask.dim() == 2:
win_mask = mask[qstart:qend, kstart:qend]
attn_out = self._focus(x=qwin, xa=kwin, mask=win_mask, win_size=win_size)
out[:, qstart:qend, :] = attn_out
return out
def forward(self, x, xa = None, mask = None):
x = self._slide_win_local(x, mask=None)
xa = self._slide_win_local(xa, mask=None)
out = self._focus(x, xa, mask=None)
return self.out(out)
def scaled_relu(x, sequence_length):
relu_output = torch.relu(x)
return relu_output / sequence_length
def taylor_softmax(x, order=2):
taylor_approx = 1.0
for i in range(1, order + 1):
factorial_i = torch.exp(torch.lgamma(torch.tensor(i + 1, dtype=torch.float32)))
taylor_approx += x**i / factorial_i
return taylor_approx / torch.sum(taylor_approx, dim=-1, keepdim=True)
def taylor_softmax_2nd_order(x):
exp_approx = 1 + x + (x**2) / 2
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True)
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
qk_cosine = qk_cosine + mask
weights = F.softmax(qk_cosine, dim=-1)
out = torch.matmul(weights, v)
return out
class attentiona(nn.Module):
def __init__(self, dims: int, head: int, dropout_rate: float = 0.1):
super().__init__()
self.head = head
self.dims = dims
self.que = nn.Linear(dims, dims, bias=False)
self.kv = nn.Linear(dims, dims * 2, bias=False)
self.out = nn.Linear(dims, dims, bias=False)
self.ln = nn.LayerNorm(dims)
self.rope = rotary(dims, head)
def forward(self, x, xa = None, mask = None):
q = self.que(self.ln(x))
k, v = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b c (h d) -> b h c d', h = self.head), (q, k, v))
scale = q.shape[-1] ** -0.5
q = self.rope(q)
k = self.rope(k)
qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale
if there_is_a(mask):
mask = mask[:qk.shape[2], :qk.shape[2]]
qk = qk.masked_fill(mask.bool(), -torch.inf)
qk = taylor_softmax(qk, order=2) # qk = torch.nn.functional.softmax(qk, dim=-1)
wv = einsum('b h k q, b h q d -> b h k d', qk, v)
wv = rearrange(wv, 'b h c d -> b c (h d)')
out = self.out(wv)
return out
class attentiond(nn.Module):
def __init__(self, dims: int, head: int):
super().__init__()
self.head = head
self.dims = dims
self.que = nn.Linear(dims, dims, bias=False)
self.kv = nn.Linear(dims, dims * 2, bias=False)
self.out = nn.Linear(dims, dims, bias=False)
self.ln = nn.LayerNorm(dims)
self.rope = rotary(dims, head)
self.x = nn.Conv2d(head, head, 1, bias = False)
self.xa = nn.Conv2d(head, head, 1, bias = False)
def forward(self, x, xa = None, mask = None):
qk, v = self.kv(self.ln(x)).chunk(2, dim=-1)
qka, va = self.kv(self.ln(x if xa is None else xa)).chunk(2, dim=-1)
qk, qka, v, va = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.head), (qk, qka, v, va))
qk = einsum('b h q d, b h k d -> b h q k', qk, qka)
if there_is_a(mask):
mask = mask[:qk.shape[2], :qk.shape[2]]
qk = qk.masked_fill(mask.bool(), -torch.inf)
x = qk.softmax(dim = -1)
xa = qk.softmax(dim = -2)
x = self.x(x)
xa = self.xa(xa)
x = einsum('b h i j, b h j d -> b h i d', x, va)
xa = einsum('b h j i, b h j d -> b h i d', xa, v)
x, xa = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (x, xa))
out = self.out(x)
return out
class tgate(nn.Module):
def __init__(self, dims, num_types=4):
super().__init__()
self.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, dims), nn.Sigmoid()) for _ in range(num_types)])
self.classifier = nn.Sequential(nn.Linear(dims, num_types), torch.nn.functional.Softmax(dim=-1))
def forward(self, x):
types = self.classifier(x)
gates = torch.stack([gate(x) for gate in self.gates], dim=-1)
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1)
return cgate
class residual(nn.Module):
def __init__(self, dims: int, head: int, act: str = "silu"):
super().__init__()
self.lna = nn.LayerNorm(dims, bias=False)
self.atta = attentiona(dims, head)
self.attb = attentionb(dims, head, max_iter=1)
self.attc = attentiond(dims, head)
self.tgate = tgate(dims, num_types=1)
self.mlp = nn.Sequential(nn.Linear(dims, dims*4), get_activation(act), nn.Linear(dims*4, dims))
def forward(self, x: Tensor, xa = None, mask = None):
out = self.atta(x, mask=mask)
if x.shape == out.shape:
x = x + out
else:
x = out
if xa is not None:
x = x + self.atta(x, xa, mask=None)
x = x + self.tgate(x)
x = x + self.mlp(self.lna(x))
return x
class processor(nn.Module):
def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int, act: str = "gelu"):
super(processor, self).__init__()
self.ln = nn.LayerNorm(dims)
self.token = nn.Embedding(vocab, dims)
self.audio = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
self.positions = nn.Parameter(torch.empty(ctx, dims), requires_grad=True)
self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
act_fn = get_activation(act)
self.encoder = nn.Sequential(
nn.Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
self.blocka = nn.ModuleList([residual(dims, head, act_fn) for _ in range(layer)])
self.blockm = nn.ModuleList([residual(dims, head, act_fn) for _ in range(2)])
mask = torch.triu(torch.ones(ctx, ctx), diagonal=1)
mask = torch.empty(ctx, ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x, xa, xb, sequential=False, modal=False, blend=False, kv_cache=None) -> Tensor:
if xa.dim() == 2:
xa = xa.unsqueeze(0)
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (self.token(x.long()) + self.positions[offset : offset + x.shape[-1]])
xa = self.encoder(xa).permute(0, 2, 1)
xa = xa + self.audio(xa.shape[1], xa.shape[-1], 36000.0).to(device, dtype)
for block in chain(self.blocka or []):
xa = block(xa, mask=None)
x = block(x, mask=self.mask)
x = block(x, xa, mask=None)
if blend:
if sequential:
y = x
else:
a = torch.sigmoid(self.blend)
x = a * x + (1 - a) * y
for block in chain(self.blockm or []):
xm = block(torch.cat([x, xa], dim=1), torch.cat([x, xa], dim=1), mask=None) if modal else None
x = block(xm[:, :x.shape[1]], xm[:, x.shape[1]:], mask=None) if modal else x
if blend:
if sequential:
y = x
else:
a = torch.sigmoid(self.blend)
x = a * x + (1 - a) * y
x = nn.functional.dropout(x, p=0.001, training=self.training)
x = self.ln(x)
x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
return x
class Model(nn.Module):
def __init__(self, param: Dimensions):
super().__init__()
self.param = param
self.processor = processor(
vocab=param.vocab,
mels=param.mels,
ctx=param.ctx,
dims=param.dims,
head=param.head,
layer=param.layer,
act=param.act)
self.best_loss = float('inf')
self.factor = nn.Parameter(torch.tensor(2), requires_grad=False)
def update(self, win_size):
for name, module in self.processor.named_modules():
if isinstance(module, (attentionb)):
module.update_win(win_size)
def adjust_window(self, loss, ctx):
self.win_size = ((ctx // self.param.head))
if loss < self.best_loss:
win_size = (self.win_size * self.factor)
else:
win_size = (self.win_size // self.factor).clamp(0, self.win_size - 1)
self.win_size = win_size
self.best_loss = loss
self.update(win_size)
return win_size
def forward(self, labels=None, input_ids=None, pitch=None, pitch_tokens=None, spectrogram=None, waveform=None):
x = input_ids
xa = pitch
xb = spectrogram
enc = {}
if spectrogram is not None:
enc["spectrogram"] = spectrogram
if waveform is not None:
enc["waveform"] = waveform
if pitch is not None:
enc["pitch"] = pitch
logits = self.processor(x, xa, xb)
loss = None
if labels is not None:
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
self.adjust_window(loss=loss.item(), ctx=xa.shape[1])
return {"logits": logits, "loss": loss}
def _init_weights(self, module):
self.init_counts = {
"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0,
"Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
for name, module in self.named_modules():
if isinstance(module, nn.RMSNorm):
nn.init.ones_(module.weight)
self.init_counts["RMSNorm"] += 1
if isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
self.init_counts["LayerNorm"] += 1
elif isinstance(module, nn.Linear):
if module.weight is not None:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Linear"] += 1
elif isinstance(module, nn.Conv1d):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Conv1d"] += 1
elif isinstance(module, nn.Conv2d):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.init_counts["Conv2d"] += 1
elif isinstance(module, residual):
self.init_counts["Residual"] += 1
elif isinstance(module, processor):
self.init_counts["processor"] += 1
def init_weights(self):
print("Initializing model weights...")
self.apply(self._init_weights)
print("Initialization summary:")
for module_type, count in self.init_counts.items():
if count > 0:
print(f"{module_type}: {count}")
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.param.ctx:
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, attentiona):
hooks.append(layer.k.register_forward_hook(save_to_cache))
hooks.append(layer.v.register_forward_hook(save_to_cache))
self.processor.apply(install_hooks)
return cache, hooks
### "pipeline"
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=True, load_saved=False, save_dataset=True, cache_dir='E:/hf', extract_args=None, max_ctx=2048):
if load_saved:
if cache_dir is None:
cache_dir = cache_dir
else:
cache_dir = cache_dir
os.makedirs(cache_dir, exist_ok=True)
cache_file_train = os.path.join(cache_dir, "train.arrow")
cache_file_test = os.path.join(cache_dir, "test.arrow")
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test):
from datasets import Dataset
train_dataset = Dataset.load_from_disk(cache_file_train)
test_dataset = Dataset.load_from_disk(cache_file_test)
return train_dataset, test_dataset
def filter_func(x):
return (0 < len(x["transcription"]) < max_ctx and
len(x["audio"]["array"]) > 0 and
len(x["audio"]["array"]) < max_ctx * 160)
raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription")
raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000)
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate))
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args)).remove_columns(["audio", "transcription"])
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None
return train_dataset, test_dataset
def main():
token = ""
log_dir = os.path.join('D:/newmodel/output/logs/', datetime.now().strftime('%m-%d_%H_%M_%S'))
os.makedirs(log_dir, exist_ok=True)
tokenizer = setup_tokenizer("D:/newmodel/mod5/tokenizer.json")
extract_args = {
"waveform": True,
"spec": True,
"pitch_tokens": True,
"pitch": True,
"harmonics": False,
"aperiodics": False,
"phase_mod": False,
"crepe": False,
"sample_rate": 16000,
"hop_length": 256,
"mode": "mean",
"debug": False,
}
param = Dimensions(vocab=40000, mels=128, ctx=2048, dims=512, head=4, layer=4, act="swish")
train_dataset, test_dataset = prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False,
load_saved=False, save_dataset=False, cache_dir=None, extract_args=extract_args, max_ctx=param.ctx)
model = Model(param).to('cuda')
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
from functools import partial
metrics_fn = partial(compute_metrics, print_pred=True, num_samples=1, tokenizer=tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(
output_dir=log_dir,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
max_steps=100000,
eval_steps=1000,
save_steps=1000,
warmup_steps=1000,
logging_steps=100,
logging_dir=log_dir,
logging_strategy="steps",
eval_strategy="steps",
save_strategy="no",
report_to=["tensorboard"],
push_to_hub=False,
save_total_limit=1,
label_names=["labels"],
save_safetensors=False,
eval_on_start=False,
batch_eval_metrics=False,
disable_tqdm=False,
include_tokens_per_second=True,
include_num_input_tokens_seen=True,
learning_rate=0.00025,
weight_decay=0.025,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate, eps=1e-8, weight_decay=training_args.weight_decay, betas=(0.9, 0.999), amsgrad=False, foreach=False, fused=False, capturable=False, differentiable=False, maximize=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_args.max_steps, eta_min=1e-9, last_epoch=-1)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=DataCollator(tokenizer=tokenizer),
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
compute_metrics=metrics_fn,
optimizers=(optimizer, scheduler)
)
model.init_weights()
trainer.train()
if __name__ == "__main__":
main()