Hanrui / progress /oom_fix_progress.md
Lekr0's picture
Add files using upload-large-folder tool
40d87dd verified
# DFlash LoRA OOM 修复记录
## OOM 根因分析
1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
3. **大 vocab logits**`[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
## 已完成的改动
### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
- 文件: `SpecForge/scripts/train_dflash_lora.py:347`
- `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
- 效果: 参数跨卡分片,每卡省 ~8-12GB
### 2. 降 batch-size,提高 accumulation-steps
- 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
- `--batch-size 2``1``--accumulation-steps 4``8`
- 效果: 等效 global batch size 不变,峰值显存减半
## 待验证 / 后续优化
- [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
- [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
- [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
### 3. flex_attention + BlockMask 替换 4D additive mask
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
- 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
- LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
- 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
- HuggingFace model 用 `attn_implementation="flex_attention"` 加载
- 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
### 4. chunked cross-entropy loss
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
- 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
- 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
- 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
- `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
- 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)