Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +2 -2
modeling_gpt_refact.py
CHANGED
@@ -337,9 +337,9 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
|
|
337 |
elif isinstance(module, LayerNormNoBias):
|
338 |
module.weight.data.fill_(1.0)
|
339 |
|
340 |
-
def _set_gradient_checkpointing(self, module,
|
341 |
if isinstance(module, GPTRefactModel):
|
342 |
-
module.gradient_checkpointing =
|
343 |
|
344 |
|
345 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
|
337 |
elif isinstance(module, LayerNormNoBias):
|
338 |
module.weight.data.fill_(1.0)
|
339 |
|
340 |
+
def _set_gradient_checkpointing(self, module, enable=False):
|
341 |
if isinstance(module, GPTRefactModel):
|
342 |
+
module.gradient_checkpointing = enable
|
343 |
|
344 |
|
345 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|