SMOLLM2_105M / train.py
chbsaikiran's picture
Initial Commit
c6e5e78
import torch
import torch.nn as nn
import math
# RMSNorm is a normalization technique that normalizes the input by dividing by the square root of the variance plus a small number to prevent division by zero
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5): # the number of features/dimensions/embeddings in the input, eps is a small number to prevent division by zero
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) # weight is a learnable parameter that scales the input
self.eps = eps
def forward(self, x):
norm = x.pow(2).mean(-1, keepdim=True).sqrt() + self.eps # compute the norm of the input
return x / norm * self.weight # normalize the input by dividing by the norm and scale it by the weight parameter
# RotaryEmbedding is a technique that rotates the input by a learnable angle
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None): # dim is the number of features/dimensions/embeddings in the input, base is a base number for the frequency, device is the device to store the buffer
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) # compute the inverse frequency
self.register_buffer("inv_freq", inv_freq) # register the inverse frequency as a buffer
def forward(self, x, seq_len):
seq_len = seq_len.to(x.device) # convert seq_len to the device of the input
t = torch.arange(seq_len, device=x.device) # create a tensor of the sequence length
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # compute the frequency by taking the dot product of the sequence length and the inverse frequency
emb = torch.cat((freqs, freqs), dim=-1) # concatenate the frequency with itself
return emb
class LlamaMLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) # create the gate projection layer with the input dimension and the hidden dimension
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) # create the up projection layer with the input dimension and the hidden dimension
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) # create the down projection layer with the hidden dimension and the output dimension
self.act_fn = nn.SiLU() # create the activation function
def forward(self, x):
gated = self.gate_proj(x) # apply the gate projection to the input
hidden = self.up_proj(x) # apply the up projection to the input
return self.down_proj(self.act_fn(gated * hidden)) # apply the activation function to the gated and hidden values and then apply the down projection
class LlamaAttention(nn.Module):
def __init__(self, dim, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.o_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x):
batch_size, seq_len, dim = x.size() # [batch_size, seq_len, dim] -> [4, 128, 576]
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Split heads
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attention = torch.softmax(scores, dim=-1)
context = torch.matmul(attention, v)
# Combine heads
context = context.transpose(1, 2).reshape(batch_size, seq_len, dim)
return self.o_proj(context)
class LlamaDecoderLayer(nn.Module):
def __init__(self, dim, hidden_dim, num_heads):
super().__init__()
self.self_attn = LlamaAttention(dim, num_heads)
self.mlp = LlamaMLP(dim, hidden_dim)
self.input_layernorm = LlamaRMSNorm(dim)
self.post_attention_layernorm = LlamaRMSNorm(dim)
def forward(self, x):
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x)
x = x + residual
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = x + residual
return x
class LlamaModel(nn.Module):
def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
super().__init__()
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([
LlamaDecoderLayer(dim, hidden_dim, num_heads) for _ in range(num_layers)
])
self.norm = LlamaRMSNorm(dim)
self.rotary_emb = LlamaRotaryEmbedding(dim)
def forward(self, x):
x = self.embed_tokens(x)
for layer in self.layers:
x = layer(x)
return self.norm(x)
class LlamaForCausalLM(nn.Module):
def __init__(self, vocab_size, dim, num_layers, hidden_dim, num_heads):
super().__init__()
self.model = LlamaModel(vocab_size, dim, num_layers, hidden_dim, num_heads)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
def forward(self, x):
x = self.model(x)
return self.lm_head(x)
def get_model(tokenizer):
vocab_size = tokenizer.vocab_size # Use actual tokenizer vocab size
return LlamaForCausalLM(
vocab_size=vocab_size,
dim=576,
num_layers=30,
hidden_dim=1536,
num_heads=8
)
# model = get_model()
# print(model)
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, get_scheduler
from torch.optim import AdamW
#import wandb
import os
#from model import get_model
#wandb.init(project="smollm-training", name="llama-smollm-corpus")
BATCH_SIZE = 8
SEQ_LEN = 256
LEARNING_RATE = 1e-4
EPOCHS = 5
WARMUP_STEPS = 1000
GRADIENT_CLIP_VAL = 1.0
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
def generate_text(
model, tokenizer, prompt, max_length=50, temperature=0.7, top_k=50, device=DEVICE
):
model.eval()
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
with torch.no_grad():
for _ in range(max_length):
outputs = model(input_ids)
next_token_logits = outputs[:, -1, :] / temperature
# Apply top-k sampling
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)
# Sample from the filtered distribution
next_token_idx = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices[0, next_token_idx[0]]
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
model.train()
return generated_text
def save_checkpoint(model, optimizer, scheduler, epoch, step, loss, path):
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict() if scheduler else None,
"loss": loss,
"step": step,
},
path,
)
def load_checkpoint(path, model, optimizer, scheduler):
if os.path.exists(path):
# path = './checkpoints/checkpoint_step_5000.pt'
# print(f"Loading checkpoint from {path}")
checkpoint = torch.load(path, weights_only=True)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if scheduler and checkpoint["scheduler_state_dict"]:
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
return checkpoint["epoch"], checkpoint["step"]
return 0, 0
def count_parameters(model):
"""Count the number of trainable parameters in the model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
if tokenizer.pad_token is None:
if tokenizer.eos_token:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.resize_token_embeddings(len(tokenizer))
dataset = load_dataset(
"HuggingFaceTB/smollm-corpus", "cosmopedia-v2", streaming=True, split="train"
)
def tokenize_function(examples):
return tokenizer(
examples["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"
)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
attention_mask = torch.tensor(
[item["attention_mask"] for item in batch], dtype=torch.long
)
labels = input_ids.clone()
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
train_loader = DataLoader(
tokenized_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn
)
# Initialize model, optimizer, and scheduler
model = get_model(tokenizer)
model.to(DEVICE)
# Print model parameters
# total_params = count_parameters(model)
# print(f"\nModel Statistics:")
# print(f"Total Parameters: {total_params:,}")
# print(f"Model Size: {total_params * 4 / (1024 * 1024):.2f} MB") # Assuming float32 (4 bytes)
# print(f"Device: {DEVICE}")
# print(f"Batch Size: {BATCH_SIZE}")
# print(f"Sequence Length: {SEQ_LEN}")
# print(f"Learning Rate: {LEARNING_RATE}")
# print("-" * 50 + "\n")
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=LEARNING_RATE,
total_steps=10000,
pct_start=0.1,
anneal_strategy="cos",
cycle_momentum=False,
)
# Load checkpoint if exists
start_epoch, global_step = load_checkpoint(
f"{CHECKPOINT_DIR}/latest_checkpoint.pt", model, optimizer, lr_scheduler
)
# Sample prompts for evaluation
sample_prompts = [
"The future of artificial intelligence",
"The most important thing in life",
"The best way to learn programming",
]
model.train()
try:
for epoch in range(start_epoch, EPOCHS):
print(f"Epoch {epoch + 1}/{EPOCHS}")
for step, batch in enumerate(train_loader, start=global_step):
# Move batch to device
input_ids = batch["input_ids"].to(DEVICE)
attention_mask = batch["attention_mask"].to(DEVICE)
labels = batch["labels"].to(DEVICE)
# Forward pass
outputs = model(input_ids)
logits = outputs.view(-1, tokenizer.vocab_size)
# Calculate loss with label smoothing
loss = torch.nn.functional.cross_entropy(
logits, labels.view(-1), label_smoothing=0.1 # Add label smoothing
)
# Backward pass with gradient clipping
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
optimizer.step()
lr_scheduler.step()
# Logging
if step % 10 == 0:
print(
f"Step {step}, Loss: {loss.item():.4f}, LR: {lr_scheduler.get_last_lr()[0]:.2e}"
)
#wandb.log(
# {
# "loss": loss.item(),
# "lr": lr_scheduler.get_last_lr()[0],
# "step": step,
# "epoch": epoch,
# }
#)
# Save checkpoint every 100 steps
#if step % 100 == 0:
# save_checkpoint(
# model,
# optimizer,
# lr_scheduler,
# epoch,
# step,
# loss.item(),
# f"{CHECKPOINT_DIR}/latest_checkpoint.pt",
# )
# Also save numbered checkpoint every 1000 steps
if step % 1000 == 0:
save_checkpoint(
model,
optimizer,
lr_scheduler,
epoch,
step,
loss.item(),
f"{CHECKPOINT_DIR}/checkpoint_step.pt",
)
# Generate sample text every 500 steps with different temperatures
if step % 500 == 0:
print("\n=== Generating Sample Texts ===")
for temp in [0.7, 1.0]: # Try different temperatures
for prompt in sample_prompts:
generated = generate_text(
model,
tokenizer,
prompt,
temperature=temp,
max_length=100, # Increased max length
)
print(f"\nPrompt: {prompt}")
print(f"Temperature: {temp}")
print(f"Generated: {generated}")
#wandb.log(
# {
# f"generated_text_temp_{temp}_{prompt[:20]}": wandb.Html(
# generated
# )
# }
#)
print("\n=== End of Samples ===\n")
model.train()
# Save epoch checkpoint
#save_checkpoint(
# model,
# optimizer,
# lr_scheduler,
# epoch,
# step,
# loss.item(),
# f"{CHECKPOINT_DIR}/checkpoint_epoch.pt",
#)
except KeyboardInterrupt:
print("\nTraining interrupted! Saving checkpoint...")
save_checkpoint(
model,
optimizer,
lr_scheduler,
epoch,
step,
loss.item(),
f"{CHECKPOINT_DIR}/interrupted_checkpoint.pt",
)
print("Training complete!")
#wandb.finish()