Hanrui / progress /oom_fix_progress.md
Lekr0's picture
Add files using upload-large-folder tool
40d87dd verified

DFlash LoRA OOM 修复记录

OOM 根因分析

  1. SHARD_GRAD_OP (ZeRO-2) — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
  2. SDPA + 4D additive mask — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (bsz × 32heads × 2048 × 2048)
  3. 大 vocab logits[bsz, 2048, 151936] bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
  4. 机器只有 2 张 H100,脚本默认 NUM_GPUS=4

已完成的改动

1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)

  • 文件: SpecForge/scripts/train_dflash_lora.py:347
  • ShardingStrategy.SHARD_GRAD_OPShardingStrategy.FULL_SHARD
  • 效果: 参数跨卡分片,每卡省 ~8-12GB

2. 降 batch-size,提高 accumulation-steps

  • 文件: SpecForge/scripts/run_train_dflash_lora.sh
  • --batch-size 21--accumulation-steps 48
  • 效果: 等效 global batch size 不变,峰值显存减半

待验证 / 后续优化

  • 运行时传 bash run_train_dflash_lora.sh 2 确保用 2 卡
  • 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
  • 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback

3. flex_attention + BlockMask 替换 4D additive mask

  • 文件: SpecForge/specforge/core/dflash_lora.py, specforge/modeling/draft/dflash_lora.py, scripts/train_dflash_lora.py
  • 从非 LoRA 版 dflash.py 移植 _get_or_create_block_mask() 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
  • LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
  • --attention-backend flex_attention 启用 (默认),退回 --attention-backend additive 走原有 4D mask
  • HuggingFace model 用 attn_implementation="flex_attention" 加载
  • 效果: 不再 fallback 到 SDPA math backend,省去 [bsz, heads, seq, seq] attention scores 的显存

4. chunked cross-entropy loss

  • 文件: SpecForge/specforge/core/dflash_lora.py, specforge/modeling/draft/dflash_lora.py, scripts/train_dflash_lora.py
  • 从非 LoRA 版 dflash.py 移植 _chunked_lm_loss() 方法
  • 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 [bsz, seq, vocab] logits
  • --lm-head-chunk-size 256 启用 (默认 0 = 不启用)
  • DFlashLoRADraftModel.forward() 新增 output_hidden_states 参数,chunked 时返回 hidden states
  • 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)