Support for attention_mask in forward pass.

This commit implements the following:

- Cleans up unused arguments and definitions.
- Adds support for `attention_mask`.
- Adds support for cached inference.
This commit is contained in:
Gustavo de Rosa 2023-09-26 18:17:08 +00:00 committed by huggingface-web
parent 4a426d8015
commit 3128bb636a
4 changed files with 300 additions and 324 deletions

@ -118,7 +118,7 @@ text = tokenizer.batch_decode(outputs)[0]
print(text) print(text)
``` ```
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1) and `attention_mask' parameters. **Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1).
Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's). Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
### Citation ### Citation

@ -1,13 +1,6 @@
{ {
"_name_or_path": "phi-1.5-half", "_name_or_path": "phi-1.5-half",
"activation_function": "gelu_new", "activation_function": "gelu_new",
"architecture": {
"block_cls": "parallel",
"mixer": {},
"mlp": {
"mlp_cls": "mlp"
}
},
"architectures": [ "architectures": [
"MixFormerSequentialForCausalLM" "MixFormerSequentialForCausalLM"
], ],
@ -15,7 +8,6 @@
"AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig", "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig",
"AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM" "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM"
}, },
"embd_layer": "default",
"embd_pdrop": 0.0, "embd_pdrop": 0.0,
"initializer_range": 0.02, "initializer_range": 0.02,
"layer_norm_epsilon": 1e-05, "layer_norm_epsilon": 1e-05,
@ -25,7 +17,6 @@
"n_inner": null, "n_inner": null,
"n_layer": 24, "n_layer": 24,
"n_positions": 2048, "n_positions": 2048,
"phyagi_version": "0.0.4.dev",
"resid_pdrop": 0.0, "resid_pdrop": 0.0,
"rotary_dim": 32, "rotary_dim": 32,
"tie_word_embeddings": false, "tie_word_embeddings": false,

@ -17,8 +17,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
"hidden_size": "n_embd", "hidden_size": "n_embd",
"num_attention_heads": "n_head", "num_attention_heads": "n_head",
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
"blocks": "architecture", # `blocks` key is for backward compatibility
} }
def __init__( def __init__(
@ -31,8 +29,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
n_head: Optional[int] = 16, n_head: Optional[int] = 16,
rotary_dim: Optional[int] = 32, rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new", activation_function: Optional[str] = "gelu_new",
embd_layer: Optional[str] = "default",
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
embd_pdrop: Optional[float] = 0.0, embd_pdrop: Optional[float] = 0.0,
resid_pdrop: Optional[float] = 0.0, resid_pdrop: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-5, layer_norm_epsilon: Optional[float] = 1e-5,
@ -49,8 +45,6 @@ class MixFormerSequentialConfig(PretrainedConfig):
self.n_head = n_head self.n_head = n_head
self.rotary_dim = min(rotary_dim, n_embd // n_head) self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function self.activation_function = activation_function
self.embd_layer = embd_layer
self.architecture = architecture
self.embd_pdrop = embd_pdrop self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon

@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
#
# BSD 3-Clause License # BSD 3-Clause License
# #
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
@ -50,16 +50,38 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass @dataclass
class InferenceParams: class InferenceParams:
"""Inference parameters that are passed to the main model in order """Inference parameters passed to model to efficiently calculate
to efficienly calculate and store the context during inference. and store context during inference.
Adapted from https://github.com/Dao-AILab/flash-attention."""
max_sequence_len: int Reference:
max_batch_size: int https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
sequence_len_offset: int = 0
batch_size_offset: int = 0 Args:
key_value_memory_dict: dict = field(default_factory=dict) max_sequence_len: Maximum sequence length.
fused_ft_kernel: bool = False max_batch_size: Maximum batch size.
lengths_per_sample: Optional[torch.Tensor] = None sequence_len_offset: Sequence length offset.
batch_size_offset: Batch size offset.
key_value_memory_dict: Key value memory dictionary.
fused_ft_kernel: Whether to use fused kernel for fast inference.
lengths_per_sample: Lengths per sample.
"""
max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
key_value_memory_dict: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Key value memory dictionary."}
)
fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
class Embedding(nn.Module): class Embedding(nn.Module):
@ -80,14 +102,19 @@ class Embedding(nn.Module):
return hidden_states return hidden_states
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer. """Rotary embeddings.
Adapted from https://github.com/Dao-AILab/flash-attention."""
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
"""
def __init__( def __init__(
self, self,
dim: int, dim: int,
base: Optional[int] = 10000, base: int = 10000,
scale_base: Optional[float] = None, scale_base: Optional[float] = None,
device: Optional[str] = None, device: Optional[str] = None,
**kwargs, **kwargs,
@ -119,7 +146,7 @@ class RotaryEmbedding(nn.Module):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None: def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
seqlen = x.shape[1] + seqlen_offset seqlen = x.shape[1] + seqlen_offset
@ -153,7 +180,7 @@ class RotaryEmbedding(nn.Module):
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def apply_rotary_emb_qkv( def _apply_rotary_emb_qkv(
self, self,
qkv: torch.FloatTensor, qkv: torch.FloatTensor,
sin: torch.FloatTensor, sin: torch.FloatTensor,
@ -189,7 +216,6 @@ class RotaryEmbedding(nn.Module):
# Computes the new keys and queries, recasting to original dtype # Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat( return torch.cat(
@ -202,47 +228,9 @@ class RotaryEmbedding(nn.Module):
) )
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform the forward pass. # `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
Args:
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
Returns:
New `qkv` and the cached sinusoids.
"""
self._update_cos_sin_cache(qkv, seqlen_offset) self._update_cos_sin_cache(qkv, seqlen_offset)
return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
num_heads, head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
class MLP(nn.Module): class MLP(nn.Module):
@ -267,17 +255,6 @@ class MLP(nn.Module):
self.fc2 = nn.Linear(n_inner, config.n_embd) self.fc2 = nn.Linear(n_inner, config.n_embd)
self.act = ACT2FN[act_fn] self.act = ACT2FN[act_fn]
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"]
new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"]
if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys):
# Older version of `MLP` saved with different key names.
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
@ -286,132 +263,114 @@ class MLP(nn.Module):
return hidden_states return hidden_states
class FusedMLP(nn.Module): class SelfAttention(nn.Module):
"""Fused Multi-Layer Perceptron from `flash-attn`. """Self-attention layer (compatible with PyTorch).
Reference: Reference:
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
""" """
def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None,
raise_on_missing: bool = False) -> None: def __init__(
self,
causal: bool = True,
softmax_scale: Optional[float] = None,
attention_dropout: float = 0.0,
) -> None:
super().__init__() super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"]
activation = "gelu_approx" if act_fn in gelu_activations else "relu"
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.mlp(hidden_states)
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal self.causal = causal
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
def forward(self, qkv, causal=None, key_padding_mask=None): def forward(
"""Implements the multihead softmax attention. self,
Arguments qkv: torch.FloatTensor,
--------- causal: bool = None,
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) attention_mask: Optional[torch.BoolTensor] = None,
causal: if passed, will override self.causal **kwargs,
key_padding_mask: boolean mask to apply to the attention weights. True means to keep, ) -> torch.FloatTensor:
False means to mask out. (B, S)
"""
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2) q, k, v = qkv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, if attention_mask is not None:
device=scores.device) padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0) padding_mask.masked_fill_(attention_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal: if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype) scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention) attention = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
output = torch.einsum("bhts,bshd->bthd", attention, v)
return output return output
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Cross-attention layer (compatible with PyTorch).
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments Reference:
--------- https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
def __init__(
self,
causal: bool = True,
softmax_scale: Optional[float] = None,
attention_dropout: float = 0.0,
) -> None:
super().__init__() super().__init__()
self.causal = causal self.causal = causal
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout) self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None): def forward(
"""Implements the multihead softmax attention. self,
Arguments q: torch.FloatTensor,
--------- kv: torch.FloatTensor,
q: The tensor containing the query. (B, Sq, H, D) causal: bool = None,
kv: The tensor containing the key and value. (B, Sk, 2, H, D) attention_mask: Optional[torch.BoolTensor] = None,
causal: if passed, will override self.causal **kwargs,
key_padding_mask: boolean mask to apply to the attention weights. True means to keep, ) -> torch.FloatTensor:
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1] batch_size, seq_len_q = q.shape[0], q.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
seq_len_k = kv.shape[1]
k, v = kv.unbind(dim=2) k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
if key_padding_mask is not None:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, if attention_mask is not None:
device=scores.device) padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0) padding_mask.masked_fill_(attention_mask, 0.0)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal: if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16' causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype) scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention) attention = self.drop(attention)
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
output = torch.einsum("bhts,bshd->bthd", attention, v)
return output return output
def find_mha_dims( def find_mha_dims(
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
) -> Tuple[int, int]: ) -> Tuple[int, int]:
@ -445,152 +404,163 @@ def find_mha_dims(
return n_head, head_dim return n_head, head_dim
def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
"""Update the key-value cache for inference.
Reference:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
Args:
kv: Key-value tensor.
inference_params: Inference parameters.
layer_idx: Layer index.
Returns:
Updated key-value tensor.
"""
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
2,
num_heads,
head_dim,
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
assert inference_params.sequence_len_offset == 0
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
return kv
class MHA(nn.Module): class MHA(nn.Module):
"""Multi-head attention layer. """Multi-head attention layer."""
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
dtype: Optional[torch.dtype] = None,
device: Optional[str] = None,
rotary_dim: Optional[int] = None, rotary_dim: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None,
n_head: Optional[int] = None, n_head: Optional[int] = None,
head_dim: Optional[int] = None, head_dim: Optional[int] = None,
bias: Optional[bool] = True, bias: bool = True,
dropout: Optional[float] = 0.0, causal: bool = True,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: Optional[bool] = True, dropout: float = 0.0,
layer_idx: Optional[int] = None, layer_idx: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None, return_residual: bool = False,
return_residual: Optional[bool] = False, checkpointing: bool = False,
checkpointing: Optional[bool] = False,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
fused_dense: Optional[bool] = True,
flash_attn: Optional[bool] = True,
cutlass_attn: Optional[bool] = False,
flash_rotary: Optional[bool] = True,
raise_on_missing: Optional[bool] = False
) -> None: ) -> None:
super().__init__() super().__init__()
factory_kwargs = {"device": device, "dtype": dtype} # Rotary embedding
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
self.hidden_size = config.n_embd
self.n_head = n_head
self.head_dim = head_dim
self.op_size = n_head * head_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
self.fused_dense = fused_dense
self.flash_attn = flash_attn
self.cutlass_attn = cutlass_attn
self.flash_rotary = flash_rotary
self.return_residual = return_residual
self.checkpointing = checkpointing
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
rotary_kwargs = {"device": device} rotary_kwargs = {"device": device}
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base rotary_kwargs["scale_base"] = rotary_emb_scale_base
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
else:
pass
self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs) # MLP
self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs) self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
op_size = self.n_head * self.head_dim
hidden_size = config.n_embd
self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
# Attention
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None: self.layer_idx = layer_idx
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) self.return_residual = return_residual
Adapted from https://github.com/Dao-AILab/flash-attention.""" self.checkpointing = checkpointing
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward( def forward(
self, self,
x: torch.FloatTensor, x: torch.FloatTensor,
x_kv: Optional[torch.FloatTensor] = None, past_key_values: Optional[InferenceParams] = None,
key_padding_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None, max_seqlen: Optional[int] = None,
mixer_subset: Optional[torch.LongTensor] = None, **kwargs,
past_cache: Optional[InferenceParams] = None,
**kwargs
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Perform the forward pass.
Args:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
past_cache: For generation only.
Returns:
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
else (total, hidden_dim) where total is the is the sum of the sequence lengths
in the batch.
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.flash_attn
assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.flash_attn
if past_cache is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
attn_kwargs = {"key_padding_mask": key_padding_mask}
assert x_kv is None and mixer_subset is None
qkv = self.Wqkv(x) qkv = self.Wqkv(x)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
if past_cache is None: seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
context = self.inner_attn(qkv, **attn_kwargs)
if past_key_values is not None:
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
if attention_mask is not None:
attention_mask, cu_seqlens, max_seqlen = attention_mask
attention_mask = attention_mask.to(qkv.device)
attention_kwargs = {"attention_mask": attention_mask}
if past_key_values is None or seqlen_offset == 0:
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
else:
attn_output = self.inner_attn(qkv, **attention_kwargs)
else: else:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
q = qkv[:, :, 0] q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache) causal = None if past_key_values.sequence_len_offset == 0 else False
# If we're processing the prompt, causal=None (use self.causal). attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
# If we're decoding, then causal=False.
causal = None if past_cache.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
out = rearrange(context, "... h d -> ... (h d)") output = rearrange(attn_output, "... h d -> ... (h d)")
out = self.out_proj(out) output = self.out_proj(output)
return output if not self.return_residual else (output, x)
return out if not self.return_residual else (out, x)
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
"""Parallel block. """Parallel block.
@ -602,8 +572,6 @@ class ParallelBlock(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
mixer: Optional[Dict[str, Any]] = None,
mlp: Optional[Dict[str, Any]] = None,
block_idx: Optional[int] = None, block_idx: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -612,19 +580,20 @@ class ParallelBlock(nn.Module):
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx self.block_idx = block_idx
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx) self.mixer = MHA(config, layer_idx=block_idx)
mlp_cls = mlp.pop('mlp_cls') self.mlp = MLP(config)
if mlp_cls == 'fused_mlp':
self.mlp = FusedMLP(config=config, **mlp)
else:
self.mlp = MLP(config=config, **mlp)
def forward(self, hidden_states: torch.FloatTensor, def forward(
past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: self,
hidden_states: torch.FloatTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln(hidden_states) hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(hidden_states, past_cache=past_cache) attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask)
if isinstance(attn_outputs, tuple): if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0] attn_outputs = attn_outputs[0]
@ -635,6 +604,7 @@ class ParallelBlock(nn.Module):
return hidden_states return hidden_states
class CausalLMHead(nn.Module): class CausalLMHead(nn.Module):
"""Causal Language Modeling head. """Causal Language Modeling head.
@ -666,7 +636,7 @@ class CausalLMLoss(nn.Module):
""" """
def __init__(self, shift_labels: Optional[bool] = True) -> None: def __init__(self, shift_labels: bool = True) -> None:
super().__init__() super().__init__()
self.shift_labels = shift_labels self.shift_labels = shift_labels
@ -681,6 +651,7 @@ class CausalLMLoss(nn.Module):
return loss return loss
class MixFormerSequentialPreTrainedModel(PreTrainedModel): class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"""MixFormer (sequential for DeepSpeed) pre-trained model.""" """MixFormer (sequential for DeepSpeed) pre-trained model."""
@ -691,9 +662,35 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs) -> None: def __init__(self, *inputs, **kwargs) -> None:
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]: def _init_weights(self, module: nn.Module) -> None:
if "use_cache" in kwargs and not kwargs["use_cache"]: if isinstance(module, (nn.Linear,)):
return {"input_ids": input_ids} module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
**kwargs,
) -> Dict[str, Any]:
if attention_mask is not None and torch.any(~attention_mask.bool()):
total_seq_len = torch.sum(attention_mask, dim=1)
max_seq_len = torch.max(total_seq_len)
total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
else:
attention_mask = None
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams( past_key_values = InferenceParams(
@ -705,11 +702,15 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
key_value_memory_dict={}, key_value_memory_dict={},
) )
else: else:
# assume past_key_values has cached all but last token in input_ids # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
past_key_values.sequence_len_offset = len(input_ids[0]) - 1 past_key_values.sequence_len_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
}
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
@ -723,23 +724,7 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
super().__init__(config) super().__init__(config)
modules = [Embedding(config)] modules = [Embedding(config)]
block_config = config.architecture modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]
if not isinstance(block_config, list):
block_config = [block_config for _ in range(config.n_layer)]
if config.n_layer != len(block_config):
config.n_layer = len(block_config)
for block_idx, block in enumerate(block_config):
# `block_cls` with `legacy` value is for backward compatibility
# `path` key is for backward compatibility
block = copy.deepcopy(block) or {"block_cls": "parallel"}
block_cls = block.pop("path", None) or block.pop("block_cls", None)
block["block_idx"] = block_idx
modules.append(ParallelBlock(config, **block))
modules.append(CausalLMHead(config)) modules.append(CausalLMHead(config))
self.layers = nn.Sequential(*modules) self.layers = nn.Sequential(*modules)
@ -760,16 +745,22 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
self.layers[-1].linear = new_embeddings self.layers[-1].linear = new_embeddings
def forward( def forward(
self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, self,
past_key_values: Optional[torch.FloatTensor] = None, **kwargs input_ids: torch.LongTensor,
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
attention_mask: Optional[torch.BoolTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
if attention_mask is not None and self.training:
raise ValueError("`attention_mask` is not supported during training.")
if not past_key_values: if past_key_values is None and attention_mask is None:
lm_logits = self.layers(input_ids) lm_logits = self.layers(input_ids)
else: else:
hidden_layer = self.layers[0](input_ids) hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]: for module in self.layers[1:-1]:
hidden_layer = module(hidden_layer, past_cache=past_key_values) hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
lm_logits = self.layers[-1](hidden_layer) lm_logits = self.layers[-1](hidden_layer)
loss = None loss = None