Adds support for MQA/GQA and attention mask during training.

This commit is contained in:
Gustavo de Rosa 2023-10-30 16:59:12 +00:00 committed by huggingface-web
parent d38e6f954e
commit de35f900d3
3 changed files with 262 additions and 208 deletions

@ -127,7 +127,7 @@ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
``` ```
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1). **Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's). Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
### Citation ### Citation

@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import math import math
from typing import Any, Dict, List, Optional, Union from typing import Optional
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -27,6 +27,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
n_layer: Optional[int] = 20, n_layer: Optional[int] = 20,
n_inner: Optional[int] = None, n_inner: Optional[int] = None,
n_head: Optional[int] = 16, n_head: Optional[int] = 16,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32, rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new", activation_function: Optional[str] = "gelu_new",
embd_pdrop: Optional[float] = 0.0, embd_pdrop: Optional[float] = 0.0,
@ -43,6 +44,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
self.n_layer = n_layer self.n_layer = n_layer
self.n_inner = n_inner self.n_inner = n_inner
self.n_head = n_head self.n_head = n_head
self.n_head_kv = n_head_kv
self.rotary_dim = min(rotary_dim, n_embd // n_head) self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function self.activation_function = activation_function
self.embd_pdrop = embd_pdrop self.embd_pdrop = embd_pdrop

@ -34,20 +34,20 @@
from __future__ import annotations from __future__ import annotations
import math import math
import copy
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange, repeat
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mixformer_sequential import MixFormerSequentialConfig from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass @dataclass
class InferenceParams: class InferenceParams:
"""Inference parameters passed to model to efficiently calculate """Inference parameters passed to model to efficiently calculate
@ -57,21 +57,20 @@ class InferenceParams:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
Args: Args:
max_sequence_len: Maximum sequence length. max_seqlen: Maximum sequence length.
max_batch_size: Maximum batch size. max_batch_size: Maximum batch size.
sequence_len_offset: Sequence length offset. seqlen_offset: Sequence length offset.
batch_size_offset: Batch size offset. batch_size_offset: Batch size offset.
key_value_memory_dict: Key value memory dictionary. key_value_memory_dict: Key value memory dictionary.
fused_ft_kernel: Whether to use fused kernel for fast inference.
lengths_per_sample: Lengths per sample. lengths_per_sample: Lengths per sample.
""" """
max_sequence_len: int = field(metadata={"help": "Maximum sequence length."}) max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
max_batch_size: int = field(metadata={"help": "Maximum batch size."}) max_batch_size: int = field(metadata={"help": "Maximum batch size."})
sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
@ -79,8 +78,6 @@ class InferenceParams:
default_factory=dict, metadata={"help": "Key value memory dictionary."} default_factory=dict, metadata={"help": "Key value memory dictionary."}
) )
fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
@ -103,11 +100,111 @@ class Embedding(nn.Module):
return hidden_states return hidden_states
def _apply_rotary_emb(
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
) -> torch.FloatTensor:
_, seqlen, _, head_dim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
x_rot = x[:, :, :, :rotary_dim]
x_pass = x[:, :, :, rotary_dim:]
x1, x2 = x_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
return torch.cat([x_rot, x_pass], axis=-1)
def _apply_rotary_emb_kv(
kv: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, two, _, head_dim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
k_rot = kv[:, :, 0, :, :rotary_dim]
k_pass = kv[:, :, 0, :, rotary_dim:]
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
return torch.cat(
[
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
kv[:, :, 1:2, :, :],
],
axis=2,
)
def _apply_rotary_emb_qkv(
qkv: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, head_dim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
"""Rotary embeddings. """Rotary positional embedding (RoPE).
Reference: Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py. RoFormer: Enhanced Transformer with Rotary Position Embedding.
https://arxiv.org/pdf/2104.09864.pdf.
""" """
@ -131,14 +228,14 @@ class RotaryEmbedding(nn.Module):
self.device = device self.device = device
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
scale = ( scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None if scale_base is not None
else None else None
) )
self.register_buffer("scale", scale) self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0 self._seq_len_cached = 0
self._cos_cached = None self._cos_cached = None
@ -146,28 +243,26 @@ class RotaryEmbedding(nn.Module):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None: def _update_cos_sin_cache(
# Reset the tables if the sequence length has changed, self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
# or if we're on a new device (possibly due to tracing for instance) ) -> None:
seqlen = x.shape[1] + seqlen_offset
# Re-generate the inverse frequency buffer if it's not fp32 # Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called) # (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32": if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / ( self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim) self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
) )
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=torch.float32) t = torch.arange(seqlen, device=device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16 # Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32)) freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
if self.scale is None: if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(dtype)
else: else:
power = ( power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
@ -175,62 +270,32 @@ class RotaryEmbedding(nn.Module):
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32 # We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def _apply_rotary_emb_qkv( def forward(
self, self,
qkv: torch.FloatTensor, qkv: torch.Tensor,
sin: torch.FloatTensor, kv: Optional[torch.Tensor] = None,
cos: torch.FloatTensor, seqlen_offset: int = 0,
sin_k: Optional[torch.FloatTensor] = None, max_seqlen: Optional[int] = None,
cos_k: Optional[torch.FloatTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.FloatTensor: seqlen = qkv.shape[1]
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape if max_seqlen is not None:
rotary_dim *= 2 self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
assert rotary_dim <= headdim else:
assert seqlen <= rotary_seqlen self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
cos_k = cos if cos_k is None else cos_k if kv is None:
sin_k = sin if sin_k is None else sin_k return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) else:
q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
q_rot = qkv[:, :, 0, :, :rotary_dim] return q, kv
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
# Splits the queries and keys in half
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
# Casts to fp32 are necessary to prevent fp16 overflow issues
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
# Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
# `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
self._update_cos_sin_cache(qkv, seqlen_offset)
return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
class MLP(nn.Module): class MLP(nn.Module):
@ -290,21 +355,22 @@ class SelfAttention(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
causal = self.causal if causal is None else causal batch_size, seqlen = qkv.shape[0], qkv.shape[1]
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
causal = self.causal if causal is None else causal
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if attention_mask is not None: if attention_mask is not None:
padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device) padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(attention_mask, 0.0) padding_mask.masked_fill_(attention_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal: if causal:
causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1) causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype) scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
@ -343,25 +409,31 @@ class CrossAttention(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
**kwargs, **kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
causal = self.causal if causal is None else causal batch_size, seqlen_q = q.shape[0], q.shape[1]
batch_size, seq_len_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
seq_len_k = kv.shape[1] if kv.shape[3] != q.shape[2]:
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2) k, v = kv.unbind(dim=2)
causal = self.causal if causal is None else causal
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if attention_mask is not None: if attention_mask is not None:
padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device) padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(attention_mask, 0.0) padding_mask.masked_fill_(attention_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal: if causal:
causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1) rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
scores = scores + causal_mask.to(dtype=scores.dtype) cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
causal_mask = cols > rows + seqlen_k - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention = self.drop(attention) attention = self.drop(attention)
@ -371,21 +443,12 @@ class CrossAttention(nn.Module):
return output return output
def find_mha_dims( def _find_mha_dims(
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None config: PretrainedConfig,
n_head: Optional[int] = None,
n_head_kv: Optional[int] = None,
head_dim: Optional[int] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Validate and return the number of heads and head dimension for multi-head attention.
Args:
config: Model configuration.
n_head: Number of heads.
head_dim: Head dimension.
Returns:
Number of heads and head dimension.
"""
assert all( assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"] hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes." ), "`config` must have `n_embd` and `n_head` attributes."
@ -401,31 +464,20 @@ def find_mha_dims(
elif n_head is None or head_dim is None: elif n_head is None or head_dim is None:
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
return n_head, head_dim if n_head_kv is None:
n_head_kv = getattr(config, "n_head_kv", None) or n_head
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
return n_head, n_head_kv, head_dim
def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor: def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
"""Update the key-value cache for inference.
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
Args:
kv: Key-value tensor.
inference_params: Inference parameters.
layer_idx: Layer index.
Returns:
Updated key-value tensor.
"""
num_heads, head_dim = kv.shape[-2:] num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict: if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty( kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_batch_size,
inference_params.max_sequence_len, inference_params.max_seqlen,
2, 2,
num_heads, num_heads,
head_dim, head_dim,
@ -434,43 +486,19 @@ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, la
) )
inference_params.key_value_memory_dict[layer_idx] = kv_cache inference_params.key_value_memory_dict[layer_idx] = kv_cache
else: else:
if not inference_params.fused_ft_kernel: kv_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) assert batch_end <= kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) assert sequence_end <= kv_cache.shape[1]
if not inference_params.fused_ft_kernel: assert kv_cache is not None
assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
assert inference_params.sequence_len_offset == 0
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
return kv return kv
@ -486,6 +514,7 @@ class MHA(nn.Module):
rotary_dim: Optional[int] = None, rotary_dim: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None, rotary_emb_scale_base: Optional[float] = None,
n_head: Optional[int] = None, n_head: Optional[int] = None,
n_head_kv: Optional[int] = None,
head_dim: Optional[int] = None, head_dim: Optional[int] = None,
bias: bool = True, bias: bool = True,
causal: bool = True, causal: bool = True,
@ -506,12 +535,12 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
# MLP # MLP
self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim) self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
op_size = self.n_head * self.head_dim op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
hidden_size = config.n_embd hidden_size = config.n_embd
self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype) self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype) self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention # Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
@ -521,40 +550,75 @@ class MHA(nn.Module):
self.return_residual = return_residual self.return_residual = return_residual
self.checkpointing = checkpointing self.checkpointing = checkpointing
def _forward_self_attn(
self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
return self.inner_attn(qkv, attention_mask=attention_mask)
def _forward_cross_attn(
self,
x: torch.FloatTensor,
past_key_values: Optional[InferenceParams],
attention_mask: Optional[torch.BoolTensor],
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
q = qkv[..., : self.n_head * self.head_dim]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = qkv[..., self.n_head * self.head_dim :]
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
causal = None if seqlen_offset == 0 else False
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
if past_key_values is not None:
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(
self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
)
return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
def forward( def forward(
self, self,
x: torch.FloatTensor, x: torch.FloatTensor,
past_key_values: Optional[InferenceParams] = None, past_key_values: Optional[InferenceParams] = None,
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
qkv = self.Wqkv(x) if attention_mask is not None and torch.any(~attention_mask.bool()):
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) attention_mask = attention_mask.bool()
seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if past_key_values is not None:
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
if attention_mask is not None:
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
attention_mask = attention_mask.bool().to(qkv.device)
attention_kwargs = {"attention_mask": attention_mask}
if past_key_values is None or seqlen_offset == 0:
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
else:
attn_output = self.inner_attn(qkv, **attention_kwargs)
else: else:
q = qkv[:, :, 0] attention_mask = None
causal = None if past_key_values.sequence_len_offset == 0 else False
attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs) # MHA
if self.n_head == self.n_head_kv:
if past_key_values is None:
# If `past_key_values` are not supplied, we run self-attention
attn_output = self._forward_self_attn(x, attention_mask)
else:
# If `past_key_values` are supplied, it means that we might have cached values and
# could take advantage of cross-attention
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
# MQA / GQA
else:
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
# because `q` and `kv` lengths might be different
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
output = rearrange(attn_output, "... h d -> ... (h d)") output = rearrange(attn_output, "... h d -> ... (h d)")
output = self.out_proj(output) output = self.out_proj(output)
@ -672,38 +736,29 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_() if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if attention_mask is not None and torch.any(~attention_mask.bool()):
total_seq_len = torch.sum(attention_mask, dim=1)
max_seq_len = torch.max(total_seq_len)
total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
else:
attention_mask = None
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams( past_key_values = InferenceParams(
max_seqlen=self.config.n_positions,
max_batch_size=input_ids.shape[0], max_batch_size=input_ids.shape[0],
max_sequence_len=self.config.n_positions, seqlen_offset=0,
sequence_len_offset=0,
batch_size_offset=0, batch_size_offset=0,
fused_ft_kernel=False,
key_value_memory_dict={}, key_value_memory_dict={},
lengths_per_sample=None,
) )
else: else:
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
past_key_values.sequence_len_offset = len(input_ids[0]) - 1 past_key_values.seqlen_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return { return {
@ -712,9 +767,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
if isinstance(module, MixFormerSequentialPreTrainedModel): if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@ -756,13 +811,10 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
if past_key_values is None and attention_mask is None: hidden_layer = self.layers[0](input_ids)
lm_logits = self.layers(input_ids) for module in self.layers[1:-1]:
else: hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
hidden_layer = self.layers[0](input_ids) lm_logits = self.layers[-1](hidden_layer)
for module in self.layers[1:-1]:
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
lm_logits = self.layers[-1](hidden_layer)
loss = None loss = None
if labels is not None: if labels is not None: