| | from transformers.trainer import * |
| |
|
| | class DistributedTrainer(Trainer): |
| | def _inner_training_loop( |
| | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None |
| | ): |
| | self.accelerator.free_memory() |
| | self._train_batch_size = batch_size |
| | if self.args.auto_find_batch_size: |
| | if self.state.train_batch_size != self._train_batch_size: |
| | from accelerate.utils import release_memory |
| |
|
| | (self.model_wrapped,) = release_memory(self.model_wrapped) |
| | self.model_wrapped = self.model |
| |
|
| | |
| | if self.is_deepspeed_enabled: |
| | |
| | original_bs = self.args.per_device_train_batch_size |
| | self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu) |
| | self.propagate_args_to_deepspeed(True) |
| | self.args.per_device_train_batch_size = original_bs |
| | self.state.train_batch_size = self._train_batch_size |
| | logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") |
| | |
| | train_dataloader = self.get_train_dataloader() |
| | if self.is_fsdp_xla_v2_enabled: |
| | train_dataloader = tpu_spmd_dataloader(train_dataloader) |
| |
|
| | |
| | |
| | |
| | |
| | total_train_batch_size = self.get_total_train_batch_size(args) |
| |
|
| | ( |
| | num_train_epochs, |
| | num_update_steps_per_epoch, |
| | num_examples, |
| | num_train_samples, |
| | epoch_based, |
| | len_dataloader, |
| | max_steps, |
| | ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size) |
| |
|
| | num_train_tokens = None |
| | if self.args.include_tokens_per_second: |
| | num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps) |
| | |
| | if len_dataloader is not None and epoch_based: |
| | num_train_tokens *= args.num_train_epochs |
| | |
| | else: |
| | num_train_tokens *= args.gradient_accumulation_steps |
| |
|
| | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: |
| | if self.args.n_gpu > 1: |
| | |
| | |
| | raise ValueError( |
| | "Currently --debug underflow_overflow is not supported under DP. Please use DDP" |
| | " (torchrun or torch.distributed.launch (deprecated))." |
| | ) |
| | else: |
| | DebugUnderflowOverflow(self.model) |
| |
|
| | delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled |
| |
|
| | |
| | is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2) |
| | if is_fsdp2: |
| | delay_optimizer_creation = False |
| |
|
| | |
| | if self._created_lr_scheduler: |
| | self.lr_scheduler = None |
| | self._created_lr_scheduler = False |
| |
|
| | if self.is_deepspeed_enabled: |
| | self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) |
| |
|
| | if not delay_optimizer_creation: |
| | self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
| |
|
| | self.state = TrainerState( |
| | stateful_callbacks=[ |
| | cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) |
| | ] |
| | ) |
| | self.state.is_hyper_param_search = trial is not None |
| | self.state.train_batch_size = self._train_batch_size |
| |
|
| | |
| | self.state.compute_steps(args, max_steps) |
| |
|
| | |
| | if args.gradient_checkpointing: |
| | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) |
| |
|
| | model = self._wrap_model(self.model_wrapped) |
| |
|
| | |
| | |
| | |
| | use_accelerator_prepare = model is self.model |
| |
|
| | if use_accelerator_prepare and self.is_fsdp_enabled: |
| | |
| | |
| | self.model = unwrap_model(self.model, recursive=True) |
| |
|
| | if delay_optimizer_creation: |
| | if use_accelerator_prepare: |
| | |
| | self._fsdp_qlora_plugin_updates() |
| | if self.accelerator.mixed_precision != "fp8": |
| | self.model = self.accelerator.prepare(self.model) |
| | self.create_optimizer_and_scheduler(num_training_steps=max_steps) |
| |
|
| | |
| | use_accelerator_prepare = False |
| | if use_accelerator_prepare: |
| | self.model.train() |
| | if hasattr(self.lr_scheduler, "step"): |
| | if self.use_apex: |
| | model = self.accelerator.prepare(self.model) |
| | else: |
| | |
| | if self.is_tp_enabled: |
| | self.optimizer = self.accelerator.prepare(self.optimizer) |
| | else: |
| | model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) |
| | else: |
| | |
| | model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( |
| | self.model, self.optimizer, self.lr_scheduler |
| | ) |
| | else: |
| | self.optimizer = self.accelerator.prepare(self.optimizer) |
| |
|
| | if self.is_fsdp_enabled: |
| | self.model = self.model_wrapped = model |
| |
|
| | |
| | if model is not self.model: |
| | self.model_wrapped = model |
| |
|
| | |
| | if self.is_deepspeed_enabled: |
| | self.deepspeed = self.model_wrapped |
| |
|
| | |
| | if resume_from_checkpoint is not None: |
| | if self.is_deepspeed_enabled: |
| | deepspeed_load_checkpoint( |
| | self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) |
| | ) |
| | elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: |
| | self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) |
| |
|
| | |
| | self._load_optimizer_and_scheduler(resume_from_checkpoint) |
| | self._load_scaler(resume_from_checkpoint) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | logger.info("***** Running training *****") |
| | logger.info(f" Num examples = {num_examples:,}") |
| | logger.info(f" Num Epochs = {num_train_epochs:,}") |
| | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") |
| | if self.args.per_device_train_batch_size != self._train_batch_size: |
| | logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") |
| | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") |
| | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| | logger.info(f" Total optimization steps = {max_steps:,}") |
| | logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") |
| |
|
| | self.state.epoch = 0 |
| | start_time = time.time() |
| | epochs_trained = 0 |
| | steps_trained_in_current_epoch = 0 |
| |
|
| | |
| | if resume_from_checkpoint is not None and os.path.isfile( |
| | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) |
| | ): |
| | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) |
| | self.compare_trainer_and_checkpoint_args(self.args, self.state) |
| | self._load_callback_state() |
| | epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) |
| | if not args.ignore_data_skip: |
| | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) |
| | steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
| | else: |
| | steps_trained_in_current_epoch = 0 |
| |
|
| | logger.info(" Continuing training from checkpoint, will skip to saved global_step") |
| | logger.info(f" Continuing training from epoch {epochs_trained}") |
| | logger.info(f" Continuing training from global step {self.state.global_step}") |
| | if not args.ignore_data_skip: |
| | logger.info( |
| | f" Will skip the first {epochs_trained} epochs then the first" |
| | f" {steps_trained_in_current_epoch} batches in the first epoch." |
| | ) |
| |
|
| | |
| | for attr in ("model", "optimizer", "lr_scheduler"): |
| | setattr(self.callback_handler, attr, getattr(self, attr)) |
| | self.callback_handler.train_dataloader = train_dataloader |
| |
|
| | self.state.init_training_references(self, max_steps, num_train_epochs, trial) |
| |
|
| | |
| | tr_loss = torch.tensor(0.0, device=model.out_device) |
| | |
| | self._total_loss_scalar = 0.0 |
| | self._globalstep_last_logged = self.state.global_step |
| | model.zero_grad() |
| | grad_norm: Optional[float] = None |
| | learning_rate = None |
| | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) |
| |
|
| | if args.eval_on_start: |
| | self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) |
| |
|
| | for epoch in range(epochs_trained, num_train_epochs): |
| | epoch_dataloader = train_dataloader |
| | if hasattr(epoch_dataloader, "set_epoch"): |
| | epoch_dataloader.set_epoch(epoch) |
| |
|
| | |
| | if args.past_index >= 0: |
| | self._past = None |
| |
|
| | steps_in_epoch = ( |
| | len(epoch_dataloader) |
| | if len_dataloader is not None |
| | else args.max_steps * args.gradient_accumulation_steps |
| | ) |
| | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) |
| |
|
| | step = -1 |
| | rng_to_sync = False |
| |
|
| | |
| | if epoch == epochs_trained and resume_from_checkpoint is not None: |
| | if steps_trained_in_current_epoch > 0 and not args.ignore_data_skip: |
| | epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) |
| | step = steps_trained_in_current_epoch - 1 |
| | rng_to_sync = True |
| | elif steps_trained_in_current_epoch == 0: |
| | self._load_rng_state(resume_from_checkpoint) |
| |
|
| | epoch_iterator = iter(epoch_dataloader) |
| | |
| | remainder = steps_in_epoch % args.gradient_accumulation_steps |
| | if remainder == 0: |
| | remainder = args.gradient_accumulation_steps |
| | update_step = -1 |
| | total_updates = steps_in_epoch // args.gradient_accumulation_steps + int( |
| | remainder < args.gradient_accumulation_steps |
| | ) |
| | for _ in range(total_updates): |
| | update_step += 1 |
| | num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder |
| | batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) |
| | |
| | |
| | self.current_gradient_accumulation_steps = len(batch_samples) |
| | for i, inputs in enumerate(batch_samples): |
| | step += 1 |
| | do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch |
| | |
| | self.accelerator.gradient_state._set_sync_gradients(do_sync_step) |
| |
|
| | if self.args.include_num_input_tokens_seen not in ["no", False]: |
| | main_input_name = getattr(self.model, "main_input_name", "input_ids") |
| | if main_input_name not in inputs: |
| | logger.warning( |
| | "Tried to track the number of tokens seen, however the current model is " |
| | "not configured properly to know what item is the input. To fix this, add " |
| | "a `main_input_name` attribute to the model class you are using." |
| | ) |
| | else: |
| | if self.args.include_num_input_tokens_seen == "non_padding": |
| | if "attention_mask" in inputs: |
| | input_tokens = inputs["attention_mask"].sum() |
| | elif ( |
| | self.processing_class is not None |
| | and hasattr(self.processing_class, "pad_token_id") |
| | and self.processing_class.pad_token_id is not None |
| | ): |
| | input_tokens = ( |
| | inputs[main_input_name] != self.processing_class.pad_token_id |
| | ).sum() |
| | else: |
| | logger.warning( |
| | "Could not determine method to count non-padding tokens, falling back to counting all tokens." |
| | ) |
| | input_tokens = inputs[main_input_name].numel() |
| | else: |
| | input_tokens = inputs[main_input_name].numel() |
| |
|
| | input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) |
| | self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() |
| |
|
| | if rng_to_sync: |
| | self._load_rng_state(resume_from_checkpoint) |
| | rng_to_sync = False |
| |
|
| | if step % args.gradient_accumulation_steps == 0: |
| | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) |
| |
|
| | |
| | context = ( |
| | functools.partial(self.accelerator.no_sync, model=model) |
| | if i != len(batch_samples) - 1 |
| | and self.accelerator.distributed_type != DistributedType.DEEPSPEED |
| | else contextlib.nullcontext |
| | ) |
| | with context(): |
| | tr_loss_step = self.training_step(model, inputs, num_items_in_batch) |
| |
|
| | if ( |
| | args.logging_nan_inf_filter |
| | and not is_torch_xla_available() |
| | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) |
| | ): |
| | |
| | tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) |
| | else: |
| | if tr_loss.device != tr_loss_step.device: |
| | raise ValueError( |
| | f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" |
| | ) |
| | tr_loss = tr_loss + tr_loss_step |
| |
|
| | self.current_flos += float(self.floating_point_ops(inputs)) |
| |
|
| | if do_sync_step: |
| | |
| | self.accelerator.gradient_state._set_sync_gradients(True) |
| |
|
| | |
| | if args.max_grad_norm is not None and args.max_grad_norm > 0: |
| | if is_sagemaker_mp_enabled() and args.fp16: |
| | _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) |
| | elif self.use_apex: |
| | from apex import amp |
| |
|
| | |
| | _grad_norm = nn.utils.clip_grad_norm_( |
| | amp.master_params(self.optimizer), |
| | args.max_grad_norm, |
| | ) |
| | else: |
| | grad_norm_context = contextlib.nullcontext |
| | if self.is_tp_enabled: |
| | from torch.distributed._tensor.experimental import implicit_replication |
| |
|
| | grad_norm_context = implicit_replication |
| | with grad_norm_context(): |
| | _grad_norm = self.accelerator.clip_grad_norm_( |
| | model.parameters(), |
| | args.max_grad_norm, |
| | ) |
| |
|
| | if ( |
| | is_accelerate_available() |
| | and self.accelerator.distributed_type == DistributedType.DEEPSPEED |
| | ): |
| | grad_norm = model.get_global_grad_norm() |
| | |
| | if hasattr(grad_norm, "item"): |
| | grad_norm = grad_norm.item() |
| | else: |
| | grad_norm = _grad_norm |
| |
|
| | self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) |
| |
|
| | context = contextlib.nullcontext |
| | if self.is_tp_enabled: |
| | from torch.distributed._tensor.experimental import implicit_replication |
| |
|
| | context = implicit_replication |
| |
|
| | with context(): |
| | self.optimizer.step() |
| |
|
| | self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) |
| |
|
| | |
| | learning_rate = self._get_learning_rate() |
| |
|
| | if not self.accelerator.optimizer_step_was_skipped: |
| | |
| | if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
| | self.lr_scheduler.step() |
| |
|
| | model.zero_grad() |
| | self.state.global_step += 1 |
| | self.state.epoch = epoch + (step + 1) / steps_in_epoch |
| | self.control = self.callback_handler.on_step_end(args, self.state, self.control) |
| | self._maybe_log_save_evaluate( |
| | tr_loss, |
| | grad_norm, |
| | model, |
| | trial, |
| | epoch, |
| | ignore_keys_for_eval, |
| | start_time, |
| | learning_rate=learning_rate, |
| | ) |
| | else: |
| | self.control = self.callback_handler.on_substep_end(args, self.state, self.control) |
| |
|
| | |
| | |
| | |
| | if self.control.should_epoch_stop or self.control.should_training_stop: |
| | if is_torch_xla_available(): |
| | xm.mark_step() |
| | break |
| | |
| | if self.control.should_epoch_stop or self.control.should_training_stop: |
| | if is_torch_xla_available(): |
| | xm.mark_step() |
| | break |
| | if step < 0: |
| | logger.warning( |
| | "There seems not to be a single sample in your epoch_iterator, stopping training at step" |
| | f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" |
| | f" num_steps ({max_steps}) higher than the number of available samples." |
| | ) |
| | self.control.should_training_stop = True |
| |
|
| | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) |
| | self._maybe_log_save_evaluate( |
| | tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate |
| | ) |
| |
|
| | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: |
| | if is_torch_xla_available(): |
| | |
| | xm.master_print(met.metrics_report()) |
| | else: |
| | logger.warning( |
| | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " |
| | "configured. Check your training configuration if this is unexpected." |
| | ) |
| | if self.control.should_training_stop: |
| | break |
| |
|
| | if args.past_index and hasattr(self, "_past"): |
| | |
| | delattr(self, "_past") |
| |
|
| | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") |
| | if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: |
| | self._load_best_model() |
| |
|
| | |
| | self._total_loss_scalar += tr_loss.item() |
| | effective_global_step = max(self.state.global_step, 0.001) |
| | train_loss = self._total_loss_scalar / effective_global_step |
| |
|
| | metrics = speed_metrics( |
| | "train", |
| | start_time, |
| | num_samples=num_train_samples, |
| | num_steps=self.state.max_steps, |
| | num_tokens=num_train_tokens, |
| | ) |
| | self.store_flos() |
| | metrics["total_flos"] = self.state.total_flos |
| | metrics["train_loss"] = train_loss |
| |
|
| | self.is_in_train = False |
| |
|
| | self._memory_tracker.stop_and_update_metrics(metrics) |
| |
|
| | self.log(metrics) |
| |
|
| | run_dir = self._get_output_dir(trial) |
| | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) |
| |
|
| | |
| | if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: |
| | for checkpoint in checkpoints_sorted: |
| | if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): |
| | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") |
| | shutil.rmtree(checkpoint, ignore_errors=True) |
| |
|
| | self.control = self.callback_handler.on_train_end(args, self.state, self.control) |
| |
|
| | |
| | self._finish_current_push() |
| |
|
| | |
| | |
| | if self.neftune_noise_alpha is not None: |
| | self._deactivate_neftune(self.model) |
| |
|
| | return TrainOutput(self.state.global_step, train_loss, metrics) |