### 1. `train_dflash_lora.py` * 加了lora,原来是调用小模型,现在是hidden states+lora预测。 * `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。 ### 2. OOM优化 * 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。 * `batch-size=1`,`accumulation-steps=8`。 * 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。 * `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。 ### 运行 * bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2