Adding _set_gradient_checkpointing for compatibility

This commit is contained in:
Vicente Rivera 2023-09-18 22:17:40 -07:00
parent b6a7e2fe15
commit a30a931294

@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):