From 426ea900b06746a7c907c5594fe85f16f5e1f3f8 Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Mon, 15 Jan 2024 14:26:10 +0000 Subject: [PATCH] Update modeling_phi.py --- modeling_phi.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modeling_phi.py b/modeling_phi.py index 50b6fd3..5fad744 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -47,10 +47,13 @@ from transformers.utils import ( from .configuration_phi import PhiConfig -try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -except: +try: # noqa: SIM105 + if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + # Workaround for https://github.com/huggingface/transformers/issues/28459, + # don't move to contextlib.suppress(ImportError) pass