Fixes flash-attn import with a try/except statement
This commit is contained in:
parent
0bbd68a176
commit
0254d42a95
@ -32,7 +32,6 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||
|
||||
|
||||
def _is_flash_attn_available() -> bool:
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
|
||||
if _is_flash_attn_available():
|
||||
try:
|
||||
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
else:
|
||||
except:
|
||||
FlashRotaryEmbedding = None
|
||||
FusedDense = None
|
||||
|
||||
@ -549,9 +544,6 @@ class MHA(nn.Module):
|
||||
bias: bool = True,
|
||||
causal: bool = True,
|
||||
softmax_scale: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
flash_rotary: bool = True,
|
||||
fused_dense: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
return_residual: bool = False,
|
||||
checkpointing: bool = False,
|
||||
@ -565,7 +557,7 @@ class MHA(nn.Module):
|
||||
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
||||
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
||||
|
||||
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
||||
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
||||
if rotary_cls is None:
|
||||
rotary_cls = RotaryEmbedding
|
||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
||||
@ -575,7 +567,7 @@ class MHA(nn.Module):
|
||||
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
||||
hidden_size = config.n_embd
|
||||
|
||||
linear_cls = FusedDense if fused_dense else nn.Linear
|
||||
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
||||
if linear_cls is None:
|
||||
linear_cls = nn.Linear
|
||||
|
||||
@ -583,8 +575,8 @@ class MHA(nn.Module):
|
||||
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
# Attention
|
||||
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
||||
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.return_residual = return_residual
|
||||
|
||||
Loading…
Reference in New Issue
Block a user