Adding _set_gradient_checkpointing for compatibility (#22)

- Adding _set_gradient_checkpointing for compatibility (a30a931294ac0f344a0c1547877c692ceb17123c)


Co-authored-by: Vicente Rivera <vriveras@users.noreply.huggingface.co>
This commit is contained in:
Gustavo de Rosa 2023-10-17 12:11:30 +00:00 committed by system
parent b6a7e2fe15
commit 8091327f9e

@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):