Hanrui / SpecForge /scripts /train_eagle3.py
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
import argparse
import hashlib
import math
import os
import time
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate.utils import set_seed
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, AutoTokenizer
from datasets import Dataset
from specforge import (
AutoDraftModelConfig,
AutoEagle3DraftModel,
OnlineEagle3Model,
QwenVLOnlineEagle3Model,
)
from specforge.args import SGLangBackendArgs, TrackerArgs
from specforge.data import (
build_eagle3_dataset,
build_offline_eagle3_dataset,
generate_vocab_mapping_file,
prepare_dp_dataloaders,
)
from specforge.distributed import (
destroy_distributed,
get_dp_group,
get_draft_dp_group,
get_tp_group,
init_distributed,
)
from specforge.modeling.target import (
Eagle3TargetModel,
TargetHead,
get_eagle3_target_model,
)
from specforge.optimizer import BF16Optimizer
from specforge.tracker import Tracker, create_tracker, get_tracker_class
from specforge.utils import (
create_draft_config_from_target,
get_last_checkpoint,
print_args_with_dots,
print_on_rank0,
print_with_rank,
rank_0_priority,
safe_conversations_generator,
)
def parse_args() -> Tuple[ArgumentParser, Namespace]:
"""
This function is used to parse the arguments for the training script.
"""
parser = argparse.ArgumentParser(description="Train Eagle3 with online data")
# add model-related arguments
model_group = parser.add_argument_group("model")
model_group.add_argument("--target-model-path", type=str, required=True)
model_group.add_argument(
"--trust-remote-code", action="store_true", help="Trust remote code"
)
model_group.add_argument(
"--draft-model-config",
type=str,
required=False,
help="Draft model config path. If not provided, will auto-generate from target model.",
)
model_group.add_argument(
"--embedding-key",
type=str,
default="model.embed_tokens.weight",
help="The key of the embedding weight to load from the target model",
)
model_group.add_argument(
"--lm-head-key",
type=str,
default="lm_head.weight",
help="The key of the lm head weight to load from the target model, this is only required for offline training",
)
model_group.add_argument(
"--is-vlm", action="store_true", help="Whether the target model is a VLM"
)
model_group.add_argument(
"--target-model-backend",
type=str,
default="sglang",
choices=["sglang", "hf", "custom"],
help="The backend of the target model",
)
# dataset arguments
dataset_group = parser.add_argument_group("dataset")
dataset_group.add_argument("--train-data-path", type=str, required=True)
dataset_group.add_argument("--train-hidden-states-path", type=str, default=None)
dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None)
dataset_group.add_argument("--eval-data-path", type=str, default=None)
dataset_group.add_argument("--chat-template", type=str, default="llama3")
dataset_group.add_argument(
"--is-preformatted",
action="store_true",
help="Whether the input data is preformatted text with the chat template already applied to the conversation messages.",
)
dataset_group.add_argument(
"--train-only-last-turn",
action="store_true",
help="If set, only the last assistant turn in each conversation contributes to the loss. "
"Useful for thinking models where conversation history may lack thought processes.",
)
dataset_group.add_argument("--build-dataset-num-proc", type=int, default=8)
dataset_group.add_argument(
"--dataloader-num-workers",
type=int,
default=4,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
# training hyper params
training_group = parser.add_argument_group("training")
training_group.add_argument("--num-epochs", type=int, default=10)
training_group.add_argument(
"--max-num-steps",
type=int,
default=None,
help="The maximum number of steps to train. If not provided, will be calculated as num_epochs * steps_per_epoch",
)
training_group.add_argument("--batch-size", type=int, default=1)
training_group.add_argument("--learning-rate", type=float, default=1e-4)
training_group.add_argument("--max-length", type=int, default=2048)
training_group.add_argument("--warmup-ratio", type=float, default=0.015)
training_group.add_argument(
"--total-steps",
type=int,
default=None,
help="Total training steps. If not provided, will be calculated as num_epochs * steps_per_epoch",
)
training_group.add_argument("--max-grad-norm", type=float, default=0.5)
training_group.add_argument(
"--ttt-length",
type=int,
default=7,
help="The length for Test-Time Training (TTT).",
)
training_group.add_argument("--resume", action="store_true")
training_group.add_argument(
"--ckpt-dir",
type=str,
default=None,
help="directory includes the checkpoint to start training with",
)
training_group.add_argument("--eval-interval", type=int, default=5000)
training_group.add_argument("--save-interval", type=int, default=5000)
training_group.add_argument(
"--log-interval",
type=int,
default=50,
help="Log training metrics every N steps",
)
training_group.add_argument("--seed", type=int, default=0)
training_group.add_argument("--draft-accumulation-steps", type=int, default=1)
# data processing type
optimization_group = parser.add_argument_group("optimization")
optimization_group.add_argument(
"--tp-size",
type=int,
default=1,
help="The size of the tensor parallel for the target model",
)
# distributed training
optimization_group.add_argument("--sp-ulysses-size", type=int, default=1)
optimization_group.add_argument("--sp-ring-size", type=int, default=1)
optimization_group.add_argument(
"--attention-backend",
type=str,
default="flex_attention",
help="The attention backend for the draft model",
)
# other args
other_group = parser.add_argument_group("others")
other_group.add_argument("--cache-key", type=str, default=None)
other_group.add_argument("--cache-dir", type=str, default="./cache")
other_group.add_argument("--output-dir", type=str, required=True)
other_group.add_argument("--verbose", action="store_true")
other_group.add_argument(
"--dist-timeout",
type=int,
default=20,
help="Timeout for collective communication in minutes",
)
other_group.add_argument(
"--model-download-dir",
type=str,
default=None,
help="The directory to download the target model to",
)
# vlm related args
vlm_group = parser.add_argument_group("vlm")
vlm_group.add_argument(
"--min-pixels", type=int, default=50176
) # 64*28*28 for qwen2.5-vl
vlm_group.add_argument(
"--max-pixels", type=int, default=802816
) # 1024*28*28 for qwen2.5-vl
# profiling related args
profiling_group = parser.add_argument_group("profiling")
profiling_group.add_argument("--profile", action="store_true")
profiling_group.add_argument("--profile-start-step", type=int, default=30)
profiling_group.add_argument("--profile-num-steps", type=int, default=4)
profiling_group.add_argument("--profile-record-shapes", action="store_true")
# sglang target model backend related args
sglang_group = parser.add_argument_group("sglang target model backend")
SGLangBackendArgs.add_args(sglang_group)
# tracker related args
tracker_group = parser.add_argument_group("tracker")
TrackerArgs.add_args(tracker_group)
args = parser.parse_args()
return parser, args
def build_tracker(args: Namespace, parser: ArgumentParser) -> Tracker:
"""
Build the experiment tracker according to the report_to argument.
Args:
args: The arguments for the training script.
parser: The parser for the training script.
Returns:
The experiment tracker.
"""
tracker_class = get_tracker_class(args.report_to)
if tracker_class:
tracker_class.validate_args(parser, args)
else:
parser.error(f"Unknown tracker: {args.report_to}")
tracker = create_tracker(args, args.output_dir)
return tracker
def build_target_model(
args: Namespace, draft_model_config: AutoDraftModelConfig, is_online: bool = True
) -> Tuple[Union[Eagle3TargetModel, TargetHead], Optional[AutoProcessor]]:
"""
Build the target model according to the arguments.
Args:
args: The arguments for the training script.
draft_model_config: The draft model config.
Returns:
The target model.
"""
if is_online:
if (
args.is_vlm
and draft_model_config.target_model_type == "qwen2_5_vl"
and args.target_model_backend == "custom"
):
from transformers import Qwen2_5_VLForConditionalGeneration
target_model = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path=args.target_model_path,
torch_dtype=torch.bfloat16,
)
.eval()
.cuda()
)
else:
if args.target_model_backend == "sglang":
target_model_kwargs = SGLangBackendArgs.from_args(args).to_kwargs()
else:
target_model_kwargs = {}
target_model = get_eagle3_target_model(
pretrained_model_name_or_path=args.target_model_path,
backend=args.target_model_backend,
torch_dtype=torch.bfloat16,
device="cuda",
cache_dir=args.model_download_dir,
**target_model_kwargs,
trust_remote_code=args.trust_remote_code,
)
# set the aux hidden states layers
if (
hasattr(draft_model_config, "eagle_config")
and draft_model_config.eagle_config is not None
and "eagle_aux_hidden_state_layer_ids" in draft_model_config.eagle_config
):
target_model.set_aux_hidden_states_layers(
draft_model_config.eagle_config["eagle_aux_hidden_state_layer_ids"]
)
else:
target_model.set_aux_hidden_states_layers()
if args.is_vlm:
processor = AutoProcessor.from_pretrained(
args.target_model_path,
min_pixels=args.min_pixels,
max_pixels=args.max_pixels,
)
else:
processor = None
return target_model, processor
else:
target_head = TargetHead.from_pretrained(
model_path=args.target_model_path,
lm_head_key=args.lm_head_key,
cache_dir=args.model_download_dir,
trust_remote_code=args.trust_remote_code,
)
return target_head, None
def sanity_check(args: Namespace) -> None:
"""
Perform sanity checks on the arguments.
Args:
args: The arguments for the training script.
Returns:
None
"""
args.dp_size = dist.get_world_size() // args.tp_size
args.target_batch_size = args.tp_size * args.batch_size
if args.attention_backend == "usp":
sp_sanity_check(args)
def sp_sanity_check(args: Namespace) -> None:
args.draft_accumulation_steps = (
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
)
assert (
args.batch_size == 1
), f"USP only supports batch_size=1, got batch_size={args.batch_size}"
assert args.sp_ring_size * args.sp_ulysses_size > 1, (
f"USP requires sp_ring_size * sp_ulysses_size > 1. "
f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}."
)
assert args.train_hidden_states_path is not None, f"USP only support offline mode"
if args.eval_data_path is not None and args.eval_hidden_states_path is not None:
raise ValueError(
"Cannot set both eval_data_path and eval_hidden_states_path. "
"For online mode, set only eval_data_path. "
"For offline mode, set only eval_hidden_states_path."
)
def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]:
# ckpt info(epoch, step)
ckpt_info = (0, 0)
# Handle draft model config
if args.draft_model_config is None:
# Auto-generate and save config file
auto_config_path = create_draft_config_from_target(
target_model_path=args.target_model_path, cache_dir=args.model_download_dir
)
draft_model_config = AutoDraftModelConfig.from_file(auto_config_path)
else:
# Use provided config file
draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config)
# Handle base ckpt, config file
draft_model_last_checkpoint = None
is_resume_checkpoint = False
if args.ckpt_dir is not None:
if os.path.isdir(args.ckpt_dir):
draft_model_config = AutoDraftModelConfig.from_file(
os.path.join(args.ckpt_dir, "config.json")
)
draft_model_last_checkpoint = args.ckpt_dir
print_on_rank0(f"Finetuning from base model: {draft_model_last_checkpoint}")
else:
raise ValueError(
f"Provided base model dir {args.ckpt_dir} is not a valid directory."
)
# detecting last ckpt for draft model
if args.resume and os.path.isdir(args.output_dir):
print_on_rank0(args.output_dir)
draft_model_last_checkpoint, ckpt_info = get_last_checkpoint(args.output_dir)
print(f"Last checkpoint detected: {draft_model_last_checkpoint}")
is_resume_checkpoint = True
if draft_model_last_checkpoint:
draft_model = AutoEagle3DraftModel.from_pretrained(
draft_model_last_checkpoint,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
else:
draft_model = AutoEagle3DraftModel.from_config(
draft_model_config,
attention_backend=args.attention_backend,
torch_dtype=torch.bfloat16,
).cuda()
# Load training state (optimizer, scheduler, epoch, step) for true resume
resume_state = None
if is_resume_checkpoint and draft_model_last_checkpoint:
training_state_path = os.path.join(
draft_model_last_checkpoint, "training_state.pt"
)
if os.path.exists(training_state_path):
resume_state = torch.load(
training_state_path, map_location="cpu", weights_only=False
)
print_on_rank0(
f"Loaded training state from {training_state_path}: "
f"epoch={resume_state['epoch']}, step={resume_state['global_step']}"
)
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
return draft_model_config, draft_model, ckpt_info, resume_state
def build_dataloaders(
args: Namespace,
draft_model_config: AutoDraftModelConfig,
processor: Optional[AutoProcessor] = None,
) -> Tuple[DataLoader, str, Optional[DataLoader]]:
# build dataloaders
tokenizer = AutoTokenizer.from_pretrained(
args.target_model_path, trust_remote_code=args.trust_remote_code
)
# convert to dataloader
cache_params_string = (
f"{args.train_data_path}-"
f"{args.max_length}-"
f"{args.chat_template}-"
f"{args.target_model_path}" # Tokenizer may also different
)
cache_key = hashlib.md5(cache_params_string.encode()).hexdigest()
train_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.train_data_path},
)
is_online = (
args.train_data_path is not None and args.train_hidden_states_path is None
)
with rank_0_priority():
train_eagle3_dataset = build_eagle3_dataset(
dataset=train_dataset,
tokenizer=tokenizer,
chat_template=args.chat_template,
max_length=args.max_length,
cache_dir=os.path.join(args.cache_dir, "processed_dataset"),
cache_key=cache_key,
is_vlm=args.is_vlm,
is_preformatted=args.is_preformatted,
processor=processor,
num_proc=args.build_dataset_num_proc,
train_only_last_turn=args.train_only_last_turn,
)
vocab_mapping_path = generate_vocab_mapping_file(
dataset=train_eagle3_dataset,
target_vocab_size=draft_model_config.vocab_size,
draft_vocab_size=draft_model_config.draft_vocab_size,
cache_dir=os.path.join(args.cache_dir, "vocab_mapping"),
cache_key=cache_key,
)
if not is_online:
train_eagle3_dataset = build_offline_eagle3_dataset(
args.train_hidden_states_path,
args.max_length,
ttt_length=args.ttt_length,
use_usp_preprocess=(args.attention_backend == "usp"),
)
train_dataloader = prepare_dp_dataloaders(
train_eagle3_dataset,
args.target_batch_size,
num_workers=args.dataloader_num_workers,
shuffle=True,
process_group=(
get_draft_dp_group()
if args.attention_backend == "usp" and not is_online
else get_dp_group()
),
is_vlm=args.is_vlm,
)
if args.eval_data_path is not None or args.eval_hidden_states_path is not None:
if args.eval_data_path is not None:
eval_dataset = Dataset.from_generator(
generator=safe_conversations_generator,
gen_kwargs={"file_path": args.eval_data_path},
)
eval_eagle3_dataset = build_eagle3_dataset(
eval_dataset,
tokenizer,
args.chat_template,
args.max_length,
is_vlm=args.is_vlm,
processor=processor,
num_proc=args.build_dataset_num_proc,
is_preformatted=args.is_preformatted,
train_only_last_turn=args.train_only_last_turn,
)
elif args.eval_hidden_states_path is not None:
eval_eagle3_dataset = build_offline_eagle3_dataset(
args.eval_hidden_states_path,
args.max_length,
ttt_length=args.ttt_length,
use_usp_preprocess=(args.attention_backend == "usp"),
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
args.target_batch_size,
num_workers=args.dataloader_num_workers,
shuffle=False,
process_group=(
get_draft_dp_group()
if args.attention_backend == "usp" and not is_online
else get_dp_group()
),
is_vlm=args.is_vlm,
)
print_with_rank("Initialized eval dataloader")
else:
eval_dataloader = None
return (
train_dataloader,
vocab_mapping_path,
eval_dataloader,
)
def save_checkpoints(
args: Namespace,
epoch: int,
step: int,
eagle3_model: nn.Module,
optimizer: Optimizer,
):
epoch_output_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}")
if dist.get_rank() == 0:
os.makedirs(epoch_output_dir, exist_ok=True)
dist.barrier()
with FSDP.state_dict_type(eagle3_model, StateDictType.FULL_STATE_DICT):
model_state_dict = eagle3_model.state_dict()
state_to_save = {
"epoch": epoch,
"global_step": step,
"args": args,
}
state_to_save.update(optimizer.state_dict())
draft_model_state_dict = {
k.replace("draft_model.", ""): v
for k, v in model_state_dict.items()
if "draft_model." in k and "embed" not in k.lower()
}
if dist.get_rank() == 0:
torch.save(
state_to_save,
os.path.join(epoch_output_dir, "training_state.pt"),
)
print_on_rank0(
f"Saved full training state to {epoch_output_dir}/training_state.pt"
)
eagle3_model.draft_model.save_pretrained(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()
def run_forward(
args: Namespace,
eagle3_model: nn.Module,
data: dict,
target_model: Optional[Eagle3TargetModel] = None,
is_online: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
if args.is_vlm and args.target_model_backend == "custom":
plosses, _, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
pixel_values=data["pixel_values"].cuda(),
image_grid_thw=data["image_grid_thw"].cuda(),
)
else:
image_grid_thw = None
if is_online:
# we generate the eagle3 using the target model in an online fashion
# Handle VLM data: pixel_values and image_grid_thw are lists
# pixel_values = [pv.cuda() for pv in data["pixel_values"]] if args.is_vlm else None
if args.is_vlm:
image_grid_thw = (
[thw.cuda().squeeze() for thw in data["image_grid_thw"]]
if args.is_vlm
else None
)
pixel_values = data["pixel_values"].cuda()
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
is_vlm=args.is_vlm,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
else:
eagle3_data = target_model.generate_eagle3_data(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
)
input_ids = get_dp_data_shard_from_tp(eagle3_data.input_ids)
attention_mask = get_dp_data_shard_from_tp(eagle3_data.attention_mask)
loss_mask = get_dp_data_shard_from_tp(eagle3_data.loss_mask)
target = get_dp_data_shard_from_tp(eagle3_data.target)
hidden_states = get_dp_data_shard_from_tp(eagle3_data.hidden_states)
else:
# we generate the logits using the hidden states loaded from disk
attention_mask = data["attention_mask"].cuda()
hidden_states = data["hidden_state"].cuda()
input_ids, target, loss_mask = target_model.preprocess(
data["input_ids"], data["target"], data["loss_mask"]
)
input_ids = input_ids.cuda()
target = target_model(
target.cuda()
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
loss_mask = loss_mask.cuda()
plosses, _, acces = eagle3_model(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
target=target,
hidden_states=hidden_states,
position_ids=(
data["position_ids"].cuda() if "position_ids" in data else None
),
image_grid_thw=image_grid_thw,
is_vlm=args.is_vlm,
)
return plosses, acces
def run_backward_and_update(
args: Namespace, plosses: List[torch.Tensor], optimizer: Optimizer, global_step: int
) -> None:
ploss_weight = [0.8**i for i in range(len(plosses))]
ploss = (
sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
/ args.draft_accumulation_steps
)
ploss.backward()
if global_step % args.draft_accumulation_steps == 0:
optimizer.step()
def record_metrcs(
args: Namespace,
accuracies: List[torch.Tensor],
plosses: List[torch.Tensor],
global_step: int,
tracker: Tracker,
optimizer: Optional[Optimizer] = None,
mode: str = "train",
) -> None:
logdict = {}
if mode == "train" and optimizer is not None:
logdict["train/lr"] = optimizer.get_learning_rate()
accuracies = torch.stack(accuracies)
plosses = torch.stack(plosses)
assert accuracies.shape[0] == args.ttt_length
dist.all_reduce(accuracies, op=dist.ReduceOp.AVG)
accuracies = accuracies.cpu().tolist()
for i in range(len(accuracies)):
logdict[f"{mode}/acc_{i}"] = accuracies[i]
print_on_rank0(
f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acc: {accuracies[i]:.2f}"
)
dist.all_reduce(plosses, op=dist.ReduceOp.AVG)
plosses = plosses.cpu().tolist()
for i in range(len(plosses)):
logdict[f"{mode}/ploss_{i}"] = plosses[i]
print_on_rank0(
f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, pLoss: {plosses[i]}"
)
tracker.log(logdict, step=global_step)
def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor:
"""
Get the data shard from the tensor.
"""
tp_size = dist.get_world_size(get_tp_group())
tp_rank = dist.get_rank(get_tp_group())
return tensor.chunk(tp_size, dim=0)[tp_rank]
def main():
# ================================================
# 1. Initialize
# ================================================
parser, args = parse_args()
set_seed(args.seed)
init_distributed(
timeout=args.dist_timeout,
tp_size=args.tp_size,
sp_ring_size=args.sp_ring_size,
sp_ulysses_size=args.sp_ulysses_size,
)
is_online = (
args.train_data_path is not None and args.train_hidden_states_path is None
)
sanity_check(args)
print_args_with_dots(args)
print_with_rank("Initialized distributed environment")
# ================================================
# 2. Build models
# ================================================
draft_model_config, draft_model, ckpt_info, resume_state = build_draft_model(args)
target_model, processor = build_target_model(args, draft_model_config, is_online)
# ================================================
# 3. Build dataloader
# ================================================
train_dataloader, vocab_mapping_path, eval_dataloader = build_dataloaders(
args, draft_model_config, processor
)
# we load the vocab mapping then
draft_model.load_vocab_mapping(vocab_mapping_path)
print_with_rank("Loaded vocab mapping")
# Calculate total steps if not provided
if args.total_steps is None:
steps_per_epoch = math.ceil(
len(train_dataloader) / args.draft_accumulation_steps
)
args.total_steps = args.num_epochs * steps_per_epoch
print_with_rank(
f"Auto-calculated total_steps: {args.total_steps} (num_epochs={args.num_epochs} * steps_per_epoch={steps_per_epoch})"
)
else:
print_with_rank(f"Using provided total_steps: {args.total_steps}")
# ================================================
# 4. Build Eagle3 model
# ================================================
if (
args.is_vlm
and getattr(draft_model_config, "target_model_type", None) == "qwen2_5_vl"
and args.tp_size == 1
and args.target_model_backend != "sglang"
):
eagle3_model = QwenVLOnlineEagle3Model(
target_model=target_model,
draft_model=draft_model,
processor=processor,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
else:
if is_online:
eagle3_model = OnlineEagle3Model(
target_model=target_model,
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
else:
# offline: the target_model is TargetHead not a model
eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
)
eagle3_model = FSDP(
eagle3_model,
use_orig_params=True,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
process_group=dist.group.WORLD, # the draft model should run dp for all processes
)
print_with_rank("Initialized Eagle3 FSDP model")
# ================================================
# 5. Build optimizer and scheduler
# ================================================
optimizer = BF16Optimizer(
draft_model,
lr=args.learning_rate,
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
total_steps=args.total_steps,
)
print_with_rank("Initialized optimizer and scheduler")
# Restore optimizer/scheduler state for true resume
if resume_state is not None:
optimizer.load_state_dict(resume_state)
start_epoch = resume_state["epoch"]
global_step = resume_state["global_step"]
print_on_rank0(
f"Restored optimizer/scheduler state: "
f"epoch={start_epoch}, step={global_step}, "
f"lr={optimizer.get_learning_rate():.6f}"
)
del resume_state
else:
start_epoch = ckpt_info[0]
global_step = ckpt_info[1]
# Calculate how many steps to skip in the current epoch (for dataloader fast-forward)
skip_steps = global_step - start_epoch * len(train_dataloader)
# ================================================
# 6. Build tracker
# ================================================
tracker = build_tracker(args, parser)
dist.barrier()
last_time = time.time()
# ================================================
# 7. Start training
# ================================================
print_on_rank0(
f"Starting training from epoch:{start_epoch} step:{global_step}"
)
for epoch in range(start_epoch, args.num_epochs):
# Run training
train_dataloader.sampler.set_epoch(epoch + 1)
draft_model.train()
if dist.get_rank() == 0:
progress_bar = tqdm(
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
)
else:
progress_bar = train_dataloader
for step_in_epoch, data in enumerate(progress_bar):
# Skip steps already processed in the current epoch when resuming
if epoch == start_epoch and step_in_epoch < skip_steps:
continue
global_step += 1
# ================================================
# 7.0 Profiling
# ================================================
if args.profile:
# we add the step by 1 to align with global step
if global_step == args.profile_start_step + 1:
print("Start profile")
torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=args.profile_record_shapes,
)
torch_profiler.start()
if global_step == args.profile_start_step + args.profile_num_steps + 1:
output_path = os.path.join(
args.output_dir,
f"profile_rank{torch.distributed.get_rank()}_{time.time()}.trace.json.gz",
)
print(f"End profile {output_path=}")
torch_profiler.stop()
torch_profiler.export_chrome_trace(output_path)
# ================================================
# 7.1 Training Step
# ================================================
plosses, acces = run_forward(
args,
eagle3_model,
data,
target_model,
is_online,
)
run_backward_and_update(args, plosses, optimizer, global_step)
# log training metrics
if global_step % (args.log_interval * args.draft_accumulation_steps) == 0:
record_metrcs(
args,
acces,
plosses,
global_step // args.draft_accumulation_steps,
tracker,
optimizer,
mode="train",
)
if dist.get_rank() == 0:
time_per_step = time.time() - last_time
last_time = time.time()
avg_loss = sum(pl for pl in plosses) / len(plosses)
avg_acc = sum(acces) / len(acces)
progress_bar.set_postfix(
{
"loss": f"{avg_loss:.2f}",
"acc": f"{avg_acc:.2f}",
"time": f"{time_per_step:.2f}s",
}
)
# ================================================
# 7.2 Evaluation Step
# ================================================
should_evaluate = (
args.eval_data_path is not None
or args.eval_hidden_states_path is not None
)
if (
should_evaluate
and global_step % (args.eval_interval * args.draft_accumulation_steps)
== 0
):
# Run evaluation
draft_model.eval()
eval_acces = [[] for _ in range(eagle3_model.length)]
eval_plosses = [[] for _ in range(eagle3_model.length)]
for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"):
with torch.no_grad():
plosses, acces = run_forward(
args, eagle3_model, data, target_model, is_online
)
eval_acces = [
eval_acces[i] + [acces[i]] for i in range(len(acces))
]
eval_plosses = [
eval_plosses[i] + [plosses[i]] for i in range(len(plosses))
]
# compute average over all minibatches
eval_acces = [torch.stack(acc).mean() for acc in eval_acces]
eval_plosses = [torch.stack(pl).mean() for pl in eval_plosses]
record_metrcs(
args,
eval_acces,
eval_plosses,
global_step // args.draft_accumulation_steps,
tracker,
mode="eval",
)
# ================================================
# 7.3 Save Checkpoints
# ================================================
if global_step % args.save_interval == 0:
# Save the model
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
if args.max_num_steps is not None and global_step >= args.max_num_steps:
break
if args.max_num_steps is not None and global_step >= args.max_num_steps:
break
# Save final checkpoint if training ended without saving
if global_step % args.save_interval != 0:
print_on_rank0(
f"Training completed at step {global_step}, saving final checkpoint..."
)
save_checkpoints(args, epoch, global_step, eagle3_model, optimizer)
# Close the tracker
tracker.close()
destroy_distributed()
if __name__ == "__main__":
main()