From 8091327f9e72dee400973a48b008c3a9fd47b7f0 Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Tue, 17 Oct 2023 12:11:30 +0000 Subject: [PATCH] Adding _set_gradient_checkpointing for compatibility (#22) - Adding _set_gradient_checkpointing for compatibility (a30a931294ac0f344a0c1547877c692ceb17123c) Co-authored-by: Vicente Rivera --- modeling_mixformer_sequential.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index 7d4f722..01d2c97 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel): "past_key_values": past_key_values, "attention_mask": attention_mask, } + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MixFormerSequentialPreTrainedModel): + module.gradient_checkpointing = value class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):