From d655135ca1783378b889f3b52d5281c2119e85ed Mon Sep 17 00:00:00 2001 From: Gunasekar Date: Mon, 11 Sep 2023 21:30:53 +0000 Subject: [PATCH] Upload MixFormerSequentialForCausalLM --- modeling_mixformer_sequential.py | 46 ++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index 356d213..7c10e73 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -1,6 +1,36 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + from __future__ import annotations import math @@ -21,7 +51,8 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig @dataclass class InferenceParams: """Inference parameters that are passed to the main model in order - to efficienly calculate and store the context during inference.""" + to efficienly calculate and store the context during inference. + Adapted from https://github.com/Dao-AILab/flash-attention.""" max_sequence_len: int max_batch_size: int sequence_len_offset: int = 0 @@ -50,7 +81,8 @@ class Embedding(nn.Module): return hidden_states class RotaryEmbedding(nn.Module): - """PyTorch implementation of `flash-attn` RotaryEmbedding layer.""" + """PyTorch implementation of `flash-attn` RotaryEmbedding layer. + Adapted from https://github.com/Dao-AILab/flash-attention.""" def __init__( self, @@ -187,7 +219,7 @@ class RotaryEmbedding(nn.Module): def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) - """ + Adapted from https://github.com/Dao-AILab/flash-attention.""" # Pre-allocate memory for key-values for inference. num_heads, head_dim = kv.shape[-2:] if layer_idx not in inference_params.key_value_memory_dict: @@ -281,6 +313,7 @@ class FusedMLP(nn.Module): class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. + Adapted from https://github.com/Dao-AILab/flash-attention. Arguments --------- softmax_scale: The temperature to use for the softmax attention. @@ -329,6 +362,7 @@ class SelfAttention(nn.Module): class CrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. + Adapted from https://github.com/Dao-AILab/flash-attention. Arguments --------- softmax_scale: The temperature to use for the softmax attention. @@ -412,7 +446,8 @@ def find_mha_dims( class MHA(nn.Module): - """Multi-head attention layer.""" + """Multi-head attention layer. + Adapted from https://github.com/Dao-AILab/flash-attention.""" def __init__( self, @@ -472,7 +507,8 @@ class MHA(nn.Module): self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None: - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) + Adapted from https://github.com/Dao-AILab/flash-attention.""" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"