#!/usr/bin/env python3 """ Teacher-Student知识蒸馏脚本 将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型 """ import os import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, logging, ) from datasets import load_dataset, Dataset as HFDataset from peft import LoraConfig, get_peft_model, TaskType import numpy as np import wandb from typing import Dict, List, Any, Optional import json from tqdm import tqdm import warnings warnings.filterwarnings("ignore") logging.set_verbosity(logging.CRITICAL) class DistillationConfig: """蒸馏训练配置""" # 模型路径 teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型 student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型 # 蒸馏参数 temperature = 4.0 # 蒸馏温度 alpha = 0.7 # 蒸馏损失权重 beta = 0.3 # 学生损失权重 gamma = 0.1 # 特征匹配损失权重 # 训练参数 learning_rate = 1e-4 num_train_epochs = 3 per_device_train_batch_size = 2 per_device_eval_batch_size = 4 gradient_accumulation_steps = 8 warmup_ratio = 0.1 weight_decay = 0.01 logging_steps = 50 eval_steps = 500 save_steps = 1000 # LoRA配置(为Student模型添加LoRA以提高训练效率) use_lora = True lora_r = 32 lora_alpha = 64 lora_dropout = 0.1 # 数据配置 max_length = 512 num_distill_samples = 10000 # 用于蒸馏的样本数量 # 输出配置 output_dir = "./distilled_student_model" run_name = "teacher-student-distillation" class DistillationDataset(Dataset): """蒸馏数据集类""" def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512): self.data = teacher_outputs self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # 构建完整的输入-输出序列 full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}" # Tokenize encoded = self.tokenizer( full_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ) return { "input_ids": encoded["input_ids"].squeeze(), "attention_mask": encoded["attention_mask"].squeeze(), "teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float), "labels": encoded["input_ids"].squeeze() } class KnowledgeDistillationTrainer(Trainer): """知识蒸馏训练器""" def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs): super().__init__(model=student_model, **kwargs) self.teacher_model = teacher_model self.teacher_model.eval() # 冻结Teacher模型 self.temperature = temperature self.alpha = alpha # 蒸馏损失权重 self.beta = beta # 学生损失权重 self.gamma = gamma # 特征匹配损失权重 def compute_loss(self, model, inputs, return_outputs=False): """计算蒸馏损失""" labels = inputs.get("labels") teacher_logits = inputs.get("teacher_logits").to(model.device) # Student模型前向传播 student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]}) student_logits = student_outputs.logits # 计算各种损失 losses = {} # 1. 标准语言模型损失 (学生模型自己的损失) if labels is not None: shift_logits = student_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss() student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) losses["student_loss"] = student_loss # 2. 蒸馏损失 (KL散度) if teacher_logits is not None: # 确保维度匹配 if teacher_logits.shape != student_logits.shape: min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1]) teacher_logits = teacher_logits[:, :min_seq_len, :] student_logits_for_distill = student_logits[:, :min_seq_len, :] else: student_logits_for_distill = student_logits # 计算软标签概率 teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1) # KL散度损失 distill_loss = F.kl_div( student_log_probs, teacher_probs, reduction="batchmean" ) * (self.temperature ** 2) losses["distill_loss"] = distill_loss # 3. 组合总损失 total_loss = 0 if "student_loss" in losses: total_loss += self.beta * losses["student_loss"] if "distill_loss" in losses: total_loss += self.alpha * losses["distill_loss"] # 记录各项损失 self.log({ "train/total_loss": total_loss.item(), "train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0, "train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0, }) return (total_loss, student_outputs) if return_outputs else total_loss def prepare_student_model(config: DistillationConfig): """准备Student模型""" print("🎓 Preparing student model...") # 加载Student基础模型 student_model = AutoModelForCausalLM.from_pretrained( config.student_model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) # 添加LoRA(可选,用于高效训练) if config.use_lora: print("🔧 Adding LoRA to student model...") lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] ) student_model = get_peft_model(student_model, lora_config) student_model.print_trainable_parameters() return student_model def load_teacher_model(config: DistillationConfig): """加载Teacher模型""" print("👨‍🏫 Loading teacher model...") teacher_model = AutoModelForCausalLM.from_pretrained( config.teacher_model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) teacher_model.eval() return teacher_model def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig): """生成蒸馏数据""" print("📊 Generating distillation dataset...") # 加载提示数据集 dataset_sources = [ "smangrul/ad-copy-generation", # 可以添加更多数据源 ] all_prompts = [] for source in dataset_sources: try: ds = load_dataset(source, split="train") # 提取提示词 for item in ds: if "conversations" in item and len(item["conversations"]) > 0: prompt = item["conversations"][0].get("value", "") if len(prompt.strip()) > 10: all_prompts.append(prompt.strip()) except Exception as e: print(f"⚠️ Error loading {source}: {e}") # 限制样本数量 if len(all_prompts) > config.num_distill_samples: all_prompts = all_prompts[:config.num_distill_samples] print(f"📝 Generating responses for {len(all_prompts)} prompts...") distillation_data = [] teacher_model.eval() with torch.no_grad(): for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")): try: # 格式化输入 formatted_prompt = f"### Human: {prompt}\n### Assistant:" inputs = tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=config.max_length // 2 ).to(teacher_model.device) # 生成响应 outputs = teacher_model.generate( **inputs, max_new_tokens=200, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=True ) # 解码响应 generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() # 获取Teacher的logits full_text = f"### Human: {prompt}\n### Assistant: {response}" full_inputs = tokenizer( full_text, return_tensors="pt", truncation=True, max_length=config.max_length ).to(teacher_model.device) teacher_outputs = teacher_model(**full_inputs) teacher_logits = teacher_outputs.logits.cpu().numpy() distillation_data.append({ "prompt": prompt, "response": response, "teacher_logits": teacher_logits.tolist() }) # 定期保存中间结果 if (i + 1) % 100 == 0: print(f"Generated {i + 1}/{len(all_prompts)} samples") except Exception as e: print(f"⚠️ Error generating for prompt {i}: {e}") continue print(f"✅ Generated {len(distillation_data)} teacher-student pairs") # 保存蒸馏数据 with open("distillation_data.json", "w", encoding="utf-8") as f: json.dump(distillation_data, f, ensure_ascii=False, indent=2) return distillation_data def create_data_collator(tokenizer): """创建数据整理器""" return DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, pad_to_multiple_of=8 ) def run_distillation(): """主要的蒸馏训练流程""" print("🚀 Starting Teacher-Student Distillation...") config = DistillationConfig() # 初始化wandb wandb.init( project="teacher-student-distillation", config=vars(config), name=config.run_name ) # 加载tokenizer tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 加载模型 teacher_model = load_teacher_model(config) student_model = prepare_student_model(config) # 生成蒸馏数据 if os.path.exists("distillation_data.json"): print("📂 Loading existing distillation data...") with open("distillation_data.json", "r", encoding="utf-8") as f: distillation_data = json.load(f) else: distillation_data = generate_distillation_data(teacher_model, tokenizer, config) # 创建数据集 train_size = int(0.9 * len(distillation_data)) train_data = distillation_data[:train_size] eval_data = distillation_data[train_size:] train_dataset = DistillationDataset(train_data, tokenizer, config.max_length) eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length) print(f"📊 Training samples: {len(train_dataset)}") print(f"📊 Evaluation samples: {len(eval_dataset)}") # 训练参数 training_args = TrainingArguments( output_dir=config.output_dir, overwrite_output_dir=True, num_train_epochs=config.num_train_epochs, per_device_train_batch_size=config.per_device_train_batch_size, per_device_eval_batch_size=config.per_device_eval_batch_size, gradient_accumulation_steps=config.gradient_accumulation_steps, learning_rate=config.learning_rate, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, logging_steps=config.logging_steps, eval_steps=config.eval_steps, save_steps=config.save_steps, evaluation_strategy="steps", save_strategy="steps", load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, report_to="wandb", run_name=config.run_name, fp16=True, dataloader_pin_memory=False, remove_unused_columns=False, group_by_length=True, ) # 创建数据整理器 data_collator = create_data_collator(tokenizer) # 创建蒸馏训练器 trainer = KnowledgeDistillationTrainer( teacher_model=teacher_model, student_model=student_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator, tokenizer=tokenizer, temperature=config.temperature, alpha=config.alpha, beta=config.beta, gamma=config.gamma, ) # 开始训练 print("🔥 Starting distillation training...") trainer.train() # 保存最终模型 print("💾 Saving distilled student model...") trainer.save_model() tokenizer.save_pretrained(config.output_dir) # 评估模型 print("🧪 Evaluating distilled model...") evaluate_distilled_model(trainer.model, tokenizer, config) wandb.finish() print("✅ Distillation training completed!") def evaluate_distilled_model(model, tokenizer, config: DistillationConfig): """评估蒸馏后的模型""" print("📊 Evaluating distilled student model...") test_prompts = [ "Create an advertisement for a revolutionary AI-powered fitness tracker", "Write marketing copy for an eco-friendly electric vehicle", "Generate a slogan for a productivity app for remote workers", "Create ad copy for a sustainable fashion brand targeting millennials", "Write promotional content for a mental health app", ] model.eval() results = [] for prompt in test_prompts: formatted_prompt = f"### Human: {prompt}\n### Assistant:" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_text = response[len(formatted_prompt):].strip() results.append({ "prompt": prompt, "response": generated_text }) print(f"\n🔍 Prompt: {prompt}") print(f"📝 Student Response: {generated_text}") print("-" * 80) # 保存评估结果 with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) return results if __name__ == "__main__": # 设置环境变量 os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" os.environ["TOKENIZERS_PARALLELISM"] = "false" # 检查GPU if torch.cuda.is_available(): print(f"🔥 Using {torch.cuda.device_count()} GPUs") for i in range(torch.cuda.device_count()): print(f" GPU {i}: {torch.cuda.get_device_name(i)}") else: print("⚠️ Warning: No GPU available, using CPU (very slow)") # 开始蒸馏训练 run_distillation()