diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index 716cd1a..b4efc53 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -32,7 +32,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from __future__ import annotations -import importlib import math from typing import Any, Dict, Optional, Tuple, Union @@ -49,14 +48,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_mixformer_sequential import MixFormerSequentialConfig -def _is_flash_attn_available() -> bool: - return importlib.util.find_spec("flash_attn") is not None - - -if _is_flash_attn_available(): +try: from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from flash_attn.ops.fused_dense import FusedDense -else: +except: FlashRotaryEmbedding = None FusedDense = None @@ -549,9 +544,6 @@ class MHA(nn.Module): bias: bool = True, causal: bool = True, softmax_scale: Optional[float] = None, - dropout: float = 0.0, - flash_rotary: bool = True, - fused_dense: bool = True, layer_idx: Optional[int] = None, return_residual: bool = False, checkpointing: bool = False, @@ -565,7 +557,7 @@ class MHA(nn.Module): if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: rotary_kwargs["scale_base"] = rotary_emb_scale_base - rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding + rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding if rotary_cls is None: rotary_cls = RotaryEmbedding self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs) @@ -575,7 +567,7 @@ class MHA(nn.Module): op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) hidden_size = config.n_embd - linear_cls = FusedDense if fused_dense else nn.Linear + linear_cls = FusedDense if config.fused_dense else nn.Linear if linear_cls is None: linear_cls = nn.Linear @@ -583,8 +575,8 @@ class MHA(nn.Module): self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) # Attention - self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=config.attn_pdrop) self.layer_idx = layer_idx self.return_residual = return_residual