Upload MixFormerSequentialForCausalLM
This commit is contained in:
parent
07a048efa7
commit
d655135ca1
@ -1,6 +1,36 @@
|
|||||||
# Copyright (c) Microsoft Corporation.
|
# Copyright (c) Microsoft Corporation.
|
||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# BSD 3-Clause License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Redistribution and use in source and binary forms, with or without
|
||||||
|
# modification, are permitted provided that the following conditions are met:
|
||||||
|
#
|
||||||
|
# * Redistributions of source code must retain the above copyright notice, this
|
||||||
|
# list of conditions and the following disclaimer.
|
||||||
|
#
|
||||||
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
# this list of conditions and the following disclaimer in the documentation
|
||||||
|
# and/or other materials provided with the distribution.
|
||||||
|
#
|
||||||
|
# * Neither the name of the copyright holder nor the names of its
|
||||||
|
# contributors may be used to endorse or promote products derived from
|
||||||
|
# this software without specific prior written permission.
|
||||||
|
#
|
||||||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -21,7 +51,8 @@ from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
|||||||
@dataclass
|
@dataclass
|
||||||
class InferenceParams:
|
class InferenceParams:
|
||||||
"""Inference parameters that are passed to the main model in order
|
"""Inference parameters that are passed to the main model in order
|
||||||
to efficienly calculate and store the context during inference."""
|
to efficienly calculate and store the context during inference.
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||||
max_sequence_len: int
|
max_sequence_len: int
|
||||||
max_batch_size: int
|
max_batch_size: int
|
||||||
sequence_len_offset: int = 0
|
sequence_len_offset: int = 0
|
||||||
@ -50,7 +81,8 @@ class Embedding(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer."""
|
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -187,7 +219,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
|
|
||||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||||
"""
|
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||||
# Pre-allocate memory for key-values for inference.
|
# Pre-allocate memory for key-values for inference.
|
||||||
num_heads, head_dim = kv.shape[-2:]
|
num_heads, head_dim = kv.shape[-2:]
|
||||||
if layer_idx not in inference_params.key_value_memory_dict:
|
if layer_idx not in inference_params.key_value_memory_dict:
|
||||||
@ -281,6 +313,7 @@ class FusedMLP(nn.Module):
|
|||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Implement the scaled dot product attention with softmax.
|
"""Implement the scaled dot product attention with softmax.
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||||
Arguments
|
Arguments
|
||||||
---------
|
---------
|
||||||
softmax_scale: The temperature to use for the softmax attention.
|
softmax_scale: The temperature to use for the softmax attention.
|
||||||
@ -329,6 +362,7 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
"""Implement the scaled dot product attention with softmax.
|
"""Implement the scaled dot product attention with softmax.
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention.
|
||||||
Arguments
|
Arguments
|
||||||
---------
|
---------
|
||||||
softmax_scale: The temperature to use for the softmax attention.
|
softmax_scale: The temperature to use for the softmax attention.
|
||||||
@ -412,7 +446,8 @@ def find_mha_dims(
|
|||||||
|
|
||||||
|
|
||||||
class MHA(nn.Module):
|
class MHA(nn.Module):
|
||||||
"""Multi-head attention layer."""
|
"""Multi-head attention layer.
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -472,7 +507,8 @@ class MHA(nn.Module):
|
|||||||
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||||
|
|
||||||
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
def _update_kv_cache(self, kv: torch.FloatTensor, inference_params: InferenceParams) -> None:
|
||||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||||
|
Adapted from https://github.com/Dao-AILab/flash-attention."""
|
||||||
|
|
||||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user