Sarthak
chore: moved tokenlearn as in internal package
4255a26
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from distiller.model2vec.distill.utils import select_optimal_device
from distiller.model2vec.model import StaticModel
if TYPE_CHECKING:
from tokenizers import Tokenizer
logger = logging.getLogger(__name__)
class StaticModelFineTuner(nn.Module):
def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None:
"""
Initialize from a model.
:param vectors: The vectors to use.
:param out_dim: The output dimension.
:param pad_id: The padding id.
"""
super().__init__()
self.pad_id = pad_id
norms = vectors.norm(dim=1)
# Normalize the vectors
vectors = vectors / norms[:, None]
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
self.n_out = out_dim
self.out_layer = nn.Linear(vectors.shape[1], self.n_out)
weights = torch.Tensor(norms)
weights[pad_id] = 0
self.w = nn.Parameter(weights)
def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Forward pass through the mean."""
# Fix for index out of bounds issue - filter out invalid tokens
valid_mask = (input_ids >= 0) & (input_ids < self.w.shape[0])
if not valid_mask.all():
input_ids = torch.where(valid_mask, input_ids, 0)
w = self.w[input_ids]
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
# Fix for embedding index out of bounds issue
valid_emb_mask = (input_ids >= 0) & (input_ids < self.embeddings.num_embeddings)
if not valid_emb_mask.all():
input_ids = torch.where(valid_emb_mask, input_ids, 0)
embedded = self.embeddings(input_ids)
# Zero out the padding
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# Simulate actual mean
return embedded / length[:, None]
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean, and a classifier layer after."""
embedded = self.sub_forward(x)
return self.out_layer(embedded), embedded
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return self.embeddings.weight.device
class TextDataset(Dataset):
def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
"""
Initialize the dataset.
:param texts: The texts to tokenize.
:param targets: The targets.
:param tokenizer: The tokenizer to use.
:raises ValueError: If the number of labels does not match the number of texts.
"""
if len(targets) != len(texts):
msg = "Number of labels does not match number of texts."
raise ValueError(msg)
self.texts = [x[:20_000] for x in texts]
self.tokenized_texts: list[list[int]] = [
encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False)
]
self.targets = targets
self.tokenizer = tokenizer
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.tokenized_texts)
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
"""Gets an item."""
return self.tokenized_texts[index], self.targets[index]
@staticmethod
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function."""
texts, targets = zip(*batch, strict=False)
tensors = [torch.LongTensor(x).int() for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
return padded, torch.stack(targets)
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
"""Convert the dataset to a DataLoader."""
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
def train_supervised(
train_dataset: TextDataset,
validation_dataset: TextDataset,
model: StaticModel,
patience: int | None = 5,
device: str | None = None,
batch_size: int = 256,
lr: float = 1e-3,
) -> StaticModel:
"""
Train a tokenlearn model.
:param train_dataset: The training dataset.
:param validation_dataset: The validation dataset.
:param model: The model to train.
:param patience: The number of epochs to wait before early stopping.
:param device: The device to train on.
:param batch_size: The batch size.
:param lr: The learning rate.
:return: The trained model.
"""
device = select_optimal_device(device)
train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
# Initialize the model
trainable_model = StaticModelFineTuner(
torch.from_numpy(model.embedding),
out_dim=train_dataset.targets.shape[1],
pad_id=model.tokenizer.token_to_id("[PAD]"),
)
trainable_model.to(device)
# Separate parameters for model and linear layer
model_params = [
*list(trainable_model.embeddings.parameters()),
trainable_model.w,
*list(trainable_model.out_layer.parameters()),
]
# Create optimizer with separate parameter groups
optimizer = torch.optim.AdamW(params=model_params, lr=lr)
lowest_loss = float("inf")
param_dict = trainable_model.state_dict()
curr_patience = patience
stop = False
criterion = nn.MSELoss()
try:
for epoch in range(100_000):
logger.info(f"Epoch {epoch}")
trainable_model.train()
# Track train loss separately
train_losses = []
barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]")
for idx, (x, y) in enumerate(barred_train):
optimizer.zero_grad()
x = x.to(trainable_model.device)
y_hat, _ = trainable_model(x)
# Separate loss components
train_loss = criterion(y_hat, y.to(trainable_model.device)).mean()
# Apply weights
train_loss.backward()
optimizer.step()
train_losses.append(train_loss.item())
barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}")
# Evaluate every 1000 steps and at the end of the epoch
if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1:
trainable_model.eval()
with torch.no_grad():
validation_losses = []
barred_val = tqdm(
validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation"
)
for x_val, y_val in barred_val:
x_val = x_val.to(trainable_model.device)
y_hat_val, _ = trainable_model(x_val)
val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean()
validation_losses.append(val_loss.item())
barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}")
validation_loss = np.mean(validation_losses)
# Early stopping logic based on validation loss
if patience is not None and curr_patience is not None:
if (lowest_loss - validation_loss) > 1e-4:
param_dict = trainable_model.state_dict() # Save best model state based on training loss
curr_patience = patience
lowest_loss = validation_loss
else:
curr_patience -= 1
if curr_patience == 0:
stop = True
break
logger.info(f"Patience level: {patience - curr_patience}")
logger.info(f"Validation loss: {validation_loss:.3f}")
logger.info(f"Lowest loss: {lowest_loss:.3f}")
trainable_model.train()
if stop:
logger.info("Early stopping")
break
except KeyboardInterrupt:
logger.info("Training interrupted")
trainable_model.eval()
# Load best model based on training loss
trainable_model.load_state_dict(param_dict)
# Move the embeddings to the device (GPU)
embeddings_weight = trainable_model.embeddings.weight.to(device)
# Perform the forward pass on GPU
with torch.no_grad():
vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy()
return StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)