| import itertools |
| import typing |
|
|
| import hydra.utils |
| import lightning as L |
| import torch |
| import torch.nn.functional as F |
| import torchmetrics |
| import transformers |
|
|
| import dataloader |
| import models.dit |
| import noise_schedule |
|
|
|
|
| class MicroAveragingMetric(torchmetrics.Metric): |
| """Micro-averaging metric. |
| |
| Adapted from https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py#L12 |
| """ |
|
|
| def __init__(self, class_idx: typing.Optional[int] = 1, |
| dist_sync_on_step=False): |
| super().__init__(dist_sync_on_step=dist_sync_on_step) |
| self.class_idx = torch.tensor(class_idx) \ |
| if class_idx is not None else None |
| self.add_state("numerator", default=torch.tensor(0.0), |
| dist_reduce_fx="sum") |
| self.add_state("denominator", default=torch.tensor(0.0), |
| dist_reduce_fx="sum") |
|
|
| def _update( |
| self, numerator, denominator, preds, y) -> tuple: |
| raise NotImplementedError |
|
|
| def update(self, logits: torch.Tensor, y: torch.Tensor): |
| |
| preds = torch.argmax(logits, dim=-1) |
| y = y.view(-1) |
| assert preds.shape == y.shape, \ |
| f"preds shape {preds.shape} != y shape {y.shape}" |
| self.numerator, self.denominator = self._update( |
| self.numerator, self.denominator, preds, y) |
|
|
| def compute(self): |
| |
| value = self.numerator.float() / self.denominator \ |
| if self.denominator.item() > 0. else torch.tensor(0.0) |
| return value |
|
|
| def reset(self): |
| self.numerator = torch.tensor(0.0).to(self.device) |
| self.denominator = torch.tensor(0.0).to(self.device) |
|
|
|
|
| class CrossEntropy(MicroAveragingMetric): |
| """Calculates cross-entropy loss.""" |
| def _update( |
| self, numerator, denominator, logits, y) -> tuple: |
| with torch.no_grad(): |
| numerator += F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| y.view(-1), |
| ignore_index=-100, |
| reduction='sum') |
| denominator += y.numel() |
| return numerator, denominator |
|
|
| |
| def update(self, logits: torch.Tensor, y: torch.Tensor): |
| y = y.view(-1) |
| self.numerator, self.denominator = self._update( |
| self.numerator, self.denominator, logits, y) |
|
|
|
|
| class Accuracy(MicroAveragingMetric): |
| """Calculates accuracy. |
| |
| Can be used to calculate accuracy per class. |
| Copied from: |
| https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py |
| """ |
|
|
| def _update( |
| self, numerator, denominator, preds, y) -> tuple: |
| if self.class_idx is None: |
| numerator += (preds == y).sum() |
| denominator += y.numel() |
| else: |
| class_idx = self.class_idx |
| relevant_idxs = (y == class_idx) |
| numerator += (preds[relevant_idxs] == class_idx).sum() |
| denominator += relevant_idxs.sum() |
| relevant_idxs = (y != class_idx) |
| numerator += (preds[relevant_idxs] != class_idx).sum() |
| denominator += relevant_idxs.sum() |
| return numerator, denominator |
|
|
|
|
| class Precision(MicroAveragingMetric): |
| """Calculates precision. |
| |
| Can be used to calculate precision per class. |
| Adapted from: |
| https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py |
| """ |
|
|
| def _update(self, numerator, denominator, preds, y) -> tuple: |
| class_idx = self.class_idx |
| relevant_idxs = (preds == class_idx) |
| numerator += (y[relevant_idxs] == class_idx).sum() |
| denominator += relevant_idxs.sum() |
| return numerator, denominator |
|
|
|
|
| class Recall(MicroAveragingMetric): |
| """Calculate recall. |
| |
| Can be used to calculate recall per class. |
| Adapted from: |
| https://github.com/HazyResearch/hyena-dna/blob/main/src/tasks/metrics.py |
| """ |
|
|
| def _update(self, numerator, denominator, preds, y) -> tuple: |
| class_idx = self.class_idx |
| relevant_idxs = (y == class_idx) |
| numerator += (preds[relevant_idxs] == class_idx).sum() |
| denominator += relevant_idxs.sum() |
| return numerator, denominator |
|
|
|
|
| class Classifier(L.LightningModule): |
| def __init__( |
| self, |
| config, |
| tokenizer: transformers.PreTrainedTokenizer, |
| pretrained_backbone: typing.Optional[torch.nn.Module] = None): |
| super().__init__() |
| self.save_hyperparameters(ignore=['pretrained_backbone']) |
| self.config = config |
|
|
| |
| |
| self.is_eval_classifier = getattr( |
| config, 'is_eval_classifier', False) |
|
|
| self.tokenizer = tokenizer |
| self.vocab_size = tokenizer.vocab_size |
| self.antithetic_sampling = config.training.antithetic_sampling |
| self.importance_sampling = config.training.importance_sampling |
| self.change_of_variables = config.training.change_of_variables |
| if (not hasattr(self.tokenizer, 'mask_token') |
| or self.tokenizer.mask_token is None): |
| self.mask_index = self.vocab_size |
| self.vocab_size += 1 |
| else: |
| self.mask_index = self.tokenizer.mask_token_id |
|
|
| if config.classifier_backbone == 'dit': |
| self.classifier_model = models.dit.DITClassifier( |
| self.config, vocab_size=self.vocab_size) |
| elif self.config.classifier_backbone == 'dimamba': |
| self.classifier_model = models.dimamba.DiMambaClassifier( |
| self.config, vocab_size=self.vocab_size, |
| pad_token_id=self.tokenizer.pad_token_id) |
| elif config.classifier_backbone == 'hyenadna': |
| hyena_config = transformers.AutoConfig.from_pretrained( |
| config.classifier_model.hyena_model_name_or_path, |
| n_layer=config.classifier_model.n_layer, |
| trust_remote_code=True |
| ) |
| self.classifier_model = transformers.AutoModelForSequenceClassification.from_config( |
| hyena_config, |
| pretrained=False, |
| num_labels=config.data.num_classes, |
| problem_type='single_label_classification', |
| trust_remote_code=True |
| ) |
| else: |
| raise NotImplementedError( |
| f"Classifier backbone " |
| f"{self.config.classifier_backbone} not " |
| f"implemented.") |
| if pretrained_backbone is not None: |
| self.classifier_model.load_pretrained_encoder( |
| pretrained_backbone) |
| |
| metrics = torchmetrics.MetricCollection({ |
| 'cross_entropy': CrossEntropy(), |
| 'accuracy': Accuracy(class_idx=None), |
| }) |
| if config.data.num_classes > 2: |
| for c in range(config.data.num_classes): |
| metrics.add_metrics( |
| {f"accuracy_class{c}": Accuracy(class_idx=c), |
| f"precision_class{c}": Precision(class_idx=c), |
| f"recall_class{c}": Recall(class_idx=c)}) |
| else: |
| metrics.add_metrics( |
| {'precision': Precision(class_idx=1), |
| 'recall': Recall(class_idx=1)}) |
| metrics.set_dtype(torch.float64) |
| self.train_metrics = metrics.clone(prefix='train/') |
| self.valid_metrics = metrics.clone(prefix='val/') |
|
|
| self.T = config.T |
| self.noise = noise_schedule.get_noise(config, |
| dtype=self.dtype) |
| self.sampling_eps = config.training.sampling_eps |
| self.lr = config.optim.lr |
| self.time_conditioning = config.time_conditioning |
| self.fast_forward_epochs = None |
| self.fast_forward_batches = None |
|
|
| def on_load_checkpoint(self, checkpoint): |
| |
| |
| self.fast_forward_epochs = checkpoint['loops'][ |
| 'fit_loop']['epoch_progress']['current']['completed'] |
| self.fast_forward_batches = checkpoint['loops'][ |
| 'fit_loop']['epoch_loop.batch_progress'][ |
| 'current']['completed'] |
|
|
| def on_save_checkpoint(self, checkpoint): |
| |
| |
| |
| |
| |
| checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.batch_progress']['total'][ |
| 'completed'] = checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.automatic_optimization.optim_progress'][ |
| 'optimizer']['step']['total'][ |
| 'completed'] * self.trainer.accumulate_grad_batches |
| checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.batch_progress']['current'][ |
| 'completed'] = checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.automatic_optimization.optim_progress'][ |
| 'optimizer']['step']['current'][ |
| 'completed'] * self.trainer.accumulate_grad_batches |
| |
| |
| |
| |
| checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.state_dict'][ |
| '_batches_that_stepped'] = \ |
| checkpoint['loops']['fit_loop'][ |
| 'epoch_loop.automatic_optimization.optim_progress'][ |
| 'optimizer']['step']['total']['completed'] |
| if 'sampler' not in checkpoint.keys(): |
| checkpoint['sampler'] = {} |
| if hasattr(self.trainer.train_dataloader.sampler, |
| 'state_dict'): |
| sampler_state_dict = self.trainer. \ |
| train_dataloader.sampler.state_dict() |
| checkpoint['sampler'][ |
| 'random_state'] = sampler_state_dict.get( |
| 'random_state', None) |
| else: |
| checkpoint['sampler']['random_state'] = None |
|
|
| def on_train_start(self): |
| |
| |
| distributed = ( |
| self.trainer._accelerator_connector.use_distributed_sampler |
| and self.trainer._accelerator_connector.is_distributed) |
| if distributed: |
| sampler_cls = dataloader.FaultTolerantDistributedSampler |
| else: |
| sampler_cls = dataloader.RandomFaultTolerantSampler |
| updated_dls = [] |
| for dl in self.trainer.fit_loop._combined_loader.flattened: |
| if hasattr(dl.sampler, 'shuffle'): |
| dl_sampler = sampler_cls( |
| dl.dataset, shuffle=dl.sampler.shuffle) |
| else: |
| dl_sampler = sampler_cls(dl.dataset) |
| if (distributed |
| and self.fast_forward_epochs is not None |
| and self.fast_forward_batches is not None): |
| dl_sampler.load_state_dict({ |
| 'epoch': self.fast_forward_epochs, |
| 'counter': (self.fast_forward_batches |
| * self.config.loader.batch_size)}) |
| updated_dls.append( |
| torch.utils.data.DataLoader( |
| dl.dataset, |
| batch_size=self.config.loader.batch_size, |
| num_workers=self.config.loader.num_workers, |
| pin_memory=self.config.loader.pin_memory, |
| sampler=dl_sampler, |
| shuffle=False, |
| persistent_workers=self.config.loader.persistent_workers |
| )) |
| self.trainer.fit_loop._combined_loader.flattened = updated_dls |
|
|
| def forward(self, x, sigma=None, x_emb=None, attention_mask=None): |
| """Returns logits. |
| |
| x_emb can be provided during PPLM / NoS-style guidance |
| (see: https://arxiv.org/abs/2305.20009). |
| """ |
| if self.is_eval_classifier: |
| logits = self.classifier_model(x) |
| if hasattr(logits, 'logits'): |
| logits = logits.logits |
| else: |
| sigma = self._process_sigma(sigma) if sigma is not None else sigma |
| with torch.cuda.amp.autocast(dtype=torch.float32): |
| logits = self.classifier_model(x, sigma, x_emb=x_emb, attention_mask=attention_mask) |
| return logits |
|
|
| def get_log_probs(self, x, sigma, x_emb=None): |
| """Returns log probabilities. |
| Use for CBG-style guidance. |
| """ |
| if self.is_eval_classifier: |
| raise NotImplementedError( |
| '`get_log_prob` not implemented for classifiers ' |
| 'that are meant to be used for evaluation purposes ' |
| 'only.') |
| with torch.cuda.amp.autocast(dtype=torch.float32): |
| return torch.nn.functional.log_softmax( |
| self.forward(x, sigma, x_emb=x_emb), dim=-1) |
|
|
| def training_step(self, batch, batch_idx): |
| loss = self._compute_loss(batch, prefix='train') |
| self.log(name='trainer/loss', |
| value=loss.item(), |
| on_step=True, |
| on_epoch=False, |
| sync_dist=True, |
| prog_bar=True) |
| self.log(name='lr', |
| value= |
| self.trainer.optimizers[0].param_groups[0][ |
| 'lr'], |
| on_step=True, |
| on_epoch=False, |
| sync_dist=True, |
| prog_bar=True, logger=False) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| return self._compute_loss(batch, prefix='val') |
|
|
| def configure_optimizers(self): |
| |
| |
| |
| |
| optimizer = torch.optim.AdamW( |
| itertools.chain(self.classifier_model.parameters(), |
| self.noise.parameters()), |
| lr=self.config.optim.lr, |
| betas=(self.config.optim.beta1, |
| self.config.optim.beta2), |
| eps=self.config.optim.eps, |
| weight_decay=self.config.optim.weight_decay) |
|
|
| scheduler = hydra.utils.instantiate( |
| self.config.lr_scheduler, optimizer=optimizer) |
| scheduler_dict = { |
| 'scheduler': scheduler, |
| 'interval': 'step', |
| 'monitor': 'val/loss', |
| 'name': 'trainer/lr', |
| } |
| return [optimizer], [scheduler_dict] |
|
|
| def _q_xt(self, x, move_chance): |
| """Computes the noisy sample xt. |
| |
| Args: |
| x: int torch.Tensor with shape (batch_size, |
| diffusion_model_input_length), input. |
| move_chance: float torch.Tensor with shape |
| (batch_size, 1). |
| """ |
| move_indices = torch.rand( |
| *x.shape, device=x.device) < move_chance |
| if self.config.diffusion == 'absorbing_state': |
| return torch.where(move_indices, self.mask_index, x) |
| if self.config.diffusion == 'uniform': |
| uniform_tensor = torch.randint( |
| 0, self.vocab_size, x.shape, device=x.device) |
| return torch.where(move_indices, uniform_tensor, x) |
| raise NotImplementedError( |
| f'Diffusion type {self.config.diffusion} not ' |
| 'implemented.') |
|
|
| def _compute_loss(self, batch, prefix): |
| x0 = batch['input_ids'] |
| attention_mask = batch['attention_mask'] |
| t = None |
| if self.is_eval_classifier: |
| logits = self.forward(x0) |
| elif self.config.parameterization == 'ar': |
| |
| logits = self.forward( |
| x0, attention_mask=attention_mask) |
| else: |
| t = self._sample_t(x0.shape[0]) |
| if self.T > 0: |
| t = (t * self.T).to(torch.int) |
| t = t / self.T |
| |
| t += (1 / self.T) |
| if self.change_of_variables: |
| time_conditioning = t[:, None] |
| f_T = torch.log1p(- torch.exp(- self.noise.sigma_max)) |
| f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min)) |
| move_chance = torch.exp(f_0 + t * (f_T - f_0)) |
| move_chance = move_chance[:, None] |
| else: |
| sigma, _ = self.noise(t) |
| time_conditioning = sigma[:, None] |
| move_chance = 1 - torch.exp(-sigma[:, None]) |
|
|
| xt = self._q_xt(x0, move_chance) |
| logits = self.forward(xt, time_conditioning, attention_mask=attention_mask) |
| if hasattr(self.config.data, 'label_col'): |
| if f"{self.config.data.label_col}_threshold" in batch: |
| y = batch[f"{self.config.data.label_col}_threshold"] |
| else: |
| y = batch[self.config.data.label_col] |
| else: |
| y = batch['label'] |
| if (not self.is_eval_classifier |
| and getattr(self.config.training, 'use_label_smoothing', False)): |
| |
| labels = (torch.nn.functional.one_hot(y, self.config.data.num_classes) * (1 - t)[..., None] + |
| (1 / self.config.data.num_classes) * t[..., None]) |
| else: |
| labels = y.view(-1) |
| if getattr(self.config, 'is_fudge_classifier', False): |
| expanded_y = y.unsqueeze(1).expand(-1, logits.shape[1]) |
| logits = logits.view(-1, self.config.data.num_classes)[attention_mask.flatten()==1, ...] |
| y = expanded_y.flatten().long()[attention_mask.flatten()==1] |
| loss = torch.nn.functional.cross_entropy( |
| logits, |
| y, |
| ignore_index=-100, |
| reduction='mean') |
| else: |
| loss = torch.nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| labels, |
| ignore_index=-100, |
| reduction='mean') |
|
|
| if prefix == 'train': |
| self.train_metrics.update(logits, y) |
| metrics = self.train_metrics |
| elif prefix == 'val': |
| self.valid_metrics.update(logits, y) |
| metrics = self.valid_metrics |
| elif prefix == 'test': |
| self.test_metrics.update(logits, y) |
| metrics = self.test_metrics |
| else: |
| raise ValueError(f'Invalid prefix: {prefix}') |
|
|
| self.log_dict(metrics, |
| on_step=False, |
| on_epoch=True, |
| sync_dist=True) |
| return loss |
|
|
| def _sample_t(self, n): |
| _eps_t = torch.rand(n, device=self.device) |
| if self.antithetic_sampling: |
| offset = torch.arange(n, device=self.device) / n |
| _eps_t = (_eps_t / n + offset) % 1 |
| t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps |
| if self.importance_sampling: |
| return self.noise.importance_sampling_transformation( |
| t) |
| return t |
|
|
| def _process_sigma(self, sigma): |
| if sigma.ndim > 1: |
| sigma = sigma.squeeze(-1) |
| if not self.time_conditioning: |
| sigma = torch.zeros_like(sigma) |
| assert sigma.ndim == 1, sigma.shape |
| return sigma |
|
|