Fixes flash-attn import with a try/except statement
This commit is contained in:
parent
0bbd68a176
commit
0254d42a95
@ -32,7 +32,6 @@
|
|||||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import importlib
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
|||||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||||
|
|
||||||
|
|
||||||
def _is_flash_attn_available() -> bool:
|
try:
|
||||||
return importlib.util.find_spec("flash_attn") is not None
|
|
||||||
|
|
||||||
|
|
||||||
if _is_flash_attn_available():
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
||||||
from flash_attn.ops.fused_dense import FusedDense
|
from flash_attn.ops.fused_dense import FusedDense
|
||||||
else:
|
except:
|
||||||
FlashRotaryEmbedding = None
|
FlashRotaryEmbedding = None
|
||||||
FusedDense = None
|
FusedDense = None
|
||||||
|
|
||||||
@ -549,9 +544,6 @@ class MHA(nn.Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softmax_scale: Optional[float] = None,
|
softmax_scale: Optional[float] = None,
|
||||||
dropout: float = 0.0,
|
|
||||||
flash_rotary: bool = True,
|
|
||||||
fused_dense: bool = True,
|
|
||||||
layer_idx: Optional[int] = None,
|
layer_idx: Optional[int] = None,
|
||||||
return_residual: bool = False,
|
return_residual: bool = False,
|
||||||
checkpointing: bool = False,
|
checkpointing: bool = False,
|
||||||
@ -565,7 +557,7 @@ class MHA(nn.Module):
|
|||||||
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
||||||
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
||||||
|
|
||||||
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
||||||
if rotary_cls is None:
|
if rotary_cls is None:
|
||||||
rotary_cls = RotaryEmbedding
|
rotary_cls = RotaryEmbedding
|
||||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
||||||
@ -575,7 +567,7 @@ class MHA(nn.Module):
|
|||||||
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
||||||
hidden_size = config.n_embd
|
hidden_size = config.n_embd
|
||||||
|
|
||||||
linear_cls = FusedDense if fused_dense else nn.Linear
|
linear_cls = FusedDense if config.fused_dense else nn.Linear
|
||||||
if linear_cls is None:
|
if linear_cls is None:
|
||||||
linear_cls = nn.Linear
|
linear_cls = nn.Linear
|
||||||
|
|
||||||
@ -583,8 +575,8 @@ class MHA(nn.Module):
|
|||||||
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Attention
|
# Attention
|
||||||
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
||||||
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop)
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.return_residual = return_residual
|
self.return_residual = return_residual
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user