File size: 3,780 Bytes
9a2ea4c
 
 
 
 
 
 
 
 
bd3f49f
9a2ea4c
 
 
bd3f49f
 
9a2ea4c
 
bd3f49f
9a2ea4c
bd3f49f
9a2ea4c
 
 
 
 
bd3f49f
9a2ea4c
 
 
 
 
 
bd3f49f
9a2ea4c
bd3f49f
 
 
 
 
 
 
 
9a2ea4c
 
 
 
 
 
bd3f49f
 
 
9a2ea4c
 
 
 
 
bd3f49f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a2ea4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3f49f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
---
license: other
---

# xLSTM-7B
This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.


## How to use it
First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels (tested on python 3.11):

```bash
pip install xlstm
pip install accelerate
pip install 'transformers @ git+https://github.com/huggingface/transformers.git@main'
```

If you get an error regarding triton library:
```bash
pip install 'triton @ git+https://github.com/triton-lang/triton.git@main'
```

Use this model as:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")

# this is a fork of EleutherAI/gpt-neox-20b
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")

tokens = tokenizer("Explain quantum computing in simple terms.", return_tensors='pt')['input_ids'].to(device="cuda")

# Get the BOS token ID from the tokenizer
bos_id = tokenizer.bos_token_id

# Prepend BOS
bos_tensor = torch.tensor([[bos_id]], device=tokens.device, dtype=tokens.dtype)
tokens_with_bos = torch.cat([bos_tensor, tokens], dim=1)

out = xlstm.generate(tokens_with_bos, max_new_tokens=20)

print(tokenizer.decode(out[0]))
```

If you cannot or do not want to use the triton kernels, you can change them to native PyTorch implementations:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch

xlstm_config = AutoConfig.from_pretrained("NX-AI/xLSTM-7b")
xlstm_config.step_kernel = "native"
xlstm_config.chunkwise_kernel = "chunkwise--native_autograd"
xlstm_config.sequence_kernel = "native_sequence__native"

xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b",
                                             config=xlstm_config, device_map="auto")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")

# Your prompt
prompt = "Explain quantum computing in simple terms."

# Tokenize and send to the same device as the model
inputs = tokenizer(prompt, return_tensors="pt")['input_ids'].to(xlstm.device)

# Get the BOS token ID from the tokenizer
bos_id = tokenizer.bos_token_id

# Prepend BOS
bos_tensor = torch.tensor([[bos_id]], device=xlstm.device, dtype=inputs.dtype)
tokens_with_bos = torch.cat([bos_tensor, inputs], dim=1)

# Generate
outputs = xlstm.generate(
    tokens_with_bos,
    max_new_tokens=200,   # adjust for output length
    temperature=0.7,      # randomness
    top_p=0.9,             # nucleus sampling
    do_sample=True
)

# Decode and print
print(tokenizer.decode(outputs[0]))

# verify selected kernels
from pprint import pprint
pprint(xlstm.backbone.blocks[0].mlstm_layer.config)
```


## Speed results
Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100:
![generation speed](plot_tokens_per_sec.svg)

## Performance
![mmlu_train_token](MMLUvsTrainToken.svg)

Using HuggingFace's `lm_eval`:

| BBH   | MMLU-Pro | Math   | MUSR | GPQA | IfEval | 
|-------|----------|--------|------|------|--------|
| 0.381	| 0.242	   | 0.036	| 0.379|0.280 |	0.244  |

Using HuggingFace's `lighteval` in the Leaderboard-v1 settings:

|Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)|
|------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------|
| 0.584	                 |0.589         |           0.710   |0.742               |          0.420     |         0.004 |         0.443      |        0.817 |

## License 
NXAI Community License (see `LICENSE` file)