import argparse import hashlib import math import os import time from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn from accelerate.utils import set_seed from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType from torch.optim import Optimizer from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoProcessor, AutoTokenizer from datasets import Dataset from specforge import ( AutoDraftModelConfig, AutoEagle3DraftModel, OnlineEagle3Model, QwenVLOnlineEagle3Model, ) from specforge.args import SGLangBackendArgs, TrackerArgs from specforge.data import ( build_eagle3_dataset, build_offline_eagle3_dataset, generate_vocab_mapping_file, prepare_dp_dataloaders, ) from specforge.distributed import ( destroy_distributed, get_dp_group, get_draft_dp_group, get_tp_group, init_distributed, ) from specforge.modeling.target import ( Eagle3TargetModel, TargetHead, get_eagle3_target_model, ) from specforge.optimizer import BF16Optimizer from specforge.tracker import Tracker, create_tracker, get_tracker_class from specforge.utils import ( create_draft_config_from_target, get_last_checkpoint, print_args_with_dots, print_on_rank0, print_with_rank, rank_0_priority, safe_conversations_generator, ) def parse_args() -> Tuple[ArgumentParser, Namespace]: """ This function is used to parse the arguments for the training script. """ parser = argparse.ArgumentParser(description="Train Eagle3 with online data") # add model-related arguments model_group = parser.add_argument_group("model") model_group.add_argument("--target-model-path", type=str, required=True) model_group.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code" ) model_group.add_argument( "--draft-model-config", type=str, required=False, help="Draft model config path. If not provided, will auto-generate from target model.", ) model_group.add_argument( "--embedding-key", type=str, default="model.embed_tokens.weight", help="The key of the embedding weight to load from the target model", ) model_group.add_argument( "--lm-head-key", type=str, default="lm_head.weight", help="The key of the lm head weight to load from the target model, this is only required for offline training", ) model_group.add_argument( "--is-vlm", action="store_true", help="Whether the target model is a VLM" ) model_group.add_argument( "--target-model-backend", type=str, default="sglang", choices=["sglang", "hf", "custom"], help="The backend of the target model", ) # dataset arguments dataset_group = parser.add_argument_group("dataset") dataset_group.add_argument("--train-data-path", type=str, required=True) dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-data-path", type=str, default=None) dataset_group.add_argument("--chat-template", type=str, default="llama3") dataset_group.add_argument( "--is-preformatted", action="store_true", help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.", ) dataset_group.add_argument( "--train-only-last-turn", action="store_true", help="If set, only the last assistant turn in each conversation contributes to the loss. " "Useful for thinking models where conversation history may lack thought processes.", ) dataset_group.add_argument("--build-dataset-num-proc", type=int, default=8) dataset_group.add_argument( "--dataloader-num-workers", type=int, default=4, help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", ) # training hyper params training_group = parser.add_argument_group("training") training_group.add_argument("--num-epochs", type=int, default=10) training_group.add_argument( "--max-num-steps", type=int, default=None, help="The maximum number of steps to train. If not provided, will be calculated as num_epochs * steps_per_epoch", ) training_group.add_argument("--batch-size", type=int, default=1) training_group.add_argument("--learning-rate", type=float, default=1e-4) training_group.add_argument("--max-length", type=int, default=2048) training_group.add_argument("--warmup-ratio", type=float, default=0.015) training_group.add_argument( "--total-steps", type=int, default=None, help="Total training steps. If not provided, will be calculated as num_epochs * steps_per_epoch", ) training_group.add_argument("--max-grad-norm", type=float, default=0.5) training_group.add_argument( "--ttt-length", type=int, default=7, help="The length for Test-Time Training (TTT).", ) training_group.add_argument("--resume", action="store_true") training_group.add_argument( "--ckpt-dir", type=str, default=None, help="directory includes the checkpoint to start training with", ) training_group.add_argument("--eval-interval", type=int, default=5000) training_group.add_argument("--save-interval", type=int, default=5000) training_group.add_argument( "--log-interval", type=int, default=50, help="Log training metrics every N steps", ) training_group.add_argument("--seed", type=int, default=0) training_group.add_argument("--draft-accumulation-steps", type=int, default=1) # data processing type optimization_group = parser.add_argument_group("optimization") optimization_group.add_argument( "--tp-size", type=int, default=1, help="The size of the tensor parallel for the target model", ) # distributed training optimization_group.add_argument("--sp-ulysses-size", type=int, default=1) optimization_group.add_argument("--sp-ring-size", type=int, default=1) optimization_group.add_argument( "--attention-backend", type=str, default="flex_attention", help="The attention backend for the draft model", ) # other args other_group = parser.add_argument_group("others") other_group.add_argument("--cache-key", type=str, default=None) other_group.add_argument("--cache-dir", type=str, default="./cache") other_group.add_argument("--output-dir", type=str, required=True) other_group.add_argument("--verbose", action="store_true") other_group.add_argument( "--dist-timeout", type=int, default=20, help="Timeout for collective communication in minutes", ) other_group.add_argument( "--model-download-dir", type=str, default=None, help="The directory to download the target model to", ) # vlm related args vlm_group = parser.add_argument_group("vlm") vlm_group.add_argument( "--min-pixels", type=int, default=50176 ) # 64*28*28 for qwen2.5-vl vlm_group.add_argument( "--max-pixels", type=int, default=802816 ) # 1024*28*28 for qwen2.5-vl # profiling related args profiling_group = parser.add_argument_group("profiling") profiling_group.add_argument("--profile", action="store_true") profiling_group.add_argument("--profile-start-step", type=int, default=30) profiling_group.add_argument("--profile-num-steps", type=int, default=4) profiling_group.add_argument("--profile-record-shapes", action="store_true") # sglang target model backend related args sglang_group = parser.add_argument_group("sglang target model backend") SGLangBackendArgs.add_args(sglang_group) # tracker related args tracker_group = parser.add_argument_group("tracker") TrackerArgs.add_args(tracker_group) args = parser.parse_args() return parser, args def build_tracker(args: Namespace, parser: ArgumentParser) -> Tracker: """ Build the experiment tracker according to the report_to argument. Args: args: The arguments for the training script. parser: The parser for the training script. Returns: The experiment tracker. """ tracker_class = get_tracker_class(args.report_to) if tracker_class: tracker_class.validate_args(parser, args) else: parser.error(f"Unknown tracker: {args.report_to}") tracker = create_tracker(args, args.output_dir) return tracker def build_target_model( args: Namespace, draft_model_config: AutoDraftModelConfig, is_online: bool = True ) -> Tuple[Union[Eagle3TargetModel, TargetHead], Optional[AutoProcessor]]: """ Build the target model according to the arguments. Args: args: The arguments for the training script. draft_model_config: The draft model config. Returns: The target model. """ if is_online: if ( args.is_vlm and draft_model_config.target_model_type == "qwen2_5_vl" and args.target_model_backend == "custom" ): from transformers import Qwen2_5_VLForConditionalGeneration target_model = ( Qwen2_5_VLForConditionalGeneration.from_pretrained( pretrained_model_name_or_path=args.target_model_path, torch_dtype=torch.bfloat16, ) .eval() .cuda() ) else: if args.target_model_backend == "sglang": target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs() else: target_model_kwargs = {} target_model = get_eagle3_target_model( pretrained_model_name_or_path=args.target_model_path, backend=args.target_model_backend, torch_dtype=torch.bfloat16, device="cuda", cache_dir=args.model_download_dir, **target_model_kwargs, trust_remote_code=args.trust_remote_code, ) # set the aux hidden states layers if ( hasattr(draft_model_config, "eagle_config") and draft_model_config.eagle_config is not None and "eagle_aux_hidden_state_layer_ids" in draft_model_config.eagle_config ): target_model.set_aux_hidden_states_layers( draft_model_config.eagle_config["eagle_aux_hidden_state_layer_ids"] ) else: target_model.set_aux_hidden_states_layers() if args.is_vlm: processor = AutoProcessor.from_pretrained( args.target_model_path, min_pixels=args.min_pixels, max_pixels=args.max_pixels, ) else: processor = None return target_model, processor else: target_head = TargetHead.from_pretrained( model_path=args.target_model_path, lm_head_key=args.lm_head_key, cache_dir=args.model_download_dir, trust_remote_code=args.trust_remote_code, ) return target_head, None def sanity_check(args: Namespace) -> None: """ Perform sanity checks on the arguments. Args: args: The arguments for the training script. Returns: None """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size if args.attention_backend == "usp": sp_sanity_check(args) def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) assert ( args.batch_size == 1 ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" assert args.sp_ring_size * args.sp_ulysses_size > 1, ( f"USP requires sp_ring_size * sp_ulysses_size > 1. " f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}." ) assert args.train_hidden_states_path is not None, f"USP only support offline mode" if args.eval_data_path is not None and args.eval_hidden_states_path is not None: raise ValueError( "Cannot set both eval_data_path and eval_hidden_states_path. " "For online mode, set only eval_data_path. " "For offline mode, set only eval_hidden_states_path." ) def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]: # ckpt info(epoch, step) ckpt_info = (0, 0) # Handle draft model config if args.draft_model_config is None: # Auto-generate and save config file auto_config_path = create_draft_config_from_target( target_model_path=args.target_model_path, cache_dir=args.model_download_dir ) draft_model_config = AutoDraftModelConfig.from_file(auto_config_path) else: # Use provided config file draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) # Handle base ckpt, config file draft_model_last_checkpoint = None is_resume_checkpoint = False if args.ckpt_dir is not None: if os.path.isdir(args.ckpt_dir): draft_model_config = AutoDraftModelConfig.from_file( os.path.join(args.ckpt_dir, "config.json") ) draft_model_last_checkpoint = args.ckpt_dir print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}") else: raise ValueError( f"Provided base model dir {args.ckpt_dir} is not a valid directory." ) # detecting last ckpt for draft model if args.resume and os.path.isdir(args.output_dir): print_on_rank0(args.output_dir) draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(args.output_dir) print(f"Last checkpoint detected: {draft_model_last_checkpoint}") is_resume_checkpoint = True if draft_model_last_checkpoint: draft_model = AutoEagle3DraftModel.from_pretrained( draft_model_last_checkpoint, attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, ).cuda() else: draft_model = AutoEagle3DraftModel.from_config( draft_model_config, attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, ).cuda() # Load training state (optimizer, scheduler, epoch, step) for true resume resume_state = None if is_resume_checkpoint and draft_model_last_checkpoint: training_state_path = os.path.join( draft_model_last_checkpoint, "training_state.pt" ) if os.path.exists(training_state_path): resume_state = torch.load( training_state_path, map_location="cpu", weights_only=False ) print_on_rank0( f"Loaded training state from {training_state_path}: " f"epoch={resume_state['epoch']}, step={resume_state['global_step']}" ) draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() return draft_model_config, draft_model, ckpt_info, resume_state def build_dataloaders( args: Namespace, draft_model_config: AutoDraftModelConfig, processor: Optional[AutoProcessor] = None, ) -> Tuple[DataLoader, str, Optional[DataLoader]]: # build dataloaders tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=args.trust_remote_code ) # convert to dataloader cache_params_string = ( f"{args.train_data_path}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() train_dataset = Dataset.from_generator( generator=safe_conversations_generator, gen_kwargs={"file_path": args.train_data_path}, ) is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None ) with rank_0_priority(): train_eagle3_dataset = build_eagle3_dataset( dataset=train_dataset, tokenizer=tokenizer, chat_template=args.chat_template, max_length=args.max_length, cache_dir=os.path.join(args.cache_dir, "processed_dataset"), cache_key=cache_key, is_vlm=args.is_vlm, is_preformatted=args.is_preformatted, processor=processor, num_proc=args.build_dataset_num_proc, train_only_last_turn=args.train_only_last_turn, ) vocab_mapping_path = generate_vocab_mapping_file( dataset=train_eagle3_dataset, target_vocab_size=draft_model_config.vocab_size, draft_vocab_size=draft_model_config.draft_vocab_size, cache_dir=os.path.join(args.cache_dir, "vocab_mapping"), cache_key=cache_key, ) if not is_online: train_eagle3_dataset = build_offline_eagle3_dataset( args.train_hidden_states_path, args.max_length, ttt_length=args.ttt_length, use_usp_preprocess=(args.attention_backend == "usp"), ) train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, args.target_batch_size, num_workers=args.dataloader_num_workers, shuffle=True, process_group=( get_draft_dp_group() if args.attention_backend == "usp" and not is_online else get_dp_group() ), is_vlm=args.is_vlm, ) if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: eval_dataset = Dataset.from_generator( generator=safe_conversations_generator, gen_kwargs={"file_path": args.eval_data_path}, ) eval_eagle3_dataset = build_eagle3_dataset( eval_dataset, tokenizer, args.chat_template, args.max_length, is_vlm=args.is_vlm, processor=processor, num_proc=args.build_dataset_num_proc, is_preformatted=args.is_preformatted, train_only_last_turn=args.train_only_last_turn, ) elif args.eval_hidden_states_path is not None: eval_eagle3_dataset = build_offline_eagle3_dataset( args.eval_hidden_states_path, args.max_length, ttt_length=args.ttt_length, use_usp_preprocess=(args.attention_backend == "usp"), ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, args.target_batch_size, num_workers=args.dataloader_num_workers, shuffle=False, process_group=( get_draft_dp_group() if args.attention_backend == "usp" and not is_online else get_dp_group() ), is_vlm=args.is_vlm, ) print_with_rank("Initialized eval dataloader") else: eval_dataloader = None return ( train_dataloader, vocab_mapping_path, eval_dataloader, ) def save_checkpoints( args: Namespace, epoch: int, step: int, eagle3_model: nn.Module, optimizer: Optimizer, ): epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") if dist.get_rank() == 0: os.makedirs(epoch_output_dir, exist_ok=True) dist.barrier() with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT): model_state_dict = eagle3_model.state_dict() state_to_save = { "epoch": epoch, "global_step": step, "args": args, } state_to_save.update(optimizer.state_dict()) draft_model_state_dict = { k.replace("draft_model.", ""): v for k, v in model_state_dict.items() if "draft_model." in k and "embed" not in k.lower() } if dist.get_rank() == 0: torch.save( state_to_save, os.path.join(epoch_output_dir, "training_state.pt"), ) print_on_rank0( f"Saved full training state to {epoch_output_dir}/training_state.pt" ) eagle3_model.draft_model.save_pretrained( epoch_output_dir, state_dict=draft_model_state_dict, ) print_on_rank0(f"Saved model configuration to {epoch_output_dir}") dist.barrier() def run_forward( args: Namespace, eagle3_model: nn.Module, data: dict, target_model: Optional[Eagle3TargetModel] = None, is_online: bool = True, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: if args.is_vlm and args.target_model_backend == "custom": plosses, _, acces = eagle3_model( input_ids=data["input_ids"].cuda(), attention_mask=data["attention_mask"].cuda(), loss_mask=data["loss_mask"].cuda(), pixel_values=data["pixel_values"].cuda(), image_grid_thw=data["image_grid_thw"].cuda(), ) else: image_grid_thw = None if is_online: # we generate the eagle3 using the target model in an online fashion # Handle VLM data: pixel_values and image_grid_thw are lists # pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None if args.is_vlm: image_grid_thw = ( [thw.cuda().squeeze() for thw in data["image_grid_thw"]] if args.is_vlm else None ) pixel_values = data["pixel_values"].cuda() eagle3_data = target_model.generate_eagle3_data( input_ids=data["input_ids"].cuda(), attention_mask=data["attention_mask"].cuda(), loss_mask=data["loss_mask"].cuda(), is_vlm=args.is_vlm, pixel_values=pixel_values, image_grid_thw=image_grid_thw, ) else: eagle3_data = target_model.generate_eagle3_data( input_ids=data["input_ids"].cuda(), attention_mask=data["attention_mask"].cuda(), loss_mask=data["loss_mask"].cuda(), ) input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids) attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask) loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask) target = get_dp_data_shard_from_tp(eagle3_data.target) hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states) else: # we generate the logits using the hidden states loaded from disk attention_mask = data["attention_mask"].cuda() hidden_states = data["hidden_state"].cuda() input_ids, target, loss_mask = target_model.preprocess( data["input_ids"], data["target"], data["loss_mask"] ) input_ids = input_ids.cuda() target = target_model( target.cuda() ) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU. loss_mask = loss_mask.cuda() plosses, _, acces = eagle3_model( input_ids=input_ids, attention_mask=attention_mask, loss_mask=loss_mask, target=target, hidden_states=hidden_states, position_ids=( data["position_ids"].cuda() if "position_ids" in data else None ), image_grid_thw=image_grid_thw, is_vlm=args.is_vlm, ) return plosses, acces def run_backward_and_update( args: Namespace, plosses: List[torch.Tensor], optimizer: Optimizer, global_step: int ) -> None: ploss_weight = [0.8**i for i in range(len(plosses))] ploss = ( sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) / args.draft_accumulation_steps ) ploss.backward() if global_step % args.draft_accumulation_steps == 0: optimizer.step() def record_metrcs( args: Namespace, accuracies: List[torch.Tensor], plosses: List[torch.Tensor], global_step: int, tracker: Tracker, optimizer: Optional[Optimizer] = None, mode: str = "train", ) -> None: logdict = {} if mode == "train" and optimizer is not None: logdict["train/lr"] = optimizer.get_learning_rate() accuracies = torch.stack(accuracies) plosses = torch.stack(plosses) assert accuracies.shape[0] == args.ttt_length dist.all_reduce(accuracies, op=dist.ReduceOp.AVG) accuracies = accuracies.cpu().tolist() for i in range(len(accuracies)): logdict[f"{mode}/acc_{i}"] = accuracies[i] print_on_rank0( f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acc: {accuracies[i]:.2f}" ) dist.all_reduce(plosses, op=dist.ReduceOp.AVG) plosses = plosses.cpu().tolist() for i in range(len(plosses)): logdict[f"{mode}/ploss_{i}"] = plosses[i] print_on_rank0( f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, pLoss: {plosses[i]}" ) tracker.log(logdict, step=global_step) def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor: """ Get the data shard from the tensor. """ tp_size = dist.get_world_size(get_tp_group()) tp_rank = dist.get_rank(get_tp_group()) return tensor.chunk(tp_size, dim=0)[tp_rank] def main(): # ================================================ # 1. Initialize # ================================================ parser, args = parse_args() set_seed(args.seed) init_distributed( timeout=args.dist_timeout, tp_size=args.tp_size, sp_ring_size=args.sp_ring_size, sp_ulysses_size=args.sp_ulysses_size, ) is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None ) sanity_check(args) print_args_with_dots(args) print_with_rank("Initialized distributed environment") # ================================================ # 2. Build models # ================================================ draft_model_config, draft_model, ckpt_info, resume_state = build_draft_model(args) target_model, processor = build_target_model(args, draft_model_config, is_online) # ================================================ # 3. Build dataloader # ================================================ train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders( args, draft_model_config, processor ) # we load the vocab mapping then draft_model.load_vocab_mapping(vocab_mapping_path) print_with_rank("Loaded vocab mapping") # Calculate total steps if not provided if args.total_steps is None: steps_per_epoch = math.ceil( len(train_dataloader) / args.draft_accumulation_steps ) args.total_steps = args.num_epochs * steps_per_epoch print_with_rank( f"Auto-calculated total_steps: {args.total_steps} (num_epochs={args.num_epochs} * steps_per_epoch={steps_per_epoch})" ) else: print_with_rank(f"Using provided total_steps: {args.total_steps}") # ================================================ # 4. Build Eagle3 model # ================================================ if ( args.is_vlm and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl" and args.tp_size == 1 and args.target_model_backend != "sglang" ): eagle3_model = QwenVLOnlineEagle3Model( target_model=target_model, draft_model=draft_model, processor=processor, length=args.ttt_length, attention_backend=args.attention_backend, ) else: if is_online: eagle3_model = OnlineEagle3Model( target_model=target_model, draft_model=draft_model, length=args.ttt_length, attention_backend=args.attention_backend, ) else: # offline: the target_model is TargetHead not a model eagle3_model = OnlineEagle3Model( draft_model=draft_model, length=args.ttt_length, attention_backend=args.attention_backend, ) eagle3_model = FSDP( eagle3_model, use_orig_params=True, mixed_precision=MixedPrecision( param_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, process_group=dist.group.WORLD, # the draft model should run dp for all processes ) print_with_rank("Initialized Eagle3 FSDP model") # ================================================ # 5. Build optimizer and scheduler # ================================================ optimizer = BF16Optimizer( draft_model, lr=args.learning_rate, max_grad_norm=args.max_grad_norm, warmup_ratio=args.warmup_ratio, total_steps=args.total_steps, ) print_with_rank("Initialized optimizer and scheduler") # Restore optimizer/scheduler state for true resume if resume_state is not None: optimizer.load_state_dict(resume_state) start_epoch = resume_state["epoch"] global_step = resume_state["global_step"] print_on_rank0( f"Restored optimizer/scheduler state: " f"epoch={start_epoch}, step={global_step}, " f"lr={optimizer.get_learning_rate():.6f}" ) del resume_state else: start_epoch = ckpt_info[0] global_step = ckpt_info[1] # Calculate how many steps to skip in the current epoch (for dataloader fast-forward) skip_steps = global_step - start_epoch * len(train_dataloader) # ================================================ # 6. Build tracker # ================================================ tracker = build_tracker(args, parser) dist.barrier() last_time = time.time() # ================================================ # 7. Start training # ================================================ print_on_rank0( f"Starting training from epoch:{start_epoch} step:{global_step}" ) for epoch in range(start_epoch, args.num_epochs): # Run training train_dataloader.sampler.set_epoch(epoch + 1) draft_model.train() if dist.get_rank() == 0: progress_bar = tqdm( train_dataloader, desc=f"Training Epoch {epoch}", leave=True ) else: progress_bar = train_dataloader for step_in_epoch, data in enumerate(progress_bar): # Skip steps already processed in the current epoch when resuming if epoch == start_epoch and step_in_epoch < skip_steps: continue global_step += 1 # ================================================ # 7.0 Profiling # ================================================ if args.profile: # we add the step by 1 to align with global step if global_step == args.profile_start_step + 1: print("Start profile") torch_profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], with_stack=True, record_shapes=args.profile_record_shapes, ) torch_profiler.start() if global_step == args.profile_start_step + args.profile_num_steps + 1: output_path = os.path.join( args.output_dir, f"profile_rank{torch.distributed.get_rank()}_{time.time()}.trace.json.gz", ) print(f"End profile {output_path=}") torch_profiler.stop() torch_profiler.export_chrome_trace(output_path) # ================================================ # 7.1 Training Step # ================================================ plosses, acces = run_forward( args, eagle3_model, data, target_model, is_online, ) run_backward_and_update(args, plosses, optimizer, global_step) # log training metrics if global_step % (args.log_interval * args.draft_accumulation_steps) == 0: record_metrcs( args, acces, plosses, global_step // args.draft_accumulation_steps, tracker, optimizer, mode="train", ) if dist.get_rank() == 0: time_per_step = time.time() - last_time last_time = time.time() avg_loss = sum(pl for pl in plosses) / len(plosses) avg_acc = sum(acces) / len(acces) progress_bar.set_postfix( { "loss": f"{avg_loss:.2f}", "acc": f"{avg_acc:.2f}", "time": f"{time_per_step:.2f}s", } ) # ================================================ # 7.2 Evaluation Step # ================================================ should_evaluate = ( args.eval_data_path is not None or args.eval_hidden_states_path is not None ) if ( should_evaluate and global_step % (args.eval_interval * args.draft_accumulation_steps) == 0 ): # Run evaluation draft_model.eval() eval_acces = [[] for _ in range(eagle3_model.length)] eval_plosses = [[] for _ in range(eagle3_model.length)] for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"): with torch.no_grad(): plosses, acces = run_forward( args, eagle3_model, data, target_model, is_online ) eval_acces = [ eval_acces[i] + [acces[i]] for i in range(len(acces)) ] eval_plosses = [ eval_plosses[i] + [plosses[i]] for i in range(len(plosses)) ] # compute average over all minibatches eval_acces = [torch.stack(acc).mean() for acc in eval_acces] eval_plosses = [torch.stack(pl).mean() for pl in eval_plosses] record_metrcs( args, eval_acces, eval_plosses, global_step // args.draft_accumulation_steps, tracker, mode="eval", ) # ================================================ # 7.3 Save Checkpoints # ================================================ if global_step % args.save_interval == 0: # Save the model save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) if args.max_num_steps is not None and global_step >= args.max_num_steps: break if args.max_num_steps is not None and global_step >= args.max_num_steps: break # Save final checkpoint if training ended without saving if global_step % args.save_interval != 0: print_on_rank0( f"Training completed at step {global_step}, saving final checkpoint..." ) save_checkpoints(args, epoch, global_step, eagle3_model, optimizer) # Close the tracker tracker.close() destroy_distributed() if __name__ == "__main__": main()