Adding _set_gradient_checkpointing for compatibility (#22)
- Adding _set_gradient_checkpointing for compatibility (a30a931294ac0f344a0c1547877c692ceb17123c) Co-authored-by: Vicente Rivera <vriveras@users.noreply.huggingface.co>
This commit is contained in:
parent
b6a7e2fe15
commit
8091327f9e
@ -712,6 +712,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
||||||
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user