|
--- |
|
license: apache-2.0 |
|
pipeline_tag: text-generation |
|
tags: |
|
- model_hub_mixin |
|
- pytorch_model_hub_mixin |
|
- RxNN |
|
- SparseQueryAttention |
|
- SQA |
|
language: |
|
- en |
|
datasets: |
|
- wikimedia/wikipedia |
|
library_name: RxNN |
|
--- |
|
|
|
# SQAT-m: Sparse Query Attention Transformer mini |
|
Research model for [Sparse Query Attention](https://github.com/RxAI-dev/RxNN/blob/main/docs/research/sparse_query_attention.md) |
|
experiments - extension to Grouped Query Attention, that's also reducing the number of used query heads, instead of further |
|
reducing key/value heads count (up to Multi Query Attention). That approach results in huge computational complexity reduction |
|
and much faster training, while the performance stays on GQA level (almost unnoticeable decrease, when compared to GQA, and |
|
noticeable better than MQA). |
|
|
|
### Architecture details: |
|
- trainable params: ~10.7M |
|
- dim: 256 |
|
- layers: 8 |
|
- self-attention: Sparse Query Attention |
|
- heads: 16 (for dimension split) |
|
- query groups: 8 |
|
- key/value groups: 4 |
|
- SwiGLU feed forward with 768 dim |
|
- RoPE |
|
- RMS Norm |
|
- vocab: 10k (english only) |
|
- message length: 1024 |
|
- Library: RxNN |
|
|
|
### Training details: |
|
This model was only trained for research purposes, on a small number of training steps. As it's the most promising from |
|
tested attention architectures, it will be developed further soon. |
|
- dataset: 50% from english subset of [wikimedia/wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia) (45% train / 5% validation) |
|
- single epoch |
|
- 1.5B processed tokens |
|
- learning rate: 5e-4, cosine annealing scheduler with 25% warmup steps |
|
|
|
### Results |
|
Validation mean loss/accuracy: |
|
- MHA: 1.1976 / ~77.35% |
|
- GQA: 1.2177 / ~77.12% |
|
- MQA: 1.2497 / ~76.64% |
|
- **SQA: 1.2272 / ~76.97%** |
|
|
|
Training time / time per batch: |
|
- MHA: ~269 min / 0.7173s |
|
- GQA: ~258 min / 0.6877s |
|
- MQA: ~261 min / 0.6947s |
|
- **SQA: ~241 min / 0.6417s** |
|
|
|
### Computational complexity comparison |
|
- MHA: `O(N*d * N*d)` |
|
- GQA `O(N*d * N*(d/heads*groups))` |
|
- MQA `O(N*d * N*(d/heads))` |
|
- SQA `O(N*(d/heads*query_groups) * N*(d/heads*groups))` |
|
|
|
SQA has reduced two factors instead of one. That means it will better scale for longer sequences and training time gains |
|
will be even greater. |
|
|
|
Furthermore, even _the extreme version_ of **SQA** with only 4/16 used query heads (and also 4/16 key/value heads), seems to perform a little |
|
better than a reference MQA model, with even shorter training times. It suggests that **SQA** could be a gamechanger for efficient |
|
long context handling. More info in [ReactiveAI/xSQAT-m](https://huggingface.co/ReactiveAI/xSQAT-m) |
|
|
|
### Model size difference |
|
SQA has reduced dimensions of query heads linear projection and output projection, which results in a little smaller model size: |
|
- MHA: 12M Params |
|
- GQA: 11.2M Params |
|
- MQA: 11M Params |
|
- **SQA: 10.7M Params** |
|
|
|
### Usage |
|
Model requires [RxNN framework](https://github.com/RxAI-dev/RxNN) for training/inference. It's integrated with HuggingFace Hub and libraries. |
|
|
|
#### Inference: |
|
- Install RxNN, PyTorch and dependencies: `pip install rxnn torch transformers tokenizers` |
|
```python |
|
import torch |
|
from rxnn.experimental.models import ExperimentalAttentionTransformer |
|
from rxnn.transformers.sampler import Sampler, SampleDecoder |
|
from rxnn.training.tokenizer import load_tokenizer_from_hf_hub |
|
|
|
model = ExperimentalAttentionTransformer.from_pretrained('ReactiveAI/SQAT-m') |
|
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/SQAT-m') |
|
sampler = Sampler(model, torch.device('cuda' if torch.cuda.is_available() else 'cpu'), end_token_id=3) |
|
sample = SampleDecoder(sampler, tokenizer) |
|
|
|
# 0.1 and 0.9 are default values for temperature and top_p |
|
generated = sample('Example model input for text generation...', temperature=0.1, top_p=0.9, max_seq_len=1024) |
|
sample('Example model input for text generation - print streamed response...', temperature=0.1, top_p=0.9, max_seq_len=1024, print_stream=True) |
|
``` |
|
|
|
#### Train: |
|
- Install RxNN, PyTorch and dependencies: `pip install rxnn torch transformers tokenizers tensorboard` (`tensorboard` is optional) |
|
```python |
|
import torch |
|
from rxnn.experimental.models import ExperimentalAttentionTransformer |
|
from rxnn.training.tokenizer import load_tokenizer_from_hf_hub |
|
from rxnn.training.dataset import AutoregressiveLMDataset |
|
from rxnn.training.bml import AutoregressiveTrainer |
|
from rxnn.training.callbacks import PrintLossCallback, PrintAccuracyCallback, TokenCounterCallback, ModelSaveCallback |
|
from rxnn.training.scheduler import get_transformer_lr_scheduler |
|
|
|
model = ExperimentalAttentionTransformer.from_pretrained('ReactiveAI/SQAT-m') |
|
tokenizer = load_tokenizer_from_hf_hub('ReactiveAI/SQAT-m') |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
batch_size = 128 # Require ~40GB GPU Memory (trained on L40S) |
|
epochs = 1 |
|
gradient_acc_steps = 1 |
|
seq_len = 1024 |
|
vocab_size = 10_000 |
|
|
|
peak_lr = 5e-4 * gradient_acc_steps |
|
|
|
train_dataset = AutoregressiveLMDataset.from_hf_hub('hf-dataset-id', 'subset', tokenizer=tokenizer, max_seq_len=seq_len) # split is 'train' by default |
|
valid_dataset = AutoregressiveLMDataset.from_hf_hub('hf-dataset-id', split='validation', tokenizer=tokenizer, max_seq_len=seq_len) |
|
|
|
dataset_len = len(train_dataset) |
|
|
|
steps_per_epoch = int(dataset_len / batch_size - 1) |
|
total_steps = int((epochs * steps_per_epoch) / gradient_acc_steps) |
|
warmup_steps = int(0.25 * steps_per_epoch) |
|
|
|
|
|
logs_dir = './tensorboard_logs' # require tensorboard `pip install tensorboard` |
|
|
|
print_cb = PrintLossCallback(batches_per_epoch=steps_per_epoch) |
|
count_cb = TokenCounterCallback() |
|
acc_cb = PrintAccuracyCallback() |
|
save_cb = ModelSaveCallback('./path/to/save', push_to_hub=True, |
|
hub_model_id='your-model-id', private_repo=True, |
|
push_checkpoint_weights=True, final_commit_message='Final commit message', hf_token=YOUR_HF_TOKEN) |
|
|
|
trainer = AutoregressiveTrainer(model, device, dataset=train_dataset, validation_dataset=valid_dataset, |
|
vocab_size=vocab_size, callbacks=[print_cb, acc_cb, count_cb, save_cb], use_amp=True, |
|
dtype=torch.bfloat16, log_dir=logs_dir, gradient_accumulation_steps=gradient_acc_steps) |
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.01) |
|
scheduler = get_transformer_lr_scheduler( |
|
optimizer, |
|
warmup_steps=warmup_steps, |
|
num_training_steps=total_steps |
|
) |
|
|
|
trainer(epochs=epochs, batch_size=batch_size, optimizer=optimizer, scheduler=scheduler) |
|
``` |
|
|
|
## Summary |
|
According to experiment results, SparseQueryAttention seems to be the most cost-effective variant of GroupedQueryAttention, |
|
leading to noticeable training time reduction and is a promising research direction. Currently, for our **Reactive Tranformer** |
|
architectures that were initially designed with GQA for self-attention and MQA for memory-attention, we consider using SQA |
|
instead for all attention layer types. More info will be released soon. |
|
|