| # DFlash LoRA OOM 修复记录 |
|
|
| ## OOM 根因分析 |
|
|
| 1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片 |
| 2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`) |
| 3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB |
| 4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4` |
|
|
| ## 已完成的改动 |
|
|
| ### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3) |
| - 文件: `SpecForge/scripts/train_dflash_lora.py:347` |
| - `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD` |
| - 效果: 参数跨卡分片,每卡省 ~8-12GB |
|
|
| ### 2. 降 batch-size,提高 accumulation-steps |
| - 文件: `SpecForge/scripts/run_train_dflash_lora.sh` |
| - `--batch-size 2` → `1`,`--accumulation-steps 4` → `8` |
| - 效果: 等效 global batch size 不变,峰值显存减半 |
|
|
| ## 待验证 / 后续优化 |
|
|
| - [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡 |
| - [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize |
| - [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback |
|
|
| ### 3. flex_attention + BlockMask 替换 4D additive mask |
| - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py` |
| - 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len) |
| - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV) |
| - 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask |
| - HuggingFace model 用 `attn_implementation="flex_attention"` 加载 |
| - 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存 |
|
|
| ### 4. chunked cross-entropy loss |
| - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py` |
| - 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法 |
| - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits |
| - 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用) |
| - `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states |
| - 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size) |
| |