diff --git a/modeling_phi.py b/modeling_phi.py index 50b6fd3..5fad744 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -47,10 +47,13 @@ from transformers.utils import ( from .configuration_phi import PhiConfig -try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -except: +try: # noqa: SIM105 + if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + # Workaround for https://github.com/huggingface/transformers/issues/28459, + # don't move to contextlib.suppress(ImportError) pass