Spaces:
Running
Running
🥦 Pretraining Thread
#2
by
burtenshaw
- opened
2. Pre-training
To pretrain, we need to download a larger slice of the data:
python -m nanochat.dataset -n 240 &
We can then run training like so and integrate trackio:
export TRACKIO_SPACE_ID="nanochat-students/trackio"
export TRACKIO_PROJECT="nanochat-pretraining"
export TRACKIO_DATASET_ID="nanochat-students/trackio-dataset"
export HF_TOKEN="<your-token>"
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20
This will start logging to the cli like so:
step 14218/21400 (66.44%) | loss: 2.885312 | lrm: 1.00 | dt: 478.80ms | tok/sec: 1,095,011 | mfu: 48.33 | total time: 114.67m
step 14219/21400 (66.44%) | loss: 2.874319 | lrm: 1.00 | dt: 479.48ms | tok/sec: 1,093,459 | mfu: 48.26 | total time: 114.68m
step 14220/21400 (66.45%) | loss: 2.880379 | lrm: 1.00 | dt: 478.54ms | tok/sec: 1,095,590 | mfu: 48.35 | total time: 114.69m
And report metrics to the shared trackio space: https://nanochat-students-trackio.hf.space"
... pretraining is still running. So I'll report back.
The hardest part of this was porting the custom inference code to transformers. But it was fun!
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
model_dir = "nanochat-students/base-d20"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
model = model.to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
prompt = "The capital of Belgium is "
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
ids = torch.tensor([input_ids], dtype=torch.long, device=device)
max_new_tokens = 50
with torch.inference_mode():
for _ in range(max_new_tokens):
outputs = model(input_ids=ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
ids = torch.cat([ids, next_token], dim=1)
decoded = tokenizer.decode(ids[0].tolist())
print(decoded)