File size: 7,971 Bytes
4255a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import numpy as np
import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from distiller.model2vec.distill.utils import select_optimal_device
from distiller.model2vec.model import StaticModel

if TYPE_CHECKING:
	from tokenizers import Tokenizer

logger = logging.getLogger(__name__)


class StaticModelFineTuner(nn.Module):
	def __init__(self, vectors: torch.Tensor, out_dim: int, pad_id: int) -> None:
		"""
		Initialize from a model.

		:param vectors: The vectors to use.
		:param out_dim: The output dimension.
		:param pad_id: The padding id.
		"""
		super().__init__()
		self.pad_id = pad_id
		norms = vectors.norm(dim=1)
		# Normalize the vectors
		vectors = vectors / norms[:, None]
		self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=False, padding_idx=pad_id)
		self.n_out = out_dim
		self.out_layer = nn.Linear(vectors.shape[1], self.n_out)
		weights = torch.Tensor(norms)
		weights[pad_id] = 0
		self.w = nn.Parameter(weights)

	def sub_forward(self, input_ids: torch.Tensor) -> torch.Tensor:
		"""Forward pass through the mean."""
		# Fix for index out of bounds issue - filter out invalid tokens
		valid_mask = (input_ids >= 0) & (input_ids < self.w.shape[0])
		if not valid_mask.all():
			input_ids = torch.where(valid_mask, input_ids, 0)
		w = self.w[input_ids]
		zeros = (input_ids != self.pad_id).float()
		w = w * zeros
		# Add a small epsilon to avoid division by zero
		length = zeros.sum(1) + 1e-16
		# Fix for embedding index out of bounds issue
		valid_emb_mask = (input_ids >= 0) & (input_ids < self.embeddings.num_embeddings)
		if not valid_emb_mask.all():
			input_ids = torch.where(valid_emb_mask, input_ids, 0)
		embedded = self.embeddings(input_ids)
		# Zero out the padding
		embedded = torch.bmm(w[:, None, :], embedded).squeeze(1)
		# Simulate actual mean
		return embedded / length[:, None]

	def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
		"""Forward pass through the mean, and a classifier layer after."""
		embedded = self.sub_forward(x)
		return self.out_layer(embedded), embedded

	@property
	def device(self) -> torch.device:
		"""Get the device of the model."""
		return self.embeddings.weight.device


class TextDataset(Dataset):
	def __init__(self, texts: list[str], targets: torch.Tensor, tokenizer: Tokenizer) -> None:
		"""
		Initialize the dataset.

		:param texts: The texts to tokenize.
		:param targets: The targets.
		:param tokenizer: The tokenizer to use.
		:raises ValueError: If the number of labels does not match the number of texts.
		"""
		if len(targets) != len(texts):
			msg = "Number of labels does not match number of texts."
			raise ValueError(msg)
		self.texts = [x[:20_000] for x in texts]
		self.tokenized_texts: list[list[int]] = [
			encoding.ids[:512] for encoding in tokenizer.encode_batch_fast(self.texts, add_special_tokens=False)
		]
		self.targets = targets
		self.tokenizer = tokenizer

	def __len__(self) -> int:
		"""Return the length of the dataset."""
		return len(self.tokenized_texts)

	def __getitem__(self, index: int) -> tuple[list[int], torch.Tensor]:
		"""Gets an item."""
		return self.tokenized_texts[index], self.targets[index]

	@staticmethod
	def collate_fn(batch: list[tuple[list[list[int]], int]]) -> tuple[torch.Tensor, torch.Tensor]:
		"""Collate function."""
		texts, targets = zip(*batch, strict=False)

		tensors = [torch.LongTensor(x).int() for x in texts]
		padded = pad_sequence(tensors, batch_first=True, padding_value=0)

		return padded, torch.stack(targets)

	def to_dataloader(self, shuffle: bool, batch_size: int = 32) -> DataLoader:
		"""Convert the dataset to a DataLoader."""
		return DataLoader(self, collate_fn=self.collate_fn, shuffle=shuffle, batch_size=batch_size)


