diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index d8ab760..716cd1a 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -32,6 +32,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from __future__ import annotations +import importlib import math from typing import Any, Dict, Optional, Tuple, Union @@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_mixformer_sequential import MixFormerSequentialConfig +def _is_flash_attn_available() -> bool: + return importlib.util.find_spec("flash_attn") is not None + + +if _is_flash_attn_available(): + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + from flash_attn.ops.fused_dense import FusedDense +else: + FlashRotaryEmbedding = None + FusedDense = None + + @dataclass class InferenceParams: """Inference parameters passed to model to efficiently calculate @@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module): dim: int, base: int = 10000, scale_base: Optional[float] = None, + pos_idx_in_fp32: bool = True, device: Optional[str] = None, **kwargs, ) -> None: @@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module): if scale_base is not None: raise NotImplementedError - # Generate and save the inverse frequency buffer (non-trainable) self.dim = dim - self.base = base + self.base = float(base) self.scale_base = scale_base + self.pos_idx_in_fp32 = pos_idx_in_fp32 self.device = device - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + # Generate and save the inverse frequency buffer (non-trainable) + inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq, persistent=False) + # Generate and save the scale buffer (non-trainable) scale = ( (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base is not None @@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module): self._cos_k_cached = None self._sin_k_cached = None + def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + def _update_cos_sin_cache( self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None ) -> None: - # Re-generate the inverse frequency buffer if it's not fp32 - # (for instance if model.half() was called) - if self.inv_freq.dtype != "torch.float32": - self.inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) - ) - - if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + # Reset the tables if sequence length has been chaned, if we are on a + # new device or if we are switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): self._seq_len_cached = seqlen - t = torch.arange(seqlen, device=device, dtype=torch.float32) - # Don't do einsum, it converts fp32 to fp16 - # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32)) + # fp32 is preferred since the output of `torch.arange` can be quite large + # and bf16 would lose a lot of precision + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + + # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP + freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) @@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module): ) / self.scale_base scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # We want the multiplication by scale to happen in fp32 + # Force the scale multiplication to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) @@ -520,6 +550,8 @@ class MHA(nn.Module): causal: bool = True, softmax_scale: Optional[float] = None, dropout: float = 0.0, + flash_rotary: bool = True, + fused_dense: bool = True, layer_idx: Optional[int] = None, return_residual: bool = False, checkpointing: bool = False, @@ -532,15 +564,23 @@ class MHA(nn.Module): rotary_kwargs = {"device": device} if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: rotary_kwargs["scale_base"] = rotary_emb_scale_base - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) + + rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding + if rotary_cls is None: + rotary_cls = RotaryEmbedding + self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs) # MLP self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim) op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) hidden_size = config.n_embd - self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype) - self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) + linear_cls = FusedDense if fused_dense else nn.Linear + if linear_cls is None: + linear_cls = nn.Linear + + self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype) + self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) # Attention self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)