File size: 7,971 Bytes
4255a26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from distiller.model2vec.distill.utils import select_optimal_device
from distiller.model2vec.model import StaticModel
if TYPE_CHECKING:
from tokenizers import Tokenizer
logger = logging.getLogger(__name__)
class StaticModelFineTuner(nn.Module):
def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None:
"""
Initialize from a model.
:param vectors: The vectors to use.
:param out_dim: The output dimension.
:param pad_id: The padding id.
"""
super().__init__()
self.pad_id = pad_id
norms = vectors.norm(dim=1)
# Normalize the vectors
vectors = vectors / norms[:, None]
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
self.n_out = out_dim
self.out_layer = nn.Linear(vectors.shape[1], self.n_out)
weights = torch.Tensor(norms)
weights[pad_id] = 0
self.w = nn.Parameter(weights)
def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Forward pass through the mean."""
# Fix for index out of bounds issue - filter out invalid tokens
valid_mask = (input_ids >= 0) & (input_ids < self.w.shape[0])
if not valid_mask.all():
input_ids = torch.where(valid_mask, input_ids, 0)
w = self.w[input_ids]
zeros = (input_ids != self.pad_id).float()
w = w * zeros
# Add a small epsilon to avoid division by zero
length = zeros.sum(1) + 1e-16
# Fix for embedding index out of bounds issue
valid_emb_mask = (input_ids >= 0) & (input_ids < self.embeddings.num_embeddings)
if not valid_emb_mask.all():
input_ids = torch.where(valid_emb_mask, input_ids, 0)
embedded = self.embeddings(input_ids)
# Zero out the padding
embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
# Simulate actual mean
return embedded / length[:, None]
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass through the mean, and a classifier layer after."""
embedded = self.sub_forward(x)
return self.out_layer(embedded), embedded
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return self.embeddings.weight.device
class TextDataset(Dataset):
def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
"""
Initialize the dataset.
:param texts: The texts to tokenize.
:param targets: The targets.
:param tokenizer: The tokenizer to use.
:raises ValueError: If the number of labels does not match the number of texts.
"""
if len(targets) != len(texts):
msg = "Number of labels does not match number of texts."
raise ValueError(msg)
self.texts = [x[:20_000] for x in texts]
self.tokenized_texts: list[list[int]] = [
encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False)
]
self.targets = targets
self.tokenizer = tokenizer
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.tokenized_texts)
def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
"""Gets an item."""
return self.tokenized_texts[index], self.targets[index]
@staticmethod
def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
"""Collate function."""
texts, targets = zip(*batch, strict=False)
tensors = [torch.LongTensor(x).int() for x in texts]
padded = pad_sequence(tensors, batch_first=True, padding_value=0)
return padded, torch.stack(targets)
def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
"""Convert the dataset to a DataLoader."""
return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)
def train_supervised(
train_dataset: TextDataset,
validation_dataset: TextDataset,
model: StaticModel,
patience: int | None = 5,
device: str | None = None,
batch_size: int = 256,
lr: float = 1e-3,
) -> StaticModel:
"""
Train a tokenlearn model.
:param train_dataset: The training dataset.
:param validation_dataset: The validation dataset.
:param model: The model to train.
:param patience: The number of epochs to wait before early stopping.
:param device: The device to train on.
:param batch_size: The batch size.
:param lr: The learning rate.
:return: The trained model.
"""
device = select_optimal_device(device)
train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)
# Initialize the model
trainable_model = StaticModelFineTuner(
torch.from_numpy(model.embedding),
out_dim=train_dataset.targets.shape[1],
pad_id=model.tokenizer.token_to_id("[PAD]"),
)
trainable_model.to(device)
# Separate parameters for model and linear layer
model_params = [
*list(trainable_model.embeddings.parameters()),
trainable_model.w,
*list(trainable_model.out_layer.parameters()),
]
# Create optimizer with separate parameter groups
optimizer = torch.optim.AdamW(params=model_params, lr=lr)
lowest_loss = float("inf")
param_dict = trainable_model.state_dict()
curr_patience = patience
stop = False
criterion = nn.MSELoss()
try:
for epoch in range(100_000):
logger.info(f"Epoch {epoch}")
trainable_model.train()
# Track train loss separately
train_losses = []
barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]")
for idx, (x, y) in enumerate(barred_train):
optimizer.zero_grad()
x = x.to(trainable_model.device)
y_hat, _ = trainable_model(x)
# Separate loss components
train_loss = criterion(y_hat, y.to(trainable_model.device)).mean()
# Apply weights
train_loss.backward()
optimizer.step()
train_losses.append(train_loss.item())
barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}")
# Evaluate every 1000 steps and at the end of the epoch
if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1:
trainable_model.eval()
with torch.no_grad():
validation_losses = []
barred_val = tqdm(
validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation"
)
for x_val, y_val in barred_val:
x_val = x_val.to(trainable_model.device)
y_hat_val, _ = trainable_model(x_val)
val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean()
validation_losses.append(val_loss.item())
barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}")
validation_loss = np.mean(validation_losses)
# Early stopping logic based on validation loss
if patience is not None and curr_patience is not None:
if (lowest_loss - validation_loss) > 1e-4:
param_dict = trainable_model.state_dict() # Save best model state based on training loss
curr_patience = patience
lowest_loss = validation_loss
else:
curr_patience -= 1
if curr_patience == 0:
stop = True
break
logger.info(f"Patience level: {patience - curr_patience}")
logger.info(f"Validation loss: {validation_loss:.3f}")
logger.info(f"Lowest loss: {lowest_loss:.3f}")
trainable_model.train()
if stop:
logger.info("Early stopping")
break
except KeyboardInterrupt:
logger.info("Training interrupted")
trainable_model.eval()
# Load best model based on training loss
trainable_model.load_state_dict(param_dict)
# Move the embeddings to the device (GPU)
embeddings_weight = trainable_model.embeddings.weight.to(device)
# Perform the forward pass on GPU
with torch.no_grad():
vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy()
return StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)
|