def train_supervised(
	train_dataset: TextDataset,
	validation_dataset: TextDataset,
	model: StaticModel,
	patience: int | None = 5,
	device: str | None = None,
	batch_size: int = 256,
	lr: float = 1e-3,
) -> StaticModel:
	"""
	Train a tokenlearn model.

	:param train_dataset: The training dataset.
	:param validation_dataset: The validation dataset.
	:param model: The model to train.
	:param patience: The number of epochs to wait before early stopping.
	:param device: The device to train on.
	:param batch_size: The batch size.
	:param lr: The learning rate.
	:return: The trained model.
	"""
	device = select_optimal_device(device)
	train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size)

	# Initialize the model
	trainable_model = StaticModelFineTuner(
		torch.from_numpy(model.embedding),
		out_dim=train_dataset.targets.shape[1],
		pad_id=model.tokenizer.token_to_id("[PAD]"),
	)
	trainable_model.to(device)

	# Separate parameters for model and linear layer
	model_params = [
		*list(trainable_model.embeddings.parameters()),
		trainable_model.w,
		*list(trainable_model.out_layer.parameters()),
	]

	# Create optimizer with separate parameter groups
	optimizer = torch.optim.AdamW(params=model_params, lr=lr)

	lowest_loss = float("inf")
	param_dict = trainable_model.state_dict()
	curr_patience = patience
	stop = False

	criterion = nn.MSELoss()

	try:
		for epoch in range(100_000):
			logger.info(f"Epoch {epoch}")
			trainable_model.train()

			# Track train loss separately
			train_losses = []
			barred_train = tqdm(train_dataloader, desc=f"Epoch {epoch:03d} [Train]")

			for idx, (x, y) in enumerate(barred_train):
				optimizer.zero_grad()
				x = x.to(trainable_model.device)
				y_hat, _ = trainable_model(x)
				# Separate loss components
				train_loss = criterion(y_hat, y.to(trainable_model.device)).mean()

				# Apply weights
				train_loss.backward()

				optimizer.step()
				train_losses.append(train_loss.item())

				barred_train.set_description_str(f"Train Loss: {np.mean(train_losses[-10:]):.3f}")

				# Evaluate every 1000 steps and at the end of the epoch
				if (idx > 0 and idx % 1000 == 0) or idx == len(train_dataloader) - 1:
					trainable_model.eval()
					with torch.no_grad():
						validation_losses = []
						barred_val = tqdm(
							validation_dataset.to_dataloader(shuffle=False, batch_size=batch_size), desc="Validation"
						)
						for x_val, y_val in barred_val:
							x_val = x_val.to(trainable_model.device)
							y_hat_val, _ = trainable_model(x_val)
							val_loss = criterion(y_hat_val, y_val.to(trainable_model.device)).mean()
							validation_losses.append(val_loss.item())
							barred_val.set_description_str(f"Validation Loss: {np.mean(validation_losses):.3f}")

						validation_loss = np.mean(validation_losses)
					# Early stopping logic based on validation loss
					if patience is not None and curr_patience is not None:
						if (lowest_loss - validation_loss) > 1e-4:
							param_dict = trainable_model.state_dict()  # Save best model state based on training loss
							curr_patience = patience
							lowest_loss = validation_loss
						else:
							curr_patience -= 1
							if curr_patience == 0:
								stop = True
								break
						logger.info(f"Patience level: {patience - curr_patience}")
						logger.info(f"Validation loss: {validation_loss:.3f}")
						logger.info(f"Lowest loss: {lowest_loss:.3f}")

					trainable_model.train()

			if stop:
				logger.info("Early stopping")
				break

	except KeyboardInterrupt:
		logger.info("Training interrupted")

	trainable_model.eval()
	# Load best model based on training loss
	trainable_model.load_state_dict(param_dict)

	# Move the embeddings to the device (GPU)
	embeddings_weight = trainable_model.embeddings.weight.to(device)

	# Perform the forward pass on GPU
	with torch.no_grad():
		vectors = trainable_model.sub_forward(torch.arange(len(embeddings_weight))[:, None].to(device)).cpu().numpy()

	return StaticModel(vectors=vectors, tokenizer=model.tokenizer, config=model.config)