diff --git a/README.md b/README.md index 74b23c0..ded381e 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ text = tokenizer.batch_decode(outputs)[0] print(text) ``` -**Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1) and `attention_mask' parameters. +**Remark.** In the generation function, our model currently does not support beam search (`num_beams` >1). Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's). ### Citation diff --git a/config.json b/config.json index 4668344..c2b5ff8 100644 --- a/config.json +++ b/config.json @@ -1,13 +1,6 @@ { "_name_or_path": "phi-1.5-half", "activation_function": "gelu_new", - "architecture": { - "block_cls": "parallel", - "mixer": {}, - "mlp": { - "mlp_cls": "mlp" - } - }, "architectures": [ "MixFormerSequentialForCausalLM" ], @@ -15,7 +8,6 @@ "AutoConfig": "configuration_mixformer_sequential.MixFormerSequentialConfig", "AutoModelForCausalLM": "modeling_mixformer_sequential.MixFormerSequentialForCausalLM" }, - "embd_layer": "default", "embd_pdrop": 0.0, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, @@ -25,7 +17,6 @@ "n_inner": null, "n_layer": 24, "n_positions": 2048, - "phyagi_version": "0.0.4.dev", "resid_pdrop": 0.0, "rotary_dim": 32, "tie_word_embeddings": false, diff --git a/configuration_mixformer_sequential.py b/configuration_mixformer_sequential.py index 2d2e42c..8cc2d51 100644 --- a/configuration_mixformer_sequential.py +++ b/configuration_mixformer_sequential.py @@ -17,8 +17,6 @@ class MixFormerSequentialConfig(PretrainedConfig): "hidden_size": "n_embd", "num_attention_heads": "n_head", "num_hidden_layers": "n_layer", - "input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility - "blocks": "architecture", # `blocks` key is for backward compatibility } def __init__( @@ -31,8 +29,6 @@ class MixFormerSequentialConfig(PretrainedConfig): n_head: Optional[int] = 16, rotary_dim: Optional[int] = 32, activation_function: Optional[str] = "gelu_new", - embd_layer: Optional[str] = "default", - architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None, embd_pdrop: Optional[float] = 0.0, resid_pdrop: Optional[float] = 0.0, layer_norm_epsilon: Optional[float] = 1e-5, @@ -49,8 +45,6 @@ class MixFormerSequentialConfig(PretrainedConfig): self.n_head = n_head self.rotary_dim = min(rotary_dim, n_embd // n_head) self.activation_function = activation_function - self.embd_layer = embd_layer - self.architecture = architecture self.embd_pdrop = embd_pdrop self.resid_pdrop = resid_pdrop self.layer_norm_epsilon = layer_norm_epsilon diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index 5e3db86..d22bbc2 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. - +# # BSD 3-Clause License # # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. @@ -50,16 +50,38 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig @dataclass class InferenceParams: - """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference. - Adapted from https://github.com/Dao-AILab/flash-attention.""" - max_sequence_len: int - max_batch_size: int - sequence_len_offset: int = 0 - batch_size_offset: int = 0 - key_value_memory_dict: dict = field(default_factory=dict) - fused_ft_kernel: bool = False - lengths_per_sample: Optional[torch.Tensor] = None + """Inference parameters passed to model to efficiently calculate + and store context during inference. + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. + + Args: + max_sequence_len: Maximum sequence length. + max_batch_size: Maximum batch size. + sequence_len_offset: Sequence length offset. + batch_size_offset: Batch size offset. + key_value_memory_dict: Key value memory dictionary. + fused_ft_kernel: Whether to use fused kernel for fast inference. + lengths_per_sample: Lengths per sample. + + """ + + max_sequence_len: int = field(metadata={"help": "Maximum sequence length."}) + + max_batch_size: int = field(metadata={"help": "Maximum batch size."}) + + sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) + + batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) + + key_value_memory_dict: Dict[str, Any] = field( + default_factory=dict, metadata={"help": "Key value memory dictionary."} + ) + + fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."}) + + lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) class Embedding(nn.Module): @@ -80,14 +102,19 @@ class Embedding(nn.Module): return hidden_states + class RotaryEmbedding(nn.Module): - """PyTorch implementation of `flash-attn` RotaryEmbedding layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" + """Rotary embeddings. + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py. + + """ def __init__( self, dim: int, - base: Optional[int] = 10000, + base: int = 10000, scale_base: Optional[float] = None, device: Optional[str] = None, **kwargs, @@ -119,7 +146,7 @@ class RotaryEmbedding(nn.Module): self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0) -> None: + def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) seqlen = x.shape[1] + seqlen_offset @@ -153,7 +180,7 @@ class RotaryEmbedding(nn.Module): self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - def apply_rotary_emb_qkv( + def _apply_rotary_emb_qkv( self, qkv: torch.FloatTensor, sin: torch.FloatTensor, @@ -189,7 +216,6 @@ class RotaryEmbedding(nn.Module): # Computes the new keys and queries, recasting to original dtype q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) return torch.cat( @@ -202,47 +228,9 @@ class RotaryEmbedding(nn.Module): ) def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: - """Perform the forward pass. - - Args: - qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim). - seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch. - - Returns: - New `qkv` and the cached sinusoids. - - """ - + # `qkv` is of shape (batch, seqlen, 3, nheads, headdim) self._update_cos_sin_cache(qkv, seqlen_offset) - - return self.apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]) - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, inference_params.max_sequence_len, 2, - num_heads, head_dim, dtype=kv.dtype, device=kv.device - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) - assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) - - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv + return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]) class MLP(nn.Module): @@ -267,17 +255,6 @@ class MLP(nn.Module): self.fc2 = nn.Linear(n_inner, config.n_embd) self.act = ACT2FN[act_fn] - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - old_keys = [prefix + "fc_in.weight", prefix + "fc_out.weight", prefix + "fc_in.bias", prefix + "fc_out.bias"] - new_keys = [prefix + "fc1.weight", prefix + "fc2.weight", prefix + "fc1.bias", prefix + "fc2.bias"] - - if all(k in state_dict for k in old_keys) and not all(k in state_dict for k in new_keys): - # Older version of `MLP` saved with different key names. - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: hidden_states = self.fc1(hidden_states) hidden_states = self.act(hidden_states) @@ -286,132 +263,114 @@ class MLP(nn.Module): return hidden_states -class FusedMLP(nn.Module): - """Fused Multi-Layer Perceptron from `flash-attn`. - - Reference: - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py. - - """ - def __init__(self, config: PretrainedConfig, n_inner: Optional[int] = None, act_fn: Optional[str] = None, - raise_on_missing: bool = False) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] - activation = "gelu_approx" if act_fn in gelu_activations else "relu" - - self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - return self.mlp(hidden_states) - class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) + """Self-attention layer (compatible with PyTorch). + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + + def __init__( + self, + causal: bool = True, + softmax_scale: Optional[float] = None, + attention_dropout: float = 0.0, + ) -> None: super().__init__() + self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) - def forward(self, qkv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - batch_size, seqlen = qkv.shape[0], qkv.shape[1] + def forward( + self, + qkv: torch.FloatTensor, + causal: bool = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: causal = self.causal if causal is None else causal + batch_size, seq_len = qkv.shape[0], qkv.shape[1] q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, - device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + if attention_mask is not None: + padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(attention_mask, 0.0) + + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1) scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + attention = self.drop(attention) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Adapted from https://github.com/Dao-AILab/flash-attention. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) + """Cross-attention layer (compatible with PyTorch). + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + + def __init__( + self, + causal: bool = True, + softmax_scale: Optional[float] = None, + attention_dropout: float = 0.0, + ) -> None: super().__init__() + self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - batch_size, seqlen_q = q.shape[0], q.shape[1] + def forward( + self, + q: torch.FloatTensor, + kv: torch.FloatTensor, + causal: bool = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: causal = self.causal if causal is None else causal - seqlen_k = kv.shape[1] + batch_size, seq_len_q = q.shape[0], q.shape[1] assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + + seq_len_k = kv.shape[1] k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, - device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + if attention_mask is not None: + padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(attention_mask, 0.0) + + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, - device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1) scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + attention = self.drop(attention) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output + def find_mha_dims( config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None ) -> Tuple[int, int]: @@ -445,152 +404,163 @@ def find_mha_dims( return n_head, head_dim +def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor: + """Update the key-value cache for inference. + + Reference: + https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. + + Args: + kv: Key-value tensor. + inference_params: Inference parameters. + layer_idx: Layer index. + + Returns: + Updated key-value tensor. + + """ + + num_heads, head_dim = kv.shape[-2:] + + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + + return kv + + assert inference_params.sequence_len_offset == 0 + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + + packsize = 4 if kv.dtype == torch.float32 else 8 + + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") + + return kv + + class MHA(nn.Module): - """Multi-head attention layer. - Adapted from https://github.com/Dao-AILab/flash-attention.""" + """Multi-head attention layer.""" def __init__( self, config: PretrainedConfig, + dtype: Optional[torch.dtype] = None, + device: Optional[str] = None, rotary_dim: Optional[int] = None, + rotary_emb_scale_base: Optional[float] = None, n_head: Optional[int] = None, head_dim: Optional[int] = None, - bias: Optional[bool] = True, - dropout: Optional[float] = 0.0, + bias: bool = True, + causal: bool = True, softmax_scale: Optional[float] = None, - causal: Optional[bool] = True, + dropout: float = 0.0, layer_idx: Optional[int] = None, - rotary_emb_scale_base: Optional[float] = None, - return_residual: Optional[bool] = False, - checkpointing: Optional[bool] = False, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - fused_dense: Optional[bool] = True, - flash_attn: Optional[bool] = True, - cutlass_attn: Optional[bool] = False, - flash_rotary: Optional[bool] = True, - raise_on_missing: Optional[bool] = False + return_residual: bool = False, + checkpointing: bool = False, ) -> None: super().__init__() - factory_kwargs = {"device": device, "dtype": dtype} - n_head, head_dim = find_mha_dims(config, n_head, head_dim) - - self.hidden_size = config.n_embd - self.n_head = n_head - self.head_dim = head_dim - self.op_size = n_head * head_dim - - self.causal = causal - self.layer_idx = layer_idx + # Rotary embedding self.rotary_emb_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - self.fused_dense = fused_dense - self.flash_attn = flash_attn - self.cutlass_attn = cutlass_attn - self.flash_rotary = flash_rotary - self.return_residual = return_residual - self.checkpointing = checkpointing - if self.rotary_emb_dim > 0: rotary_kwargs = {"device": device} if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0: rotary_kwargs["scale_base"] = rotary_emb_scale_base - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs) - else: - pass + + # MLP + self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim) + op_size = self.n_head * self.head_dim + hidden_size = config.n_embd - self.Wqkv = nn.Linear(self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs) - self.out_proj = nn.Linear(self.op_size, self.hidden_size, bias=bias, **factory_kwargs) + self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype) + self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype) + # Attention self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None: - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - Adapted from https://github.com/Dao-AILab/flash-attention.""" - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - - return _update_kv_cache(kv, inference_params, self.layer_idx) + self.layer_idx = layer_idx + self.return_residual = return_residual + self.checkpointing = checkpointing def forward( self, x: torch.FloatTensor, - x_kv: Optional[torch.FloatTensor] = None, - key_padding_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[InferenceParams] = None, + attention_mask: Optional[torch.BoolTensor] = None, cu_seqlens: Optional[torch.LongTensor] = None, max_seqlen: Optional[int] = None, - mixer_subset: Optional[torch.LongTensor] = None, - past_cache: Optional[InferenceParams] = None, - **kwargs + **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Perform the forward pass. - - Args: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if - cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total - is the is the sum of the sequence lengths in the batch. - x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. - key_padding_mask: boolean mask, True means to keep, False means to mask out. - (batch, seqlen). Only applicable when not using FlashAttention. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into x. Only applicable when using - FlashAttention. - max_seqlen: int. Maximum sequence length in the batch. - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - past_cache: For generation only. - - Returns: - (batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None, - else (total, hidden_dim) where total is the is the sum of the sequence lengths - in the batch. - - """ - - if cu_seqlens is not None: - assert max_seqlen is not None - assert key_padding_mask is None - assert self.flash_attn - assert self.rotary_emb_dim == 0 - - if key_padding_mask is not None: - assert cu_seqlens is None - assert max_seqlen is None - assert not self.flash_attn - - if past_cache is not None: - assert key_padding_mask is None - assert cu_seqlens is None and max_seqlen is None - - attn_kwargs = {"key_padding_mask": key_padding_mask} - - assert x_kv is None and mixer_subset is None - qkv = self.Wqkv(x) qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) - if past_cache is None: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - context = self.inner_attn(qkv, **attn_kwargs) + seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0 + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + if past_key_values is not None: + kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx) + + if attention_mask is not None: + attention_mask, cu_seqlens, max_seqlen = attention_mask + attention_mask = attention_mask.to(qkv.device) + + attention_kwargs = {"attention_mask": attention_mask} + + if past_key_values is None or seqlen_offset == 0: + if self.checkpointing: + attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs) + else: + attn_output = self.inner_attn(qkv, **attention_kwargs) else: - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset) q = qkv[:, :, 0] - kv = self._update_kv_cache(qkv[:, :, 1:], past_cache) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if past_cache.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + causal = None if past_key_values.sequence_len_offset == 0 else False + attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs) - out = rearrange(context, "... h d -> ... (h d)") - out = self.out_proj(out) + output = rearrange(attn_output, "... h d -> ... (h d)") + output = self.out_proj(output) + + return output if not self.return_residual else (output, x) - return out if not self.return_residual else (out, x) class ParallelBlock(nn.Module): """Parallel block. @@ -602,8 +572,6 @@ class ParallelBlock(nn.Module): def __init__( self, config: PretrainedConfig, - mixer: Optional[Dict[str, Any]] = None, - mlp: Optional[Dict[str, Any]] = None, block_idx: Optional[int] = None, ) -> None: super().__init__() @@ -612,19 +580,20 @@ class ParallelBlock(nn.Module): self.resid_dropout = nn.Dropout(config.resid_pdrop) self.block_idx = block_idx - self.mixer = MHA(config=config, **mixer, layer_idx=block_idx) - mlp_cls = mlp.pop('mlp_cls') - if mlp_cls == 'fused_mlp': - self.mlp = FusedMLP(config=config, **mlp) - else: - self.mlp = MLP(config=config, **mlp) + self.mixer = MHA(config, layer_idx=block_idx) + self.mlp = MLP(config) - def forward(self, hidden_states: torch.FloatTensor, - past_cache: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + def forward( + self, + hidden_states: torch.FloatTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> torch.FloatTensor: residual = hidden_states hidden_states = self.ln(hidden_states) - attn_outputs = self.mixer(hidden_states, past_cache=past_cache) + attn_outputs = self.mixer(hidden_states, past_key_values=past_key_values, attention_mask=attention_mask) if isinstance(attn_outputs, tuple): attn_outputs = attn_outputs[0] @@ -635,6 +604,7 @@ class ParallelBlock(nn.Module): return hidden_states + class CausalLMHead(nn.Module): """Causal Language Modeling head. @@ -666,7 +636,7 @@ class CausalLMLoss(nn.Module): """ - def __init__(self, shift_labels: Optional[bool] = True) -> None: + def __init__(self, shift_labels: bool = True) -> None: super().__init__() self.shift_labels = shift_labels @@ -681,6 +651,7 @@ class CausalLMLoss(nn.Module): return loss + class MixFormerSequentialPreTrainedModel(PreTrainedModel): """MixFormer (sequential for DeepSpeed) pre-trained model.""" @@ -691,9 +662,35 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs) -> None: super().__init__(*inputs, **kwargs) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs) -> Dict[str, Any]: - if "use_cache" in kwargs and not kwargs["use_cache"]: - return {"input_ids": input_ids} + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, (nn.Linear,)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + **kwargs, + ) -> Dict[str, Any]: + if attention_mask is not None and torch.any(~attention_mask.bool()): + total_seq_len = torch.sum(attention_mask, dim=1) + max_seq_len = torch.max(total_seq_len) + + total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1) + cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32) + attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item()) + else: + attention_mask = None if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): past_key_values = InferenceParams( @@ -705,11 +702,15 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel): key_value_memory_dict={}, ) else: - # assume past_key_values has cached all but last token in input_ids + # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` past_key_values.sequence_len_offset = len(input_ids[0]) - 1 input_ids = input_ids[:, -1].unsqueeze(-1) - return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs} + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + } class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): @@ -723,23 +724,7 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): super().__init__(config) modules = [Embedding(config)] - block_config = config.architecture - - if not isinstance(block_config, list): - block_config = [block_config for _ in range(config.n_layer)] - - if config.n_layer != len(block_config): - config.n_layer = len(block_config) - - for block_idx, block in enumerate(block_config): - # `block_cls` with `legacy` value is for backward compatibility - # `path` key is for backward compatibility - block = copy.deepcopy(block) or {"block_cls": "parallel"} - block_cls = block.pop("path", None) or block.pop("block_cls", None) - - block["block_idx"] = block_idx - modules.append(ParallelBlock(config, **block)) - + modules += [ParallelBlock(config, block_idx=i) for i in range(config.n_layer)] modules.append(CausalLMHead(config)) self.layers = nn.Sequential(*modules) @@ -760,20 +745,26 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel): self.layers[-1].linear = new_embeddings def forward( - self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, - past_key_values: Optional[torch.FloatTensor] = None, **kwargs + self, + input_ids: torch.LongTensor, + past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, + attention_mask: Optional[torch.BoolTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs, ) -> CausalLMOutputWithPast: + if attention_mask is not None and self.training: + raise ValueError("`attention_mask` is not supported during training.") - if not past_key_values: + if past_key_values is None and attention_mask is None: lm_logits = self.layers(input_ids) else: hidden_layer = self.layers[0](input_ids) for module in self.layers[1:-1]: - hidden_layer = module(hidden_layer, past_cache=past_key_values) + hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask) lm_logits = self.layers[-1](hidden_layer) loss = None if labels is not None: loss = self.loss(lm_logits, labels) - + return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)