# DFlash LoRA OOM 修复记录 ## OOM 根因分析 1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片 2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`) 3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB 4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4` ## 已完成的改动 ### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3) - 文件: `SpecForge/scripts/train_dflash_lora.py:347` - `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD` - 效果: 参数跨卡分片,每卡省 ~8-12GB ### 2. 降 batch-size,提高 accumulation-steps - 文件: `SpecForge/scripts/run_train_dflash_lora.sh` - `--batch-size 2` → `1`,`--accumulation-steps 4` → `8` - 效果: 等效 global batch size 不变,峰值显存减半 ## 待验证 / 后续优化 - [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡 - [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize - [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback ### 3. flex_attention + BlockMask 替换 4D additive mask - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py` - 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len) - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV) - 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask - HuggingFace model 用 `attn_implementation="flex_attention"` 加载 - 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存 ### 4. chunked cross-entropy loss - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py` - 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法 - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits - 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用) - `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states - 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)