diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index 7d4f722..01d2c97 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel): "past_key_values": past_key_values, "attention_mask": attention_mask, } + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MixFormerSequentialPreTrainedModel): + module.gradient_checkpointing = value class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):