Adds support for MQA/GQA and attention mask during training.
This commit is contained in:
parent
d38e6f954e
commit
de35f900d3
@ -127,7 +127,7 @@ with torch.autocast(model.device.type, dtype=torch.float16, enabled=True):
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
|
**Remark.** In the generation function, our model currently does not support beam search (`num_beams` > 1).
|
||||||
Furthermore, in the forward pass of the model, we currently do not support attention mask during training, outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
|
Furthermore, in the forward pass of the model, we currently do not support outputting hidden states or attention values, or using custom input embeddings (instead of the model's).
|
||||||
|
|
||||||
### Citation
|
### Citation
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# Licensed under the MIT license.
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
@ -27,6 +27,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
|
|||||||
n_layer: Optional[int] = 20,
|
n_layer: Optional[int] = 20,
|
||||||
n_inner: Optional[int] = None,
|
n_inner: Optional[int] = None,
|
||||||
n_head: Optional[int] = 16,
|
n_head: Optional[int] = 16,
|
||||||
|
n_head_kv: Optional[int] = None,
|
||||||
rotary_dim: Optional[int] = 32,
|
rotary_dim: Optional[int] = 32,
|
||||||
activation_function: Optional[str] = "gelu_new",
|
activation_function: Optional[str] = "gelu_new",
|
||||||
embd_pdrop: Optional[float] = 0.0,
|
embd_pdrop: Optional[float] = 0.0,
|
||||||
@ -43,6 +44,7 @@ class MixFormerSequentialConfig(PretrainedConfig):
|
|||||||
self.n_layer = n_layer
|
self.n_layer = n_layer
|
||||||
self.n_inner = n_inner
|
self.n_inner = n_inner
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
|
self.n_head_kv = n_head_kv
|
||||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
self.embd_pdrop = embd_pdrop
|
self.embd_pdrop = embd_pdrop
|
||||||
|
|||||||
@ -34,20 +34,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import copy
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange, repeat
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
from .configuration_mixformer_sequential import MixFormerSequentialConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceParams:
|
class InferenceParams:
|
||||||
"""Inference parameters passed to model to efficiently calculate
|
"""Inference parameters passed to model to efficiently calculate
|
||||||
@ -57,21 +57,20 @@ class InferenceParams:
|
|||||||
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_sequence_len: Maximum sequence length.
|
max_seqlen: Maximum sequence length.
|
||||||
max_batch_size: Maximum batch size.
|
max_batch_size: Maximum batch size.
|
||||||
sequence_len_offset: Sequence length offset.
|
seqlen_offset: Sequence length offset.
|
||||||
batch_size_offset: Batch size offset.
|
batch_size_offset: Batch size offset.
|
||||||
key_value_memory_dict: Key value memory dictionary.
|
key_value_memory_dict: Key value memory dictionary.
|
||||||
fused_ft_kernel: Whether to use fused kernel for fast inference.
|
|
||||||
lengths_per_sample: Lengths per sample.
|
lengths_per_sample: Lengths per sample.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_sequence_len: int = field(metadata={"help": "Maximum sequence length."})
|
max_seqlen: int = field(metadata={"help": "Maximum sequence length."})
|
||||||
|
|
||||||
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
max_batch_size: int = field(metadata={"help": "Maximum batch size."})
|
||||||
|
|
||||||
sequence_len_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."})
|
||||||
|
|
||||||
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."})
|
||||||
|
|
||||||
@ -79,8 +78,6 @@ class InferenceParams:
|
|||||||
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
default_factory=dict, metadata={"help": "Key value memory dictionary."}
|
||||||
)
|
)
|
||||||
|
|
||||||
fused_ft_kernel: bool = field(default=False, metadata={"help": "Whether to use fused kernel for fast inference."})
|
|
||||||
|
|
||||||
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."})
|
||||||
|
|
||||||
|
|
||||||
@ -103,11 +100,111 @@ class Embedding(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_emb(
|
||||||
|
x: torch.FloatTensor,
|
||||||
|
cos: torch.FloatTensor,
|
||||||
|
sin: torch.FloatTensor,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
_, seqlen, _, head_dim = x.shape
|
||||||
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
|
rotary_dim *= 2
|
||||||
|
|
||||||
|
assert rotary_dim <= head_dim
|
||||||
|
assert seqlen <= rotary_seqlen
|
||||||
|
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||||
|
|
||||||
|
x_rot = x[:, :, :, :rotary_dim]
|
||||||
|
x_pass = x[:, :, :, rotary_dim:]
|
||||||
|
|
||||||
|
x1, x2 = x_rot.chunk(2, dim=-1)
|
||||||
|
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
||||||
|
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
|
||||||
|
|
||||||
|
x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
|
||||||
|
|
||||||
|
return torch.cat([x_rot, x_pass], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_emb_kv(
|
||||||
|
kv: torch.FloatTensor,
|
||||||
|
cos: torch.FloatTensor,
|
||||||
|
sin: torch.FloatTensor,
|
||||||
|
cos_k: Optional[torch.FloatTensor] = None,
|
||||||
|
sin_k: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
_, seqlen, two, _, head_dim = kv.shape
|
||||||
|
assert two == 2
|
||||||
|
|
||||||
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
|
rotary_dim *= 2
|
||||||
|
assert rotary_dim <= head_dim
|
||||||
|
assert seqlen <= rotary_seqlen
|
||||||
|
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||||
|
|
||||||
|
k_rot = kv[:, :, 0, :, :rotary_dim]
|
||||||
|
k_pass = kv[:, :, 0, :, rotary_dim:]
|
||||||
|
|
||||||
|
k1, k2 = k_rot.chunk(2, dim=-1)
|
||||||
|
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
||||||
|
k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
|
||||||
|
|
||||||
|
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
||||||
|
kv[:, :, 1:2, :, :],
|
||||||
|
],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rotary_emb_qkv(
|
||||||
|
qkv: torch.FloatTensor,
|
||||||
|
cos: torch.FloatTensor,
|
||||||
|
sin: torch.FloatTensor,
|
||||||
|
cos_k: Optional[torch.FloatTensor] = None,
|
||||||
|
sin_k: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
_, seqlen, three, _, head_dim = qkv.shape
|
||||||
|
assert three == 3
|
||||||
|
|
||||||
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
|
rotary_dim *= 2
|
||||||
|
assert rotary_dim <= head_dim
|
||||||
|
assert seqlen <= rotary_seqlen
|
||||||
|
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||||
|
|
||||||
|
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||||
|
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||||
|
|
||||||
|
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
||||||
|
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
||||||
|
|
||||||
|
q1, q2 = q_rot.chunk(2, dim=-1)
|
||||||
|
k1, k2 = k_rot.chunk(2, dim=-1)
|
||||||
|
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
||||||
|
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
||||||
|
|
||||||
|
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
||||||
|
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
||||||
|
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
||||||
|
qkv[:, :, 2:3, :, :],
|
||||||
|
],
|
||||||
|
axis=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
"""Rotary embeddings.
|
"""Rotary positional embedding (RoPE).
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
|
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
||||||
|
https://arxiv.org/pdf/2104.09864.pdf.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -131,14 +228,14 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
scale = (
|
scale = (
|
||||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||||
if scale_base is not None
|
if scale_base is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.register_buffer("scale", scale)
|
self.register_buffer("scale", scale, persistent=False)
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
self._seq_len_cached = 0
|
||||||
self._cos_cached = None
|
self._cos_cached = None
|
||||||
@ -146,28 +243,26 @@ class RotaryEmbedding(nn.Module):
|
|||||||
self._cos_k_cached = None
|
self._cos_k_cached = None
|
||||||
self._sin_k_cached = None
|
self._sin_k_cached = None
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, x: torch.FloatTensor, seqlen_offset: int = 0) -> None:
|
def _update_cos_sin_cache(
|
||||||
# Reset the tables if the sequence length has changed,
|
self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
|
||||||
# or if we're on a new device (possibly due to tracing for instance)
|
) -> None:
|
||||||
seqlen = x.shape[1] + seqlen_offset
|
|
||||||
|
|
||||||
# Re-generate the inverse frequency buffer if it's not fp32
|
# Re-generate the inverse frequency buffer if it's not fp32
|
||||||
# (for instance if model.half() was called)
|
# (for instance if model.half() was called)
|
||||||
if self.inv_freq.dtype != "torch.float32":
|
if self.inv_freq.dtype != "torch.float32":
|
||||||
self.inv_freq = 1.0 / (
|
self.inv_freq = 1.0 / (
|
||||||
self.base ** (torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) / self.dim)
|
self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
else:
|
else:
|
||||||
power = (
|
power = (
|
||||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||||
@ -175,62 +270,32 @@ class RotaryEmbedding(nn.Module):
|
|||||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||||
|
|
||||||
# We want the multiplication by scale to happen in fp32
|
# We want the multiplication by scale to happen in fp32
|
||||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||||
|
|
||||||
def _apply_rotary_emb_qkv(
|
def forward(
|
||||||
self,
|
self,
|
||||||
qkv: torch.FloatTensor,
|
qkv: torch.Tensor,
|
||||||
sin: torch.FloatTensor,
|
kv: Optional[torch.Tensor] = None,
|
||||||
cos: torch.FloatTensor,
|
seqlen_offset: int = 0,
|
||||||
sin_k: Optional[torch.FloatTensor] = None,
|
max_seqlen: Optional[int] = None,
|
||||||
cos_k: Optional[torch.FloatTensor] = None,
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
) -> torch.FloatTensor:
|
seqlen = qkv.shape[1]
|
||||||
_, seqlen, three, _, headdim = qkv.shape
|
|
||||||
assert three == 3
|
|
||||||
|
|
||||||
rotary_seqlen, rotary_dim = cos.shape
|
if max_seqlen is not None:
|
||||||
rotary_dim *= 2
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||||
assert rotary_dim <= headdim
|
else:
|
||||||
assert seqlen <= rotary_seqlen
|
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||||
|
|
||||||
cos_k = cos if cos_k is None else cos_k
|
if kv is None:
|
||||||
sin_k = sin if sin_k is None else sin_k
|
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
else:
|
||||||
|
q = _apply_rotary_emb(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
|
kv = _apply_rotary_emb_kv(kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
|
|
||||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
return q, kv
|
||||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
|
||||||
|
|
||||||
k_rot = qkv[:, :, 1, :, :rotary_dim]
|
|
||||||
k_pass = qkv[:, :, 1, :, rotary_dim:]
|
|
||||||
|
|
||||||
# Splits the queries and keys in half
|
|
||||||
q1, q2 = q_rot.chunk(2, dim=-1)
|
|
||||||
k1, k2 = k_rot.chunk(2, dim=-1)
|
|
||||||
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
|
|
||||||
|
|
||||||
# Casts to fp32 are necessary to prevent fp16 overflow issues
|
|
||||||
q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
|
|
||||||
|
|
||||||
# Computes the new keys and queries, recasting to original dtype
|
|
||||||
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
|
|
||||||
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
|
|
||||||
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
|
|
||||||
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
|
|
||||||
qkv[:, :, 2:3, :, :],
|
|
||||||
],
|
|
||||||
axis=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# `qkv` is of shape (batch, seqlen, 3, nheads, headdim)
|
|
||||||
self._update_cos_sin_cache(qkv, seqlen_offset)
|
|
||||||
return self._apply_rotary_emb_qkv(qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:])
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -290,21 +355,22 @@ class SelfAttention(nn.Module):
|
|||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
causal = self.causal if causal is None else causal
|
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||||
batch_size, seq_len = qkv.shape[0], qkv.shape[1]
|
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||||
|
|
||||||
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
padding_mask = torch.full((batch_size, seq_len), -10000.0, dtype=scores.dtype, device=scores.device)
|
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
|
||||||
padding_mask.masked_fill_(attention_mask, 0.0)
|
padding_mask.masked_fill_(attention_mask, 0.0)
|
||||||
|
|
||||||
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
||||||
|
|
||||||
if causal:
|
if causal:
|
||||||
causal_mask = torch.triu(torch.full((seq_len, seq_len), -10000.0, device=scores.device), 1)
|
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
|
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
@ -343,25 +409,31 @@ class CrossAttention(nn.Module):
|
|||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
causal = self.causal if causal is None else causal
|
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||||
batch_size, seq_len_q = q.shape[0], q.shape[1]
|
seqlen_k = kv.shape[1]
|
||||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||||
|
|
||||||
seq_len_k = kv.shape[1]
|
if kv.shape[3] != q.shape[2]:
|
||||||
|
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
|
|
||||||
|
causal = self.causal if causal is None else causal
|
||||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||||
|
|
||||||
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
padding_mask = torch.full((batch_size, seq_len_k), -10000.0, dtype=scores.dtype, device=scores.device)
|
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
|
||||||
padding_mask.masked_fill_(attention_mask, 0.0)
|
padding_mask.masked_fill_(attention_mask, 0.0)
|
||||||
|
|
||||||
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
||||||
|
|
||||||
if causal:
|
if causal:
|
||||||
causal_mask = torch.triu(torch.full((seq_len_q, seq_len_k), -10000.0, device=scores.device), 1)
|
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
|
||||||
|
causal_mask = cols > rows + seqlen_k - seqlen_q
|
||||||
|
|
||||||
|
scores = scores.masked_fill(causal_mask, -10000.0)
|
||||||
|
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
attention = self.drop(attention)
|
attention = self.drop(attention)
|
||||||
@ -371,21 +443,12 @@ class CrossAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def find_mha_dims(
|
def _find_mha_dims(
|
||||||
config: PretrainedConfig, n_head: Optional[int] = None, head_dim: Optional[int] = None
|
config: PretrainedConfig,
|
||||||
|
n_head: Optional[int] = None,
|
||||||
|
n_head_kv: Optional[int] = None,
|
||||||
|
head_dim: Optional[int] = None,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""Validate and return the number of heads and head dimension for multi-head attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Model configuration.
|
|
||||||
n_head: Number of heads.
|
|
||||||
head_dim: Head dimension.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of heads and head dimension.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
||||||
), "`config` must have `n_embd` and `n_head` attributes."
|
), "`config` must have `n_embd` and `n_head` attributes."
|
||||||
@ -401,31 +464,20 @@ def find_mha_dims(
|
|||||||
elif n_head is None or head_dim is None:
|
elif n_head is None or head_dim is None:
|
||||||
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
|
||||||
|
|
||||||
return n_head, head_dim
|
if n_head_kv is None:
|
||||||
|
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
||||||
|
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
||||||
|
|
||||||
|
return n_head, n_head_kv, head_dim
|
||||||
|
|
||||||
|
|
||||||
def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor:
|
||||||
"""Update the key-value cache for inference.
|
|
||||||
|
|
||||||
Reference:
|
|
||||||
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kv: Key-value tensor.
|
|
||||||
inference_params: Inference parameters.
|
|
||||||
layer_idx: Layer index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated key-value tensor.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_heads, head_dim = kv.shape[-2:]
|
num_heads, head_dim = kv.shape[-2:]
|
||||||
|
|
||||||
if layer_idx not in inference_params.key_value_memory_dict:
|
if layer_idx not in inference_params.key_value_memory_dict:
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.empty(
|
||||||
inference_params.max_batch_size,
|
inference_params.max_batch_size,
|
||||||
inference_params.max_sequence_len,
|
inference_params.max_seqlen,
|
||||||
2,
|
2,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
@ -434,43 +486,19 @@ def update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, la
|
|||||||
)
|
)
|
||||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||||
else:
|
else:
|
||||||
if not inference_params.fused_ft_kernel:
|
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
|
||||||
else:
|
|
||||||
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
|
||||||
kv_cache = None
|
|
||||||
|
|
||||||
batch_start = inference_params.batch_size_offset
|
batch_start = inference_params.batch_size_offset
|
||||||
batch_end = batch_start + kv.shape[0]
|
batch_end = batch_start + kv.shape[0]
|
||||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
assert batch_end <= kv_cache.shape[0]
|
||||||
|
|
||||||
sequence_start = inference_params.sequence_len_offset
|
sequence_start = inference_params.seqlen_offset
|
||||||
sequence_end = sequence_start + kv.shape[1]
|
sequence_end = sequence_start + kv.shape[1]
|
||||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
assert sequence_end <= kv_cache.shape[1]
|
||||||
|
|
||||||
if not inference_params.fused_ft_kernel:
|
assert kv_cache is not None
|
||||||
assert kv_cache is not None
|
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||||
|
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
|
||||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
|
||||||
|
|
||||||
return kv
|
|
||||||
|
|
||||||
assert inference_params.sequence_len_offset == 0
|
|
||||||
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
||||||
|
|
||||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
|
||||||
|
|
||||||
if kv_cache is not None:
|
|
||||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
|
||||||
k_cache = rearrange(kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize).contiguous()
|
|
||||||
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
|
|
||||||
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
|
||||||
else:
|
|
||||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
|
||||||
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
|
|
||||||
)
|
|
||||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d")
|
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
|
|
||||||
@ -486,6 +514,7 @@ class MHA(nn.Module):
|
|||||||
rotary_dim: Optional[int] = None,
|
rotary_dim: Optional[int] = None,
|
||||||
rotary_emb_scale_base: Optional[float] = None,
|
rotary_emb_scale_base: Optional[float] = None,
|
||||||
n_head: Optional[int] = None,
|
n_head: Optional[int] = None,
|
||||||
|
n_head_kv: Optional[int] = None,
|
||||||
head_dim: Optional[int] = None,
|
head_dim: Optional[int] = None,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
@ -506,12 +535,12 @@ class MHA(nn.Module):
|
|||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
self.n_head, self.head_dim = find_mha_dims(config, n_head, head_dim)
|
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
||||||
op_size = self.n_head * self.head_dim
|
op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
|
||||||
hidden_size = config.n_embd
|
hidden_size = config.n_embd
|
||||||
|
|
||||||
self.Wqkv = nn.Linear(hidden_size, 3 * op_size, bias=bias, device=device, dtype=dtype)
|
self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
|
||||||
self.out_proj = nn.Linear(op_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
# Attention
|
# Attention
|
||||||
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
||||||
@ -521,40 +550,75 @@ class MHA(nn.Module):
|
|||||||
self.return_residual = return_residual
|
self.return_residual = return_residual
|
||||||
self.checkpointing = checkpointing
|
self.checkpointing = checkpointing
|
||||||
|
|
||||||
|
def _forward_self_attn(
|
||||||
|
self, x: torch.FloatTensor, attention_mask: Optional[torch.BoolTensor]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||||
|
|
||||||
|
if self.rotary_emb_dim > 0:
|
||||||
|
qkv = self.rotary_emb(qkv)
|
||||||
|
|
||||||
|
if self.checkpointing:
|
||||||
|
return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
return self.inner_attn(qkv, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
def _forward_cross_attn(
|
||||||
|
self,
|
||||||
|
x: torch.FloatTensor,
|
||||||
|
past_key_values: Optional[InferenceParams],
|
||||||
|
attention_mask: Optional[torch.BoolTensor],
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
qkv = self.Wqkv(x)
|
||||||
|
|
||||||
|
q = qkv[..., : self.n_head * self.head_dim]
|
||||||
|
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
||||||
|
|
||||||
|
kv = qkv[..., self.n_head * self.head_dim :]
|
||||||
|
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||||
|
|
||||||
|
seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0
|
||||||
|
causal = None if seqlen_offset == 0 else False
|
||||||
|
if self.rotary_emb_dim > 0:
|
||||||
|
q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset)
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
kv = _update_kv_cache(kv, past_key_values, self.layer_idx)
|
||||||
|
|
||||||
|
if self.checkpointing:
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
self.inner_cross_attn, q, kv, attention_mask=attention_mask, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.inner_cross_attn(q, kv, attention_mask=attention_mask, causal=causal)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.FloatTensor,
|
x: torch.FloatTensor,
|
||||||
past_key_values: Optional[InferenceParams] = None,
|
past_key_values: Optional[InferenceParams] = None,
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
||||||
max_seqlen: Optional[int] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
qkv = self.Wqkv(x)
|
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
||||||
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
attention_mask = attention_mask.bool()
|
||||||
|
|
||||||
seqlen_offset = past_key_values.sequence_len_offset if past_key_values is not None else 0
|
|
||||||
if self.rotary_emb_dim > 0:
|
|
||||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
kv = update_kv_cache(qkv[:, :, 1:], past_key_values, self.layer_idx)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[0] if isinstance(attention_mask, tuple) else attention_mask
|
|
||||||
attention_mask = attention_mask.bool().to(qkv.device)
|
|
||||||
|
|
||||||
attention_kwargs = {"attention_mask": attention_mask}
|
|
||||||
|
|
||||||
if past_key_values is None or seqlen_offset == 0:
|
|
||||||
if self.checkpointing:
|
|
||||||
attn_output = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **attention_kwargs)
|
|
||||||
else:
|
|
||||||
attn_output = self.inner_attn(qkv, **attention_kwargs)
|
|
||||||
else:
|
else:
|
||||||
q = qkv[:, :, 0]
|
attention_mask = None
|
||||||
causal = None if past_key_values.sequence_len_offset == 0 else False
|
|
||||||
attn_output = self.inner_cross_attn(q, kv, causal=causal, **attention_kwargs)
|
# MHA
|
||||||
|
if self.n_head == self.n_head_kv:
|
||||||
|
if past_key_values is None:
|
||||||
|
# If `past_key_values` are not supplied, we run self-attention
|
||||||
|
attn_output = self._forward_self_attn(x, attention_mask)
|
||||||
|
else:
|
||||||
|
# If `past_key_values` are supplied, it means that we might have cached values and
|
||||||
|
# could take advantage of cross-attention
|
||||||
|
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
||||||
|
# MQA / GQA
|
||||||
|
else:
|
||||||
|
# Regardless of `past_key_values` being supplied or not, it always use cross-attention
|
||||||
|
# because `q` and `kv` lengths might be different
|
||||||
|
attn_output = self._forward_cross_attn(x, past_key_values, attention_mask)
|
||||||
|
|
||||||
output = rearrange(attn_output, "... h d -> ... (h d)")
|
output = rearrange(attn_output, "... h d -> ... (h d)")
|
||||||
output = self.out_proj(output)
|
output = self.out_proj(output)
|
||||||
@ -672,38 +736,29 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
module.bias.data.zero_()
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
|
||||||
total_seq_len = torch.sum(attention_mask, dim=1)
|
|
||||||
max_seq_len = torch.max(total_seq_len)
|
|
||||||
|
|
||||||
total_seq_len = torch.cat((torch.tensor([0], device=attention_mask.device), total_seq_len)).unsqueeze(1)
|
|
||||||
cumulative_seq_len = torch.cumsum(total_seq_len, dim=0).squeeze(1).to(torch.int32)
|
|
||||||
attention_mask = (attention_mask.bool(), cumulative_seq_len, max_seq_len.item())
|
|
||||||
else:
|
|
||||||
attention_mask = None
|
|
||||||
|
|
||||||
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
||||||
past_key_values = InferenceParams(
|
past_key_values = InferenceParams(
|
||||||
|
max_seqlen=self.config.n_positions,
|
||||||
max_batch_size=input_ids.shape[0],
|
max_batch_size=input_ids.shape[0],
|
||||||
max_sequence_len=self.config.n_positions,
|
seqlen_offset=0,
|
||||||
sequence_len_offset=0,
|
|
||||||
batch_size_offset=0,
|
batch_size_offset=0,
|
||||||
fused_ft_kernel=False,
|
|
||||||
key_value_memory_dict={},
|
key_value_memory_dict={},
|
||||||
|
lengths_per_sample=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
||||||
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
|
past_key_values.seqlen_offset = len(input_ids[0]) - 1
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -712,9 +767,9 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
|
||||||
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
||||||
@ -756,13 +811,10 @@ class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> CausalLMOutputWithPast:
|
) -> CausalLMOutputWithPast:
|
||||||
if past_key_values is None and attention_mask is None:
|
hidden_layer = self.layers[0](input_ids)
|
||||||
lm_logits = self.layers(input_ids)
|
for module in self.layers[1:-1]:
|
||||||
else:
|
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
|
||||||
hidden_layer = self.layers[0](input_ids)
|
lm_logits = self.layers[-1](hidden_layer)
|
||||||
for module in self.layers[1:-1]:
|
|
||||||
hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
|
|
||||||
lm_logits = self.layers[-1](hidden_layer)
|
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user