Adding _set_gradient_checkpointing for compatibility
This commit is contained in:
parent
b6a7e2fe15
commit
a30a931294
@ -712,6 +712,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
||||
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user