| import os |
| import dataclasses |
| import torch |
| import transformers |
| from transformers import Trainer, TrainingArguments, TrainerCallback |
| from peft import LoraConfig, get_peft_model, TaskType |
| from huggingface_hub import HfApi, login |
| import wandb |
| from dotenv import load_dotenv |
| from config import TrainConfig, ModelConfig |
| from model import MultiModalModel |
| from data import AudioTextDataset, DataCollator |
|
|
|
|
| class SamplePredictionCallback(TrainerCallback): |
| """Every N steps, print ground-truth vs model-predicted transcript for a few samples.""" |
|
|
| def __init__(self, tokenizer, data_collator, train_dataset, sample_every_n_steps: int = 100, num_samples: int = 2, prompt: str = "Transcribe the following audio:"): |
| self.tokenizer = tokenizer |
| self.data_collator = data_collator |
| self.train_dataset = train_dataset |
| self.sample_every_n_steps = sample_every_n_steps |
| self.num_samples = num_samples |
| self.prompt = prompt |
| def on_log(self, args, state, control, model=None, **kwargs): |
| if state.global_step == 0 or state.global_step % self.sample_every_n_steps != 0: |
| return |
| if model is None: |
| return |
| model.eval() |
| device = next(model.parameters()).device |
| try: |
| indices = [i % len(self.train_dataset) for i in range(self.num_samples)] |
| batch = self.data_collator([self.train_dataset[i] for i in indices]) |
| audio_values = batch["audio_values"].to(device) |
| labels_batch = batch["labels"] |
| continuations = batch.get("continuation", [""] * audio_values.size(0)) |
| prompt_ids = self.tokenizer(self.prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) |
| prompt_ids = prompt_ids.expand(audio_values.size(0), -1) |
| with torch.no_grad(): |
| gen_ids = model.generate( |
| input_ids=prompt_ids, |
| audio_values=audio_values, |
| max_new_tokens=120, |
| do_sample=False, |
| pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, |
| ) |
| prompt_len = prompt_ids.size(1) |
| |
| |
| columns = ["Step", "Audio Index", "Ground Truth", "Prediction", "Continuation"] |
| table = wandb.Table(columns=columns) |
| |
| print(f"\n[WandB] Logging sample predictions at step {state.global_step}") |
| |
| for i in range(audio_values.size(0)): |
| gt_tokens = [t for t in labels_batch[i].tolist() if t != -100] |
| gt_text = self.tokenizer.decode(gt_tokens, skip_special_tokens=True).strip() |
| pred_text = self.tokenizer.decode(gen_ids[i][prompt_len:], skip_special_tokens=True).strip() |
| |
| cont_ref = continuations[i] if i < len(continuations) else "" |
| |
| |
| table.add_data(state.global_step, i, gt_text, pred_text, cont_ref) |
| |
| |
| if wandb.run is not None: |
| wandb.log({"sample_predictions": table}, step=state.global_step) |
| else: |
| print("Warning: WandB run not active, skipping logging.") |
|
|
| except Exception as e: |
| print(f"[SamplePredictionCallback] Error: {e}\n") |
| finally: |
| model.train() |
|
|
|
|
| import shutil |
| import glob |
| from transformers.trainer_utils import get_last_checkpoint |
|
|
| class AggressiveDeleteCallback(TrainerCallback): |
| """ |
| Deletes ALL existing checkpoints in output_dir *before* saving a new one |
| to ensure we don't run out of disk space. |
| Only keeps the one we are currently training on (in memory) effectively, |
| but on disk we want 0 checkpoints just before save. |
| |
| WARNING: If save fails, we have NO checkpoints on disk. Risk accepted by user. |
| """ |
| def __init__(self, output_dir): |
| self.output_dir = output_dir |
|
|
| def on_step_end(self, args, state, control, **kwargs): |
| |
| |
| if args.save_strategy == "steps" and args.save_steps > 0: |
| if state.global_step > 0 and state.global_step % args.save_steps == 0: |
| |
| print(f"\n[AggressiveDeleteCallback] Step {state.global_step}: Deleting old checkpoints to free space before saving...") |
| |
| |
| |
| |
| ckpts = glob.glob(os.path.join(self.output_dir, "checkpoint-*")) |
| for ckpt in ckpts: |
| try: |
| shutil.rmtree(ckpt) |
| print(f" Deleted {ckpt}") |
| except Exception as e: |
| print(f" Failed to delete {ckpt}: {e}") |
|
|
| def train(): |
| |
| load_dotenv() |
|
|
| |
| train_config = TrainConfig() |
| model_config = ModelConfig() |
| |
| |
| wandb.init( |
| project=train_config.wandb_project, |
| entity=train_config.wandb_entity, |
| name=train_config.wandb_run_name, |
| config=dataclasses.asdict(train_config), |
| ) |
|
|
| |
| |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.text_model_id) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| processor = transformers.AutoProcessor.from_pretrained(model_config.audio_model_id) |
| |
| |
| model = MultiModalModel(model_config) |
| |
| |
| if train_config.use_lora: |
| peft_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=train_config.lora_r, |
| lora_alpha=train_config.lora_alpha, |
| lora_dropout=train_config.lora_dropout, |
| target_modules=["q_proj", "v_proj"] |
| ) |
| model.llm = get_peft_model(model.llm, peft_config) |
| model.llm.print_trainable_parameters() |
| |
| |
| train_dataset = AudioTextDataset(train_config, processor, model_config, tokenizer) |
| data_collator = DataCollator(processor, tokenizer) |
| |
| |
| training_args = TrainingArguments( |
| output_dir=train_config.output_dir, |
| per_device_train_batch_size=train_config.batch_size, |
| gradient_accumulation_steps=train_config.accum_steps, |
| learning_rate=train_config.learning_rate, |
| lr_scheduler_type=train_config.lr_scheduler_type, |
| num_train_epochs=train_config.num_epochs, |
| max_steps=train_config.max_steps, |
| bf16=train_config.use_bf16, |
| gradient_checkpointing=train_config.gradient_checkpointing, |
| dataloader_num_workers=train_config.dataloader_num_workers, |
| dataloader_pin_memory=train_config.dataloader_pin_memory, |
| logging_steps=train_config.log_steps, |
| logging_first_step=True, |
| logging_nan_inf_filter=True, |
| save_steps=train_config.save_steps, |
| save_total_limit=train_config.save_total_limit, |
| eval_strategy="no", |
| remove_unused_columns=False, |
| report_to="wandb", |
| log_level="info", |
| log_level_replica="info", |
| ) |
|
|
| sample_callback = SamplePredictionCallback( |
| tokenizer=tokenizer, |
| data_collator=data_collator, |
| train_dataset=train_dataset, |
| sample_every_n_steps=train_config.sample_pred_every_steps, |
| num_samples=2, |
| prompt="Transcribe the following audio:", |
| ) |
| |
| aggressive_delete_callback = AggressiveDeleteCallback(train_config.output_dir) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| data_collator=data_collator, |
| callbacks=[sample_callback, aggressive_delete_callback], |
| ) |
|
|
| total_steps = train_config.max_steps |
| print(f"\n>>> Training: max_steps={total_steps}, batch_size={train_config.batch_size}, " |
| f"grad_accum={train_config.accum_steps} (effective batch={train_config.batch_size * train_config.accum_steps})") |
| print(f">>> Sample predictions (GT vs predicted transcript) every {train_config.sample_pred_every_steps} steps.\n") |
|
|
| |
| last_checkpoint = get_last_checkpoint(train_config.output_dir) |
| if last_checkpoint is not None: |
| print(f">>> Resuming from checkpoint: {last_checkpoint}") |
| trainer.train(resume_from_checkpoint=last_checkpoint) |
| else: |
| trainer.train() |
| |
| |
| trainer.save_model(train_config.output_dir) |
| tokenizer.save_pretrained(train_config.output_dir) |
| processor.save_pretrained(train_config.output_dir) |
|
|
| |
| if train_config.push_to_hub: |
| print(f"\n>>> Pushing model to Hugging Face Hub: {train_config.hub_model_id}") |
| if train_config.hub_token: |
| login(token=train_config.hub_token) |
| |
| api = HfApi() |
| |
| |
| |
| try: |
| api.create_repo(repo_id=train_config.hub_model_id, private=train_config.hub_private_repo, exist_ok=True) |
| except Exception as e: |
| print(f"Warning: Could not create repo {train_config.hub_model_id}. Error: {e}") |
| |
| |
| try: |
| api.upload_folder( |
| folder_path=train_config.output_dir, |
| repo_id=train_config.hub_model_id, |
| repo_type="model", |
| ) |
| |
| |
| for file in ["model.py", "config.py", "data.py", "inference.py"]: |
| if os.path.exists(file): |
| api.upload_file( |
| path_or_fileobj=file, |
| path_in_repo=file, |
| repo_id=train_config.hub_model_id, |
| repo_type="model", |
| ) |
|
|
| print(f">>> Successfully pushed to {train_config.hub_model_id}") |
| except Exception as e: |
| print(f"Error pushing to hub: {e}") |
|
|
| if __name__ == "__main__": |
| train() |
|
|