1. train_dflash_lora.py
- 加了lora,原来是调用小模型,现在是hidden states+lora预测。
dflash_lora_mask_fn函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
2. OOM优化
- 分片策略ZeRO-3,FSDP切分从
SHARD_GRAD_OP升级到FULL_SHARD。 batch-size=1,accumulation-steps=8。- 参考之前的代码用了FlexAttention(
dflash_lora_mask_fn)。 _chunked_lm_loss(),把算loss切片成256块来算+梯度检查。
运行
- bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2