Adds support for flash-attn rotary embedding and fused dense layers.

This commit is contained in:
Gustavo de Rosa 2023-11-01 20:40:12 +00:00 committed by huggingface-web
parent de35f900d3
commit 0bbd68a176

@ -32,6 +32,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import importlib
import math
from typing import Any, Dict, Optional, Tuple, Union
@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mixformer_sequential import MixFormerSequentialConfig
def _is_flash_attn_available() -> bool:
return importlib.util.find_spec("flash_attn") is not None
if _is_flash_attn_available():
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.ops.fused_dense import FusedDense
else:
FlashRotaryEmbedding = None
FusedDense = None
@dataclass
class InferenceParams:
"""Inference parameters passed to model to efficiently calculate
@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module):
dim: int,
base: int = 10000,
scale_base: Optional[float] = None,
pos_idx_in_fp32: bool = True,
device: Optional[str] = None,
**kwargs,
) -> None:
@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module):
if scale_base is not None:
raise NotImplementedError
# Generate and save the inverse frequency buffer (non-trainable)
self.dim = dim
self.base = base
self.base = float(base)
self.scale_base = scale_base
self.pos_idx_in_fp32 = pos_idx_in_fp32
self.device = device
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
# Generate and save the inverse frequency buffer (non-trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Generate and save the scale buffer (non-trainable)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module):
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
def _update_cos_sin_cache(
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
) -> None:
# Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
# Reset the tables if sequence length has been chaned, if we are on a
# new device or if we are switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
# fp32 is preferred since the output of `torch.arange` can be quite large
# and bf16 would lose a lot of precision
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module):
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
# Force the scale multiplication to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
@ -520,6 +550,8 @@ class MHA(nn.Module):
causal: bool = True,
softmax_scale: Optional[float] = None,
dropout: float = 0.0,
flash_rotary: bool = True,
fused_dense: bool = True,
layer_idx: Optional[int] = None,
return_residual: bool = False,
checkpointing: bool = False,
@ -532,15 +564,23 @@ class MHA(nn.Module):
rotary_kwargs = {"device": device}
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
if rotary_cls is None:
rotary_cls = RotaryEmbedding
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
# MLP
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
hidden_size = config.n_embd
self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
linear_cls = FusedDense if fused_dense else nn.Linear
if linear_cls is None:
linear_cls = nn.Linear
self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)