Adds support for MQA/GQA and attention mask during training.

This commit is contained in:
Gustavo de Rosa 2023-10-30 16:59:12 +00:00 committed by huggingface-web
parent d38e6f954e
commit de35f900d3
3 changed files with 262 additions and 208 deletions

@ -127,7 +127,7 @@ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
```
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
### Citation

@ -2,7 +2,7 @@
# Licensed under the MIT license.
import math
from typing import Any, Dict, List, Optional, Union
from typing import Optional
from transformers import PretrainedConfig
@ -27,6 +27,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
n_layer: Optional[int] = 20,
n_inner: Optional[int] = None,
n_head: Optional[int] = 16,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
embd_pdrop: Optional[float] = 0.0,
@ -43,6 +44,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.n_head_kv = n_head_kv
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.embd_pdrop = embd_pdrop

@ -34,20 +34,20 @@
from __future__ import annotations
import math
import copy
from typing import Any, Dict, Optional, Tuple, Union
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from einops import rearrange
from einops import rearrange, repeat
from transformers.activations import ACT2FN
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass
class InferenceParams:
"""Inference parameters passed to model to efficiently calculate
@ -57,21 +57,20 @@ class InferenceParams:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
Args:
max_sequence_len: Maximum sequence length.
max_seqlen: Maximum sequence length.
max_batch_size: Maximum batch size.
sequence_len_offset: Sequence length offset.
seqlen_offset: Sequence length offset.
batch_size_offset: Batch size offset.
key_value_memory_dict: Key value memory dictionary.
fused_ft_kernel: Whether to use fused kernel for fast inference.
lengths_per_sample: Lengths per sample.
"""
max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
@ -79,8 +78,6 @@ class InferenceParams:
default_factory=dict, metadata={"help": "Key value memory dictionary."}
)
fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
@ -103,12 +100,112 @@ class Embedding(nn.Module):
return hidden_states
def _apply_rotary_emb(
x: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
) -> torch.FloatTensor:
_, seqlen, _, head_dim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
x_rot = x[:, :, :, :rotary_dim]
x_pass = x[:, :, :, rotary_dim:]
x1, x2 = x_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
return torch.cat([x_rot, x_pass], axis=-1)
def _apply_rotary_emb_kv(
kv: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, two, _, head_dim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
k_rot = kv[:, :, 0, :, :rotary_dim]
k_pass = kv[:, :, 0, :, rotary_dim:]
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
return torch.cat(
[
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
kv[:, :, 1:2, :, :],
],
axis=2,
)
def _apply_rotary_emb_qkv(
qkv: torch.FloatTensor,
cos: torch.FloatTensor,
sin: torch.FloatTensor,
cos_k: Optional[torch.FloatTensor] = None,
sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, head_dim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
class RotaryEmbedding(nn.Module):
"""Rotary embeddings.
"""Rotary positional embedding (RoPE).
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
RoFormer: Enhanced Transformer with Rotary Position Embedding.
https://arxiv.org/pdf/2104.09864.pdf.
"""
def __init__(
@ -131,14 +228,14 @@ class RotaryEmbedding(nn.Module):
self.device = device
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
@ -146,28 +243,26 @@ class RotaryEmbedding(nn.Module):
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
seqlen = x.shape[1] + seqlen_offset
def _update_cos_sin_cache(
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
) -> None:
# Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
@ -175,62 +270,32 @@ class RotaryEmbedding(nn.Module):
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def _apply_rotary_emb_qkv(
def forward(
self,
qkv: torch.FloatTensor,
sin: torch.FloatTensor,
cos: torch.FloatTensor,
sin_k: Optional[torch.FloatTensor] = None,
cos_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: int = 0,
max_seqlen: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
seqlen = qkv.shape[1]
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
else:
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
if kv is None:
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
else:
q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
# Splits the queries and keys in half
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
# Casts to fp32 are necessary to prevent fp16 overflow issues
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
# Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
# `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
self._update_cos_sin_cache(qkv, seqlen_offset)
return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
return q, kv
class MLP(nn.Module):
@ -290,21 +355,22 @@ class SelfAttention(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> torch.FloatTensor:
causal = self.causal if causal is None else causal
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
causal = self.causal if causal is None else causal
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if attention_mask is not None:
padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(attention_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
@ -343,25 +409,31 @@ class CrossAttention(nn.Module):
attention_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> torch.FloatTensor:
causal = self.causal if causal is None else causal
batch_size, seq_len_q = q.shape[0], q.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
seq_len_k = kv.shape[1]
if kv.shape[3] != q.shape[2]:
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2)
causal = self.causal if causal is None else causal
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if attention_mask is not None:
padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(attention_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
causal_mask = cols > rows + seqlen_k - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention = self.drop(attention)
@ -371,21 +443,12 @@ class CrossAttention(nn.Module):
return output
def find_mha_dims(
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
def _find_mha_dims(
config: PretrainedConfig,
n_head: Optional[int] = None,
n_head_kv: Optional[int] = None,
head_dim: Optional[int] = None,
) -> Tuple[int, int]:
"""Validate and return the number of heads and head dimension for multi-head attention.
Args:
config: Model configuration.
n_head: Number of heads.
head_dim: Head dimension.
Returns:
Number of heads and head dimension.
"""
assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes."
@ -401,31 +464,20 @@ def find_mha_dims(
elif n_head is None or head_dim is None:
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
return n_head, head_dim
if n_head_kv is None:
n_head_kv = getattr(config, "n_head_kv", None) or n_head
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
return n_head, n_head_kv, head_dim
def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
"""Update the key-value cache for inference.
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
Args:
kv: Key-value tensor.
inference_params: Inference parameters.
layer_idx: Layer index.
Returns:
Updated key-value tensor.
"""
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
inference_params.max_seqlen,
2,
num_heads,
head_dim,
@ -434,43 +486,19 @@ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, la
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
kv_cache = inference_params.key_value_memory_dict[layer_idx]
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert batch_end <= kv_cache.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
assert sequence_end <= kv_cache.shape[1]
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
assert inference_params.sequence_len_offset == 0
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
@ -486,6 +514,7 @@ class MHA(nn.Module):
rotary_dim: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None,
n_head: Optional[int] = None,
n_head_kv: Optional[int] = None,
head_dim: Optional[int] = None,
bias: bool = True,
causal: bool = True,
@ -506,12 +535,12 @@ class MHA(nn.Module):
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
# MLP
self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
op_size = self.n_head * self.head_dim
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
hidden_size = config.n_embd
self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
@ -521,40 +550,75 @@ class MHA(nn.Module):
self.return_residual = return_residual
self.checkpointing = checkpointing
def _forward_self_attn(
self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
return self.inner_attn(qkv, attention_mask=attention_mask)
def _forward_cross_attn(
self,
x: torch.FloatTensor,
past_key_values: Optional[InferenceParams],
attention_mask: Optional[torch.BoolTensor],
) -> torch.FloatTensor:
qkv = self.Wqkv(x)
q = qkv[..., : self.n_head * self.head_dim]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = qkv[..., self.n_head * self.head_dim :]
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
causal = None if seqlen_offset == 0 else False
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
if past_key_values is not None:
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(
self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
)
return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
def forward(
self,
x: torch.FloatTensor,
past_key_values: Optional[InferenceParams] = None,
attention_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
qkv = self.Wqkv(x)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
if past_key_values is not None:
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
if attention_mask is not None:
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
attention_mask = attention_mask.bool().to(qkv.device)
attention_kwargs = {"attention_mask": attention_mask}
if past_key_values is None or seqlen_offset == 0:
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
else:
attn_output = self.inner_attn(qkv, **attention_kwargs)
if attention_mask is not None and torch.any(~attention_mask.bool()):
attention_mask = attention_mask.bool()
else:
q = qkv[:, :, 0]
causal = None if past_key_values.sequence_len_offset == 0 else False
attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
attention_mask = None
# MHA
if self.n_head == self.n_head_kv:
if past_key_values is None:
# If `past_key_values` are not supplied, we run self-attention
attn_output = self._forward_self_attn(x, attention_mask)
else:
# If `past_key_values` are supplied, it means that we might have cached values and
# could take advantage of cross-attention
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
# MQA / GQA
else:
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
# because `q` and `kv` lengths might be different
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
output = rearrange(attn_output, "... h d -> ... (h d)")
output = self.out_proj(output)
@ -672,38 +736,29 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs,
) -> Dict[str, Any]:
if attention_mask is not None and torch.any(~attention_mask.bool()):
total_seq_len = torch.sum(attention_mask, dim=1)
max_seq_len = torch.max(total_seq_len)
total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
else:
attention_mask = None
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams(
max_seqlen=self.config.n_positions,
max_batch_size=input_ids.shape[0],
max_sequence_len=self.config.n_positions,
sequence_len_offset=0,
seqlen_offset=0,
batch_size_offset=0,
fused_ft_kernel=False,
key_value_memory_dict={},
lengths_per_sample=None,
)
else:
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
past_key_values.seqlen_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
@ -712,9 +767,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"attention_mask": attention_mask,
}
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
if isinstance(module, MixFormerSequentialPreTrainedModel):
module.gradient_checkpointing = value
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@ -756,13 +811,10 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
if past_key_values is None and attention_mask is None:
lm_logits = self.layers(input_ids)
else:
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
lm_logits = self.layers[-1](hidden_layer)
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
lm_logits = self.layers[-1](hidden_layer)
loss = None
if labels is not None: