Adding _set_gradient_checkpointing for compatibility

This commit is contained in:
Vicente Rivera 2023-09-18 22:17:40 -07:00
parent b6a7e2fe15
commit a30a931294

@ -712,6 +712,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"attention_mask": attention_mask,
}
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""