DFlash LoRA OOM 修复记录
OOM 根因分析
- SHARD_GRAD_OP (ZeRO-2) — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
- SDPA + 4D additive mask — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (
bsz × 32heads × 2048 × 2048) - 大 vocab logits —
[bsz, 2048, 151936]bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB - 机器只有 2 张 H100,脚本默认
NUM_GPUS=4
已完成的改动
1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
- 文件:
SpecForge/scripts/train_dflash_lora.py:347 ShardingStrategy.SHARD_GRAD_OP→ShardingStrategy.FULL_SHARD- 效果: 参数跨卡分片,每卡省 ~8-12GB
2. 降 batch-size,提高 accumulation-steps
- 文件:
SpecForge/scripts/run_train_dflash_lora.sh --batch-size 2→1,--accumulation-steps 4→8- 效果: 等效 global batch size 不变,峰值显存减半
待验证 / 后续优化
- 运行时传
bash run_train_dflash_lora.sh 2确保用 2 卡 - 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
- 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
3. flex_attention + BlockMask 替换 4D additive mask
- 文件:
SpecForge/specforge/core/dflash_lora.py,specforge/modeling/draft/dflash_lora.py,scripts/train_dflash_lora.py - 从非 LoRA 版
dflash.py移植_get_or_create_block_mask()方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len) - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
- 用
--attention-backend flex_attention启用 (默认),退回--attention-backend additive走原有 4D mask - HuggingFace model 用
attn_implementation="flex_attention"加载 - 效果: 不再 fallback 到 SDPA math backend,省去
[bsz, heads, seq, seq]attention scores 的显存
4. chunked cross-entropy loss
- 文件:
SpecForge/specforge/core/dflash_lora.py,specforge/modeling/draft/dflash_lora.py,scripts/train_dflash_lora.py - 从非 LoRA 版
dflash.py移植_chunked_lm_loss()方法 - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整
[bsz, seq, vocab]logits - 用
--lm-head-chunk-size 256启用 (默认 0 = 不启用) DFlashLoRADraftModel.forward()新增output_hidden_states参数,chunked 时返回 hidden states- 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)