Adds support for flash-attn rotary embedding and fused dense layers.
This commit is contained in:
parent
de35f900d3
commit
0bbd68a176
@ -32,6 +32,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||
|
||||
|
||||
def _is_flash_attn_available() -> bool:
|
||||
return importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
|
||||
if _is_flash_attn_available():
|
||||
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
else:
|
||||
FlashRotaryEmbedding = None
|
||||
FusedDense = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters passed to model to efficiently calculate
|
||||
@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module):
|
||||
dim: int,
|
||||
base: int = 10000,
|
||||
scale_base: Optional[float] = None,
|
||||
pos_idx_in_fp32: bool = True,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module):
|
||||
if scale_base is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Generate and save the inverse frequency buffer (non-trainable)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.base = float(base)
|
||||
self.scale_base = scale_base
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
self.device = device
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
# Generate and save the inverse frequency buffer (non-trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Generate and save the scale buffer (non-trainable)
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None
|
||||
@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module):
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
||||
|
||||
def _update_cos_sin_cache(
|
||||
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> None:
|
||||
# Re-generate the inverse frequency buffer if it's not fp32
|
||||
# (for instance if model.half() was called)
|
||||
if self.inv_freq.dtype != "torch.float32":
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
||||
)
|
||||
|
||||
if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||
# Reset the tables if sequence length has been chaned, if we are on a
|
||||
# new device or if we are switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
|
||||
# Don't do einsum, it converts fp32 to fp16
|
||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
|
||||
# fp32 is preferred since the output of `torch.arange` can be quite large
|
||||
# and bf16 would lose a lot of precision
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
|
||||
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module):
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
# Force the scale multiplication to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
@ -520,6 +550,8 @@ class MHA(nn.Module):
|
||||
causal: bool = True,
|
||||
softmax_scale: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
flash_rotary: bool = True,
|
||||
fused_dense: bool = True,
|
||||
layer_idx: Optional[int] = None,
|
||||
return_residual: bool = False,
|
||||
checkpointing: bool = False,
|
||||
@ -532,15 +564,23 @@ class MHA(nn.Module):
|
||||
rotary_kwargs = {"device": device}
|
||||
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
|
||||
rotary_kwargs["scale_base"] = rotary_emb_scale_base
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
||||
|
||||
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
|
||||
if rotary_cls is None:
|
||||
rotary_cls = RotaryEmbedding
|
||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
||||
|
||||
# MLP
|
||||
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
||||
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
||||
hidden_size = config.n_embd
|
||||
|
||||
self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
linear_cls = FusedDense if fused_dense else nn.Linear
|
||||
if linear_cls is None:
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
||||
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
# Attention
|
||||
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user