File size: 1,979 Bytes
c1c9e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/bin/bash
# launch_rlhf.sh - 启动PPO RLHF训练

echo "🚀 Starting PPO RLHF Training..."

# 检查前置条件
echo "📋 Checking prerequisites..."

# 检查Teacher模型是否存在
if [ ! -d "./merged_model" ]; then
    echo "❌ Error: Teacher model not found at ./merged_model"
    echo "   Please run SFT training first and merge the model"
    exit 1
fi

# 检查GPU资源
echo "📊 GPU Resources:"
nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv

# 检查可用显存(建议至少80GB用于RLHF)
AVAILABLE_MEMORY=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits | awk '{sum+=$1} END {print sum}')
echo "Available GPU Memory: ${AVAILABLE_MEMORY} MB"

if [ "$AVAILABLE_MEMORY" -lt 80000 ]; then
    echo "⚠️  Warning: RLHF training requires significant GPU memory (>80GB recommended)"
    echo "   Consider using gradient checkpointing or smaller batch sizes"
fi

# 设置环境变量
export CUDA_VISIBLE_DEVICES=0,1,2,3  # 根据可用GPU调整
export TOKENIZERS_PARALLELISM=false
export WANDB_PROJECT="rlhf-teacher-training"
export WANDB_RUN_NAME="ppo-rlhf-$(date +%Y%m%d_%H%M%S)"

# 创建输出目录
mkdir -p ./rlhf_teacher_model
mkdir -p ./rlhf_logs

# 安装额外依赖
echo "📦 Installing RLHF dependencies..."
pip install -r rlhf_requirements.txt

# 启动训练
echo "🔥 Starting PPO RLHF training..."

# 单GPU训练
if [ "$1" = "single" ]; then
    CUDA_VISIBLE_DEVICES=0 python ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log

# 多GPU训练(推荐)
else
    accelerate launch \
        --config_file accelerate_config.yaml \
        --num_processes 4 \
        --main_process_port 29500 \
        ppo_rlhf_teacher.py 2>&1 | tee ./rlhf_logs/rlhf_$(date +%Y%m%d_%H%M%S).log
fi

echo "✅ RLHF training completed. Check logs for details."

# 训练后评估
echo "🧪 Running post-training evaluation..."
python evaluate_rlhf_model.py --model_path ./rlhf_teacher_model