#!/usr/bin/env python3 """ Teacher-Student模型性能比较脚本 比较RLHF Teacher模型和蒸馏后的Student模型的性能 """ import torch import argparse import json import time from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Dict, Any import numpy as np from datetime import datetime class ModelComparator: def __init__(self, teacher_path: str, student_path: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" print("📥 Loading Teacher model...") self.teacher_model = AutoModelForCausalLM.from_pretrained( teacher_path, torch_dtype=torch.float16, device_map="auto" ) self.teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_path) print("📥 Loading Student model...") self.student_model = AutoModelForCausalLM.from_pretrained( student_path, torch_dtype=torch.float16, device_map="auto" ) self.student_tokenizer = AutoTokenizer.from_pretrained(student_path) # 设置pad tokens for tokenizer in [self.teacher_tokenizer, self.student_tokenizer]: if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def generate_response(self, model, tokenizer, prompt: str, **kwargs) -> Dict[str, Any]: """生成响应并记录性能指标""" formatted_prompt = f"### Human: {prompt}\n### Assistant:" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) generation_config = { "max_new_tokens": 200, "temperature": 0.7, "top_p": 0.9, "do_sample": True, "pad_token_id": tokenizer.eos_token_id, **kwargs } # 测量生成时间 start_time = time.time() with torch.no_grad(): outputs = model.generate(**inputs, **generation_config) generation_time = time.time() - start_time # 解码响应 response = tokenizer.decode(outputs[0], skip_special_tokens=True) generated_text = response[len(formatted_prompt):].strip() # 计算tokens数量 generated_tokens = len(tokenizer.encode(generated_text)) return { "response": generated_text, "generation_time": generation_time, "tokens_generated": generated_tokens, "tokens_per_second": generated_tokens / generation_time if generation_time > 0 else 0, "prompt_tokens": inputs.input_ids.shape[1], "total_tokens": outputs.shape[1] } def calculate_model_size(self, model) -> Dict[str, Any]: """计算模型大小和参数量""" param_count = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # 估算模型大小(bytes) model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) model_size_mb = model_size_bytes / (1024 * 1024) model_size_gb = model_size_mb / 1024 return { "total_parameters": param_count, "trainable_parameters": trainable_params, "model_size_mb": model_size_mb, "model_size_gb": model_size_gb, "compression_ratio": None # 将在比较时计算 } def evaluate_quality_metrics(self, responses: List[str]) -> Dict[str, float]: """评估生成质量指标""" metrics = {} # 平均响应长度 avg_length = np.mean([len(resp.split()) for resp in responses]) metrics["avg_response_length"] = avg_length # 响应长度标准差 length_std = np.std([len(resp.split()) for resp in responses]) metrics["response_length_std"] = length_std # 词汇丰富度(使用type-token ratio的简化版本) all_words = [] for resp in responses: all_words.extend(resp.lower().split()) if all_words: unique_words = len(set(all_words)) total_words = len(all_words) metrics["vocabulary_richness"] = unique_words / total_words else: metrics["vocabulary_richness"] = 0.0 # 平均句子数量 avg_sentences = np.mean([resp.count('.') + resp.count('!') + resp.count('?') for resp in responses]) metrics["avg_sentences_per_response"] = avg_sentences return metrics def run_comprehensive_comparison(self) -> Dict[str, Any]: """运行全面的性能比较""" print("🔍 Running comprehensive Teacher-Student comparison...") # 测试提示词集合 test_prompts = [ # 广告文案生成 "Create an advertisement for a revolutionary smartphone with advanced AI features", "Write marketing copy for an eco-friendly electric vehicle targeting urban professionals", "Generate a catchy slogan for a fitness app that uses AI personal training", "Create promotional content for a sustainable fashion brand targeting Gen Z", "Write ad copy for a productivity software targeting remote workers", # 不同复杂度的任务 "Explain the benefits of renewable energy in simple terms", "Write a brief product description for wireless headphones with noise cancellation", "Create a social media post promoting a new coffee shop opening", "Generate marketing text for a luxury watch brand", "Write an email subject line for a summer sale promotion", # 创意任务 "Create a tagline for a travel app that focuses on sustainable tourism", "Write a short product pitch for smart home security system", "Generate advertising copy for a meal delivery service focusing on healthy options", "Create marketing content for an online learning platform", "Write promotional text for a mental wellness app" ] # 初始化结果收集 results = { "comparison_date": datetime.now().isoformat(), "test_prompts_count": len(test_prompts), "teacher_results": {}, "student_results": {}, "performance_comparison": {}, "detailed_responses": [] } # 获取模型信息 print("📊 Analyzing model specifications...") teacher_info = self.calculate_model_size(self.teacher_model) student_info = self.calculate_model_size(self.student_model)