Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import gc | |
import io | |
import json | |
import os | |
import random | |
import time | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wandb | |
from accelerate import Accelerator | |
from accelerate.utils import DistributedDataParallelKwargs | |
from torch.optim import AdamW | |
from torch.optim.lr_scheduler import LinearLR, SequentialLR | |
from torch.utils.data import DataLoader, Dataset, SequentialSampler, Subset | |
from tqdm import tqdm | |
from f5_tts.model.dataset import DynamicBatchSampler, collate_fn | |
from f5_tts.model.utils import list_str_to_idx | |
# torch.autograd.set_detect_anomaly(True) | |
# os.environ['HYDRA_FULL_ERROR'] = 'True' | |
def safe_sample(logits, temperature=1.0): | |
""" | |
logits: Tensor of shape (B, n_class) | |
temperature: Sampling temperature (higher => more random) | |
""" | |
# Apply temperature scaling | |
scaled_logits = logits / temperature | |
# Compute categorical distribution | |
probs = F.softmax(scaled_logits, dim=-1) | |
# Sample from the distribution once per batch element | |
samples = torch.multinomial(probs, num_samples=1) # (B, 1) | |
# Convert to one-hot encoding | |
one_hot_samples = torch.zeros_like(probs).scatter_(1, samples, 1) | |
return one_hot_samples | |
class GRPODurationTrainer: | |
""" | |
Trainer class that implements GRPO (Generative Reinforcement Learning from Preference Optimization) | |
for a duration predictor in text-to-speech synthesis. | |
""" | |
def __init__( | |
self, | |
model, # Duration predictor model | |
inference_fn, # Function to generate speech | |
reward_fn, # Function to compute rewards from generated speech | |
vocab_size: int, # Size of the vocabulary | |
vocab_char_map: dict, # Mapping from characters to token IDs | |
# Duration model parameters | |
n_class: int = 301, # Number of duration classes | |
n_frame_per_class: int = 10, # Number of frames per class | |
gumbel_tau: int = 0.7, | |
# GRPO parameters | |
beta: float = 0.04, # KL regularization weight | |
clip_param: float = 0.2, # PPO clip parameter | |
num_pre_samples: int = 8, # Number of samples per prompt | |
compute_gen_logps: bool = True, # Whether to compute generation log probabilities | |
# Training parameters | |
learning_rate: float = 5e-6, | |
num_warmup_updates: int = 10000, | |
save_per_updates: int = 10000, | |
checkpoint_path: Optional[str] = None, | |
all_steps: int = 100000, # Total training steps | |
# Batch parameters | |
batch_size: int = 8, | |
batch_size_type: str = "sample", | |
max_samples: int = 16, | |
grad_accumulation_steps: int = 2, | |
max_grad_norm: float = 1.0, | |
# Logging parameters | |
logger: Optional[str] = "wandb", | |
wandb_project: str = "tts-duration-grpo", | |
wandb_run_name: str = "grpo_run", | |
wandb_resume_id: Optional[str] = None, | |
accelerate_kwargs: dict = dict(), | |
): | |
# Initialize accelerator for distributed training | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | |
if logger == "wandb" and not wandb.api.api_key: | |
logger = None | |
print(f"Using logger: {logger}") | |
self.accelerator = Accelerator( | |
log_with=logger if logger == "wandb" else None, | |
kwargs_handlers=[ddp_kwargs], | |
gradient_accumulation_steps=grad_accumulation_steps, | |
**accelerate_kwargs, | |
) | |
self.logger = logger | |
if self.logger == "wandb": | |
if wandb_resume_id: | |
init_kwargs = { | |
"wandb": { | |
"resume": "allow", | |
"name": wandb_run_name, | |
"id": wandb_resume_id, | |
} | |
} | |
else: | |
init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}} | |
self.accelerator.init_trackers( | |
project_name=wandb_project, | |
init_kwargs=init_kwargs, | |
config={ | |
"learning_rate": learning_rate, | |
"num_warmup_updates": num_warmup_updates, | |
"batch_size": batch_size, | |
"beta": beta, | |
"clip_param": clip_param, | |
"num_pre_samples": num_pre_samples, | |
"n_class": n_class, | |
"n_frame_per_class": n_frame_per_class, | |
"all_steps": all_steps, | |
"grad_accumulation_steps": grad_accumulation_steps, | |
"max_grad_norm": max_grad_norm, | |
"gpus": self.accelerator.num_processes, | |
}, | |
) | |
elif self.logger == "tensorboard": | |
from torch.utils.tensorboard import SummaryWriter | |
self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}") | |
# Store model, inference function, and reward function | |
self.model = model | |
# Create reference model (frozen clone of the initial model) | |
self.ref_model = copy.deepcopy(model) | |
for param in self.ref_model.parameters(): | |
param.requires_grad = False | |
self.ref_model.eval() | |
# prepare inference_fn | |
self.inference_fn = inference_fn | |
self.inference_fn.scale = self.inference_fn.scale.to(self.accelerator.device) | |
self.inference_fn.tts_model = self.inference_fn.tts_model.to( | |
self.accelerator.device | |
) | |
# prepare reward_fn | |
self.reward_fn = reward_fn | |
# Store vocabulary and mapping | |
self.vocab_size = vocab_size | |
self.vocab_char_map = vocab_char_map | |
# Store duration model parameters | |
self.n_class = n_class | |
self.n_frame_per_class = n_frame_per_class | |
self.gumbel_tau = gumbel_tau | |
# Store GRPO parameters | |
self.beta = beta | |
self.clip_param = clip_param | |
self.num_pre_samples = num_pre_samples | |
self.compute_gen_logps = compute_gen_logps | |
# Store training parameters | |
self.learning_rate = learning_rate | |
self.num_warmup_updates: int = num_warmup_updates | |
self.save_per_updates = save_per_updates | |
self.checkpoint_path = checkpoint_path or f"ckpts/{wandb_run_name}" | |
self.all_steps = all_steps | |
# Store batch parameters | |
self.batch_size = batch_size | |
self.batch_size_type = batch_size_type | |
self.max_samples = max_samples | |
self.grad_accumulation_steps = grad_accumulation_steps | |
self.max_grad_norm = max_grad_norm | |
# Initialize optimizer | |
self.optimizer = AdamW(model.parameters(), lr=learning_rate) | |
# Prepare model and optimizer with accelerator | |
self.model, self.optimizer = self.accelerator.prepare( | |
self.model, self.optimizer | |
) | |
self.ref_model = self.accelerator.prepare(self.ref_model) | |
self.reward_fn, self.inference_fn = self.accelerator.prepare( | |
self.reward_fn, self.inference_fn | |
) | |
# GRPO batch queue | |
self.batch_queue = [] | |
# Store distributed rank | |
self.rank = ( | |
torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 | |
) | |
self.device = f"cuda:{self.rank}" | |
def is_main(self): | |
return self.accelerator.is_main_process | |
def save_checkpoint(self, step, last=False): | |
"""Save model and optimizer state""" | |
self.accelerator.wait_for_everyone() | |
if self.is_main: | |
checkpoint = dict( | |
model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(), | |
optimizer_state_dict=self.accelerator.unwrap_model( | |
self.optimizer | |
).state_dict(), | |
scheduler_state_dict=( | |
self.scheduler.state_dict() if hasattr(self, "scheduler") else None | |
), | |
step=step, | |
) | |
if not os.path.exists(self.checkpoint_path): | |
os.makedirs(self.checkpoint_path) | |
if last: | |
self.accelerator.save( | |
checkpoint, f"{self.checkpoint_path}/model_last.pt" | |
) | |
else: | |
self.accelerator.save( | |
checkpoint, f"{self.checkpoint_path}/model_{step}.pt" | |
) | |
def load_checkpoint(self): | |
"""Load latest checkpoint if available""" | |
if ( | |
not self.checkpoint_path | |
or not os.path.exists(self.checkpoint_path) | |
or not any( | |
filename.endswith(".pt") | |
for filename in os.listdir(self.checkpoint_path) | |
) | |
): | |
return 0 | |
self.accelerator.wait_for_everyone() | |
if "model_last.pt" in os.listdir(self.checkpoint_path): | |
latest_checkpoint = "model_last.pt" | |
else: | |
latest_checkpoint = sorted( | |
[f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")], | |
key=lambda x: int("".join(filter(str.isdigit, x))), | |
)[-1] | |
print(f"Loading checkpoint: {latest_checkpoint}") | |
checkpoint = torch.load( | |
f"{self.checkpoint_path}/{latest_checkpoint}", | |
weights_only=True, | |
map_location="cpu", | |
) | |
if "step" in checkpoint: | |
self.accelerator.unwrap_model(self.model).load_state_dict( | |
checkpoint["model_state_dict"] | |
) | |
self.accelerator.unwrap_model(self.optimizer).load_state_dict( | |
checkpoint["optimizer_state_dict"] | |
) | |
if hasattr(self, "scheduler") and checkpoint["scheduler_state_dict"]: | |
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
step = checkpoint["step"] | |
else: | |
self.accelerator.unwrap_model(self.model).load_state_dict( | |
checkpoint["model_state_dict"] | |
) | |
step = 0 | |
del checkpoint | |
gc.collect() | |
print(f"Successfully loaded checkpoint at step {step}") | |
return step | |
def get_ref_logps(self, text_ids, mel, sampled_classes): | |
""" | |
Get log probabilities from the reference model for the sampled classes | |
""" | |
B = text_ids.shape[0] | |
K = self.num_pre_samples | |
with torch.no_grad(): | |
ref_logits = self.ref_model(text_ids=text_ids, mel=mel)[:, -1, :] | |
ref_logits = ref_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) | |
ref_log_probs = F.log_softmax(ref_logits, dim=-1) | |
ref_logps = torch.gather( | |
ref_log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) | |
).squeeze(-1) | |
return ref_logps | |
def generate_duration_samples(self, batch_inputs): | |
""" | |
Generate multiple duration predictions from the model for each input | |
and evaluate them using the inference function and reward model | |
Args: | |
batch_inputs: Dictionary with text, prompt audio, etc. | |
Returns: | |
Dictionary with duration samples, rewards, and reference logits | |
""" | |
if self.rank == 0: | |
print("Generating duration samples...") | |
# all_logits = [] | |
all_text_ids = [] | |
all_mels = [] | |
all_sampled_classes = [] | |
all_durations = [] | |
all_rewards = [] | |
all_gen_logps = [] | |
all_ctc_loss = [] | |
all_sv_loss = [] | |
# Fetch batch inputs | |
# prompt_mel = batch_inputs['mel'].permute(0, 2, 1).to(self.device) | |
prompt_mel = batch_inputs["mel"].permute(0, 2, 1) # (B, T, 100) | |
prompt_text = batch_inputs["text"] | |
batch_size = prompt_mel.shape[0] | |
# Shift text to unpair 'mel' and 'text'; The shifted text will be synthesized | |
target_text = batch_inputs["target_text"] | |
target_text_lengths = torch.LongTensor([len(t) for t in target_text]).to( | |
prompt_mel.device | |
) | |
try: | |
full_text = [ | |
prompt + [" "] + target | |
for prompt, target in zip(prompt_text, target_text) | |
] | |
except: | |
target_text = [batch_inputs["text"][-1]] + batch_inputs["text"][:-1] | |
target_text_lengths = batch_inputs["text_lengths"].clone().roll(1, 0) | |
full_text = [ | |
prompt + [" "] + target | |
for prompt, target in zip(prompt_text, target_text) | |
] | |
# Goes to reward model | |
target_text_ids = list_str_to_idx(target_text, self.vocab_char_map).to( | |
self.accelerator.device | |
) # to device, the dataloader only gives list | |
# Goes to duration model and TTS | |
full_text_ids = list_str_to_idx(full_text, self.vocab_char_map).to( | |
self.accelerator.device | |
) | |
# Deepcopy to separate text_ids for SLP and TTS | |
slp_text_ids = full_text_ids.detach().clone() | |
slp_text_ids = slp_text_ids.masked_fill( | |
slp_text_ids == -1, self.vocab_size | |
) # (B, L) | |
# Pre-compute duration logits | |
K = self.num_pre_samples | |
B, T, _ = prompt_mel.shape | |
_, L = slp_text_ids.shape | |
# prompt_mel_k_repeats = prompt_mel.unsqueeze(1).repeat(1, K, 1, 1) # (B, K, T, 100) | |
# slp_text_ids_k_repeats = slp_text_ids.unsqueeze(1).repeat(1, K, 1) # (B, K, L) | |
# Run model once for B inputs | |
old_logits = self.model( | |
text_ids=slp_text_ids, mel=prompt_mel # (B, L) # (B, T, 100) | |
)[ | |
:, -1, : | |
] # (B, n_class) | |
# Repeat each result K times along batch dimension | |
old_logits = old_logits.unsqueeze(1).repeat(1, K, 1) # (B, K, n_class) | |
# logits_nograd = logits_grad.detach().clone().view(B, K, -1) # (B, K, n_class) | |
for ( | |
_full_text_ids, | |
_target_text_ids, | |
_target_text_lengths, | |
_prompt_mel, | |
_old_logits, | |
) in zip( | |
full_text_ids, target_text_ids, target_text_lengths, prompt_mel, old_logits | |
): | |
duration_sample = F.gumbel_softmax( | |
_old_logits, tau=self.gumbel_tau, hard=True, dim=-1 | |
) | |
duration2frames = ( | |
torch.arange(self.n_class).float().to(self.accelerator.device) | |
* self.n_frame_per_class | |
) | |
est_frames = (duration_sample * duration2frames).sum(-1) # (K, ) | |
# Compute log probabilities of the samples | |
sampled_classes = duration_sample.argmax(dim=-1) | |
log_probs = F.log_softmax(_old_logits, dim=-1) | |
gen_logps = torch.gather( | |
log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) | |
).squeeze( | |
-1 | |
) # Shape: [K, n_class] | |
# Generate speech using the sampled durations | |
sampled_rewards = [] | |
for i in range(K): | |
cur_duration = est_frames[i] | |
if cur_duration == 0: | |
cur_duration = cur_duration + 50 # prevent 0 duration | |
infer_full_text_ids = _full_text_ids.unsqueeze(0) | |
infer_prompt_mel = _prompt_mel.unsqueeze(0) | |
cur_duration = cur_duration.unsqueeze(0) | |
infer_target_text_ids = _target_text_ids.unsqueeze(0) | |
infer_target_text_lengths = _target_text_lengths.unsqueeze(0) | |
with torch.inference_mode(): | |
try: | |
_est_mel = self.inference_fn( | |
full_text_ids=infer_full_text_ids, | |
prompt_mel=infer_prompt_mel, | |
target_duration=cur_duration, | |
teacher_steps=0, | |
) | |
_est_mel = _est_mel.permute(0, 2, 1) # (1, T, 100) | |
loss_dict = self.reward_fn( | |
prompt_mel=infer_prompt_mel, | |
est_mel=_est_mel, | |
target_text_id=infer_target_text_ids, | |
target_text_length=infer_target_text_lengths, | |
) | |
# #TODO reweight the loss for reward | |
reward_sim = loss_dict["loss_sim"] # 0 to 1 | |
reward_ctc = loss_dict["loss_ctc"] | |
reward = -(reward_ctc + reward_sim * 3) | |
all_ctc_loss.append(reward_ctc) | |
all_sv_loss.append(reward_sim) | |
except Exception as e: | |
if self.rank == 0: | |
print(f"Error in speech synthesis: {e}") | |
reward = torch.tensor(-1.0).to(cur_duration.device) | |
sampled_rewards.append(reward) | |
# list with length of K | |
sampled_rewards = torch.stack(sampled_rewards) # (K, ) | |
# Normalize rewards | |
if (sampled_rewards.max() - sampled_rewards.min()).item() > 1e-6: | |
sampled_rewards = (sampled_rewards - sampled_rewards.mean()) / ( | |
sampled_rewards.std() + 1e-8 | |
) | |
# Store all data | |
# all_logits.append(duration_logits) | |
# all_text_ids.append(duration_input_expanded["text_ids"]) | |
# all_mels.append(duration_input_expanded["mel"]) | |
all_sampled_classes.append(sampled_classes) | |
all_durations.append(est_frames) | |
all_gen_logps.append(gen_logps) | |
all_rewards.extend(sampled_rewards) # list with length of B*K | |
# Concatenate all data | |
# logits = torch.cat(all_logits, dim=0) | |
# text_ids = torch.cat(all_text_ids, dim=0) | |
# mels = torch.cat(all_mels, dim=0) | |
sampled_classes = torch.cat(all_sampled_classes, dim=0) | |
durations = torch.cat(all_durations, dim=0) | |
rewards = torch.stack( | |
all_rewards | |
) # use stack to keep the same device of elements | |
gen_logps = torch.cat(all_gen_logps, dim=0) | |
ctc_losses = torch.stack(all_ctc_loss) | |
sv_losses = torch.stack(all_sv_loss) | |
if self.is_main: | |
self.accelerator.log( | |
{ | |
"ctc_loss": ctc_losses.mean().item(), | |
"sv_loss": sv_losses.mean().item(), | |
"reward": rewards.mean().item(), | |
"reward_min": rewards.min().item(), | |
"reward_max": rewards.max().item(), | |
}, | |
step=self.global_step, | |
) | |
# # Normalize rewards | |
# if (rewards.max() - rewards.min()).item() > 1e-6: | |
# rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | |
ref_logps = self.get_ref_logps(slp_text_ids, prompt_mel, sampled_classes) | |
# Create batch dict similar to Qwen2.5 implementation | |
batch_outputs = { | |
# "logits": logits_grad, | |
"text_ids": slp_text_ids, | |
"prompt_mel": prompt_mel, | |
"rewards": rewards, | |
"refs": ref_logps, | |
"sampled_classes": sampled_classes, | |
"durations": durations, | |
} | |
if self.compute_gen_logps: | |
batch_outputs["gen_logps"] = gen_logps | |
if self.rank == 0: | |
print( | |
f"Generated {len(rewards)} samples with reward min/mean/max: {rewards.min().item():.4f}/{rewards.mean().item():.4f}/{rewards.max().item():.4f}" | |
) | |
return batch_outputs | |
def GRPO_step(self, batch): | |
""" | |
Perform a GRPO update step | |
Args: | |
batch: Dictionary with inputs, rewards, reference logits, etc. | |
Returns: | |
Loss value | |
""" | |
# Extract batch data | |
# NOTE: why .unsqueeze(1) ??? | |
rewards = batch["rewards"] # .unsqueeze(1) | |
ref_logps = batch["refs"] # (B) | |
sampled_classes = batch["sampled_classes"] # (B) | |
prompt_mel = batch["prompt_mel"] | |
text_ids = batch["text_ids"] | |
# Forward pass to get current model logits | |
K = self.num_pre_samples | |
B, T, _ = prompt_mel.shape | |
_, L = text_ids.shape | |
cur_logits = self.model( | |
text_ids=text_ids, mel=prompt_mel # (B, L) # (B, T, 100) | |
)[:, -1, :] | |
cur_logits = cur_logits.unsqueeze(1).repeat(1, K, 1).view(B * K, -1) | |
# Compute current log probabilities for sampled actions | |
log_probs = F.log_softmax(cur_logits, dim=-1) | |
cur_logps = torch.gather( | |
log_probs, dim=-1, index=sampled_classes.unsqueeze(-1) | |
).squeeze( | |
-1 | |
) # (B) | |
# KL divergence computation (same as in Qwen2.5 code) | |
# KL = exp(ref - cur) - (ref - cur) - 1 | |
kl_div = torch.exp(ref_logps - cur_logps) - (ref_logps - cur_logps) - 1 # (B) | |
# Compute probability ratio for PPO | |
if "gen_logps" in batch: | |
gen_logps = batch["gen_logps"] | |
ratio = torch.exp(cur_logps - gen_logps) | |
clipped_ratio = torch.clamp(ratio, 1 - self.clip_param, 1 + self.clip_param) | |
loss = torch.min(ratio * rewards, clipped_ratio * rewards) | |
else: | |
# Simplification if gen_logps not available | |
loss = torch.exp(cur_logps - cur_logps.detach()) * rewards | |
# Final GRPO loss with KL regularization | |
loss = -(loss - self.beta * kl_div) # (B) | |
loss = loss.mean() | |
return loss | |
def get_batch(self): | |
"""Get a batch from the queue or return None if empty""" | |
if not self.batch_queue: | |
return None | |
return self.batch_queue.pop(0) | |
def generate_mode(self, num_batches=5): | |
""" | |
Generate samples and add them to the batch queue | |
Args: | |
dataset: Dataset to sample from | |
num_batches: Number of batches to generate | |
""" | |
if self.rank == 0: | |
print("Entering generate mode...") | |
tic = time.time() | |
for _ in range(num_batches): | |
try: | |
batch_inputs = next(self.train_iterator) | |
except StopIteration: | |
self.train_iterator = iter(self.train_dataloader) | |
batch_inputs = next(self.train_iterator) | |
# Generate samples and compute rewards | |
batch_outputs = self.generate_duration_samples(batch_inputs) | |
# Check if batch has sufficient reward diversity | |
rewards = batch_outputs["rewards"] | |
if (rewards.max() - rewards.min()).item() < 0.01: | |
if self.rank == 0: | |
print("Skipping batch with low reward diversity") | |
continue | |
# Add batch to queue | |
self.batch_queue.append(batch_outputs) | |
if self.rank == 0: | |
print(f"Exiting generate mode: {time.time() - tic:.3f}s") | |
def train( | |
self, train_dataset, valid_dataset=None, num_workers=64, resumable_with_seed=666 | |
): | |
""" | |
Train the model using GRPO | |
Args: | |
train_dataset: Training dataset | |
valid_dataset: Validation dataset (optional) | |
num_workers: Number of workers for data loading | |
""" | |
# Create training dataloader using the appropriate batching strategy | |
if self.batch_size_type == "sample": | |
self.train_dataloader = DataLoader( | |
train_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
batch_size=self.batch_size, | |
shuffle=True, | |
generator=generator, | |
) | |
# Create validation dataloader (always sequential, no shuffling) | |
self.valid_dataloader = DataLoader( | |
valid_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
batch_size=self.batch_size, | |
shuffle=False, | |
) | |
self.train_iterator = iter(self.train_dataloader) | |
self.valid_iterator = iter(self.valid_dataloader) | |
elif self.batch_size_type == "frame": | |
self.accelerator.even_batches = False | |
sampler = SequentialSampler(train_dataset) | |
batch_sampler = DynamicBatchSampler( | |
sampler, | |
self.batch_size, | |
max_samples=self.max_samples, | |
random_seed=resumable_with_seed, | |
drop_last=False, | |
) | |
self.train_dataloader = DataLoader( | |
train_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
batch_sampler=batch_sampler, | |
) | |
sampler = SequentialSampler(valid_dataset) | |
batch_sampler = DynamicBatchSampler( | |
sampler, | |
self.batch_size, | |
max_samples=self.max_samples, | |
random_seed=resumable_with_seed, | |
drop_last=False, | |
) | |
# Create validation dataloader (always sequential, no shuffling) | |
self.valid_dataloader = DataLoader( | |
valid_dataset, | |
collate_fn=collate_fn, | |
num_workers=num_workers, | |
pin_memory=True, | |
persistent_workers=True, | |
batch_sampler=batch_sampler, | |
) | |
self.train_dataloader, self.valid_dataloader = self.accelerator.prepare( | |
self.train_dataloader, self.valid_dataloader | |
) | |
self.train_iterator = iter(self.train_dataloader) | |
self.valid_iterator = iter(self.valid_dataloader) | |
else: | |
raise ValueError( | |
f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}" | |
) | |
# Setup schedulers | |
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes | |
total_steps = self.all_steps | |
decay_steps = total_steps - warmup_steps | |
warmup_scheduler = LinearLR( | |
self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps | |
) | |
decay_scheduler = LinearLR( | |
self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps | |
) | |
self.scheduler = SequentialLR( | |
self.optimizer, | |
schedulers=[warmup_scheduler, decay_scheduler], | |
milestones=[warmup_steps], | |
) | |
self.scheduler = self.accelerator.prepare(self.scheduler) | |
# Load checkpoint if available | |
start_step = self.load_checkpoint() | |
self.global_step = start_step | |
# Generate initial batches | |
self.generate_mode() | |
# Training loop | |
progress = range(1, self.all_steps + 1) | |
# Skip steps that are already done | |
progress = [step for step in progress if step > start_step] | |
if self.is_main: | |
progress = tqdm(progress, desc="Training", unit="step") | |
for step in progress: | |
# Get batch from queue or generate more | |
batch = self.get_batch() | |
while batch is None: | |
self.generate_mode() | |
batch = self.get_batch() | |
# GRPO update | |
with self.accelerator.accumulate(self.model): | |
loss = self.GRPO_step(batch) | |
# for param in self.model.parameters(): | |
# custom_loss = loss + 0 * param.sum() | |
self.accelerator.backward(loss) | |
if self.max_grad_norm > 0 and self.accelerator.sync_gradients: | |
total_norm = self.accelerator.clip_grad_norm_( | |
self.model.parameters(), self.max_grad_norm | |
) | |
else: | |
total_norm = torch.norm( | |
torch.stack( | |
[ | |
torch.norm(p.grad.detach(), 2) | |
for p in self.model.parameters() | |
if p.grad is not None | |
] | |
), | |
2, | |
) | |
self.accelerator.log( | |
{"grad_norm": total_norm.item()}, step=self.global_step | |
) | |
self.optimizer.step() | |
self.scheduler.step() | |
self.optimizer.zero_grad() | |
self.global_step += 1 | |
# Log metrics | |
if self.is_main: | |
self.accelerator.log( | |
{ | |
"loss": loss.item(), | |
"lr": self.scheduler.get_last_lr()[0], | |
# "avg_reward": batch["rewards"].mean().item(), | |
# "max_reward": batch["rewards"].max().item(), | |
# "min_reward": batch["rewards"].min().item(), | |
}, | |
step=self.global_step, | |
) | |
progress.set_postfix( | |
loss=f"{loss.item():.4f}", | |
lr=f"{self.scheduler.get_last_lr()[0]:.8f}", | |
) | |
# Save checkpoint | |
if self.global_step % self.save_per_updates == 0: | |
self.save_checkpoint(self.global_step) | |
# Optional validation logic could be added here | |
# Save final checkpoint | |
self.save_checkpoint(self.global_step, last=True) | |
self.accelerator.end_training() | |