Fixes flash-attn import with a try/except statement

This commit is contained in:
Gustavo de Rosa 2023-11-01 23:32:35 +00:00 committed by huggingface-web
parent 0bbd68a176
commit 0254d42a95

@ -32,7 +32,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import importlib
import math
from typing import Any, Dict, Optional, Tuple, Union
@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mixformer_sequential import MixFormerSequentialConfig
def _is_flash_attn_available() -> bool:
return importlib.util.find_spec("flash_attn") is not None
if _is_flash_attn_available():
try:
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.ops.fused_dense import FusedDense
else:
except:
FlashRotaryEmbedding = None
FusedDense = None
@ -549,9 +544,6 @@ class MHA(nn.Module):
bias: bool = True,
causal: bool = True,
softmax_scale: Optional[float] = None,
dropout: float = 0.0,
flash_rotary: bool = True,
fused_dense: bool = True,
layer_idx: Optional[int] = None,
return_residual: bool = False,
checkpointing: bool = False,
@ -565,7 +557,7 @@ class MHA(nn.Module):
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
if rotary_cls is None:
rotary_cls = RotaryEmbedding
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
@ -575,7 +567,7 @@ class MHA(nn.Module):
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
hidden_size = config.n_embd
linear_cls = FusedDense if fused_dense else nn.Linear
linear_cls = FusedDense if config.fused_dense else nn.Linear
if linear_cls is None:
linear_cls = nn.Linear
@ -583,8 +575,8 @@ class MHA(nn.Module):
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
self.layer_idx = layer_idx
self.return_residual = return_residual