File size: 1,998 Bytes
c1c9e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/bin/bash
# launch_distillation.sh - 启动Teacher-Student蒸馏训练

echo "🎓 Starting Teacher-Student Distillation Training..."

# 检查前置条件
echo "📋 Checking prerequisites..."

# 检查Teacher模型
if [ ! -d "./rlhf_teacher_model" ]; then
    echo "❌ Error: RLHF Teacher model not found at ./rlhf_teacher_model"
    echo "   Please complete SFT and RLHF training first"
    exit 1
fi

# 检查GPU资源
echo "📊 GPU Resources:"
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv

# 检查可用显存
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"

if [ "$AVAILABLE_MEMORY" -lt 40000 ]; then
    echo "⚠️  Warning: Distillation training requires significant GPU memory (>40GB recommended)"
    echo "   Consider using gradient checkpointing or smaller batch sizes"
fi

# 设置环境变量
export CUDA_VISIBLE_DEVICES=0,1  # 根据可用GPU调整
export TOKENIZERS_PARALLELISM=false
export WANDB_PROJECT="teacher-student-distillation"
export WANDB_RUN_NAME="distillation-$(date +%Y%m%d_%H%M%S)"

# 创建输出目录
mkdir -p ./distilled_student_model
mkdir -p ./distillation_logs

# 检查是否有现有的蒸馏数据
if [ -f "./distillation_data.json" ]; then
    echo "📂 Found existing distillation data, will reuse it"
else
    echo "📊 Will generate new distillation data from teacher model"
fi

echo "🔥 Starting distillation training..."

# 启动训练
python teacher_student_distillation.py 2>&1 | tee ./distillation_logs/distillation_$(date +%Y%m%d_%H%M%S).log

echo "✅ Distillation training completed!"

# 训练后比较
echo "⚖️ Comparing Teacher vs Student performance..."
python compare_teacher_student.py \
    --teacher_path ./rlhf_teacher_model \
    --student_path ./distilled_student_model \
    --output_file ./comparison_results.json

echo "📊 Results saved to comparison_results.json"