🥦 Pretraining Thread

#2
by burtenshaw - opened
nanochat students org
edited 3 days ago

2. Pre-training

To pretrain, we need to download a larger slice of the data:

python -m nanochat.dataset -n 240 &

We can then run training like so and integrate trackio:

export TRACKIO_SPACE_ID="nanochat-students/trackio"
export TRACKIO_PROJECT="nanochat-pretraining"
export TRACKIO_DATASET_ID="nanochat-students/trackio-dataset"
export HF_TOKEN="<your-token>"
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20

This will start logging to the cli like so:

step 14218/21400 (66.44%) | loss: 2.885312 | lrm: 1.00 | dt: 478.80ms | tok/sec: 1,095,011 | mfu: 48.33 | total time: 114.67m
step 14219/21400 (66.44%) | loss: 2.874319 | lrm: 1.00 | dt: 479.48ms | tok/sec: 1,093,459 | mfu: 48.26 | total time: 114.68m
step 14220/21400 (66.45%) | loss: 2.880379 | lrm: 1.00 | dt: 478.54ms | tok/sec: 1,095,590 | mfu: 48.35 | total time: 114.69m

And report metrics to the shared trackio space: https://nanochat-students-trackio.hf.space"

... pretraining is still running. So I'll report back.

nanochat students org

The weights from pre-training are here: https://huggingface.co/nanochat-students/base-d20

nanochat students org

The hardest part of this was porting the custom inference code to transformers. But it was fun!

from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch

model_dir = "nanochat-students/base-d20"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True)
model = model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

prompt = "The capital of Belgium is "
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
ids = torch.tensor([input_ids], dtype=torch.long, device=device)

max_new_tokens = 50
with torch.inference_mode():
    for _ in range(max_new_tokens):
        outputs = model(input_ids=ids)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        ids = torch.cat([ids, next_token], dim=1)

decoded = tokenizer.decode(ids[0].tolist())
print(decoded)

Sign up or log in to comment