Fixes flash-attn import with a try/except statement

This commit is contained in:
Gustavo de Rosa 2023-11-01 23:32:35 +00:00 committed by huggingface-web
parent 0bbd68a176
commit 0254d42a95

@ -32,7 +32,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations from __future__ import annotations
import importlib
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mixformer_sequential import MixFormerSequentialConfig from .configuration_mixformer_sequential import MixFormerSequentialConfig
def _is_flash_attn_available() -> bool: try:
return importlib.util.find_spec("flash_attn") is not None
if _is_flash_attn_available():
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.ops.fused_dense import FusedDense from flash_attn.ops.fused_dense import FusedDense
else: except:
FlashRotaryEmbedding = None FlashRotaryEmbedding = None
FusedDense = None FusedDense = None
@ -549,9 +544,6 @@ class MHA(nn.Module):
bias: bool = True, bias: bool = True,
causal: bool = True, causal: bool = True,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
dropout: float = 0.0,
flash_rotary: bool = True,
fused_dense: bool = True,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
return_residual: bool = False, return_residual: bool = False,
checkpointing: bool = False, checkpointing: bool = False,
@ -565,7 +557,7 @@ class MHA(nn.Module):
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base rotary_kwargs["scale_base"] = rotary_emb_scale_base
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
if rotary_cls is None: if rotary_cls is None:
rotary_cls = RotaryEmbedding rotary_cls = RotaryEmbedding
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs) self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
@ -575,7 +567,7 @@ class MHA(nn.Module):
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
hidden_size = config.n_embd hidden_size = config.n_embd
linear_cls = FusedDense if fused_dense else nn.Linear linear_cls = FusedDense if config.fused_dense else nn.Linear
if linear_cls is None: if linear_cls is None:
linear_cls = nn.Linear linear_cls = nn.Linear
@ -583,8 +575,8 @@ class MHA(nn.Module):
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention # Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.return_residual = return_residual self.return_residual = return_residual