| ### 1. `train_dflash_lora.py` | |
| * 加了lora,原来是调用小模型,现在是hidden states+lora预测。 | |
| * `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。 | |
| ### 2. OOM优化 | |
| * 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。 | |
| * `batch-size=1`,`accumulation-steps=8`。 | |
| * 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。 | |
| * `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。 | |
| ### 运行 | |
| * bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2 |