prototype of unblocking onnx export
This commit is contained in:
parent
92557d03bb
commit
2066613d0b
@ -117,10 +117,6 @@ def _apply_rotary_emb(
|
|||||||
rotary_seqlen, rotary_dim = cos.shape
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
rotary_dim *= 2
|
rotary_dim *= 2
|
||||||
|
|
||||||
assert rotary_dim <= head_dim
|
|
||||||
assert seqlen <= rotary_seqlen
|
|
||||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
|
||||||
|
|
||||||
x_rot = x[:, :, :, :rotary_dim]
|
x_rot = x[:, :, :, :rotary_dim]
|
||||||
x_pass = x[:, :, :, rotary_dim:]
|
x_pass = x[:, :, :, rotary_dim:]
|
||||||
|
|
||||||
@ -141,13 +137,9 @@ def _apply_rotary_emb_kv(
|
|||||||
sin_k: Optional[torch.FloatTensor] = None,
|
sin_k: Optional[torch.FloatTensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
_, seqlen, two, _, head_dim = kv.shape
|
_, seqlen, two, _, head_dim = kv.shape
|
||||||
assert two == 2
|
|
||||||
|
|
||||||
rotary_seqlen, rotary_dim = cos.shape
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
rotary_dim *= 2
|
rotary_dim *= 2
|
||||||
assert rotary_dim <= head_dim
|
|
||||||
assert seqlen <= rotary_seqlen
|
|
||||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
|
||||||
|
|
||||||
k_rot = kv[:, :, 0, :, :rotary_dim]
|
k_rot = kv[:, :, 0, :, :rotary_dim]
|
||||||
k_pass = kv[:, :, 0, :, rotary_dim:]
|
k_pass = kv[:, :, 0, :, rotary_dim:]
|
||||||
@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv(
|
|||||||
sin_k: Optional[torch.FloatTensor] = None,
|
sin_k: Optional[torch.FloatTensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
_, seqlen, three, _, head_dim = qkv.shape
|
_, seqlen, three, _, head_dim = qkv.shape
|
||||||
assert three == 3
|
|
||||||
|
|
||||||
rotary_seqlen, rotary_dim = cos.shape
|
rotary_seqlen, rotary_dim = cos.shape
|
||||||
rotary_dim *= 2
|
rotary_dim *= 2
|
||||||
assert rotary_dim <= head_dim
|
|
||||||
assert seqlen <= rotary_seqlen
|
|
||||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
|
||||||
|
|
||||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||||
@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
scale_base: Optional[float] = None,
|
scale_base: Optional[float] = None,
|
||||||
pos_idx_in_fp32: bool = True,
|
pos_idx_in_fp32: bool = True,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
|
max_position_embeddings=2048,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
self.register_buffer("scale", scale, persistent=False)
|
self.register_buffer("scale", scale, persistent=False)
|
||||||
|
|
||||||
self._seq_len_cached = 0
|
# NOTE: initialize cached attributes
|
||||||
self._cos_cached = None
|
self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32)
|
||||||
self._sin_cached = None
|
|
||||||
self._cos_k_cached = None
|
|
||||||
self._sin_k_cached = None
|
|
||||||
|
|
||||||
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
||||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
||||||
@ -262,43 +248,36 @@ class RotaryEmbedding(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Reset the tables if sequence length has been chaned, if we are on a
|
# Reset the tables if sequence length has been chaned, if we are on a
|
||||||
# new device or if we are switching from inference mode to training
|
# new device or if we are switching from inference mode to training
|
||||||
if (
|
self._seq_len_cached = seqlen
|
||||||
seqlen > self._seq_len_cached
|
|
||||||
or self._cos_cached is None
|
|
||||||
or self._cos_cached.device != device
|
|
||||||
or self._cos_cached.dtype != dtype
|
|
||||||
or (self.training and self._cos_cached.is_inference())
|
|
||||||
):
|
|
||||||
self._seq_len_cached = seqlen
|
|
||||||
|
|
||||||
# fp32 is preferred since the output of `torch.arange` can be quite large
|
# fp32 is preferred since the output of `torch.arange` can be quite large
|
||||||
# and bf16 would lose a lot of precision
|
# and bf16 would lose a lot of precision
|
||||||
if self.pos_idx_in_fp32:
|
if self.pos_idx_in_fp32:
|
||||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||||
if self.inv_freq.dtype != torch.float32:
|
if self.inv_freq.dtype != torch.float32:
|
||||||
inv_freq = self._compute_inv_freq(device=device)
|
inv_freq = self._compute_inv_freq(device=device)
|
||||||
else:
|
|
||||||
inv_freq = self.inv_freq
|
|
||||||
else:
|
else:
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
||||||
inv_freq = self.inv_freq
|
inv_freq = self.inv_freq
|
||||||
|
else:
|
||||||
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
|
inv_freq = self.inv_freq
|
||||||
|
|
||||||
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
else:
|
else:
|
||||||
power = (
|
power = (
|
||||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||||
) / self.scale_base
|
) / self.scale_base
|
||||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||||
|
|
||||||
# Force the scale multiplication to happen in fp32
|
# Force the scale multiplication to happen in fp32
|
||||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -309,10 +288,11 @@ class RotaryEmbedding(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
seqlen = qkv.shape[1]
|
seqlen = qkv.shape[1]
|
||||||
|
|
||||||
if max_seqlen is not None:
|
if seqlen > self._seq_len_cached:
|
||||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
if max_seqlen is not None:
|
||||||
else:
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
else:
|
||||||
|
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||||
|
|
||||||
if kv is None:
|
if kv is None:
|
||||||
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||||
@ -336,7 +316,6 @@ class MLP(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
act_fn = config.activation_function if act_fn is None else act_fn
|
act_fn = config.activation_function if act_fn is None else act_fn
|
||||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
|
||||||
|
|
||||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||||
@ -436,7 +415,6 @@ class CrossAttention(nn.Module):
|
|||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||||
seqlen_k = kv.shape[1]
|
seqlen_k = kv.shape[1]
|
||||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
|
||||||
|
|
||||||
if kv.shape[3] != q.shape[2]:
|
if kv.shape[3] != q.shape[2]:
|
||||||
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
||||||
@ -474,14 +452,6 @@ def _find_mha_dims(
|
|||||||
n_head_kv: Optional[int] = None,
|
n_head_kv: Optional[int] = None,
|
||||||
head_dim: Optional[int] = None,
|
head_dim: Optional[int] = None,
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
assert all(
|
|
||||||
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
|
||||||
), "`config` must have `n_embd` and `n_head` attributes."
|
|
||||||
|
|
||||||
if head_dim is None:
|
|
||||||
assert (
|
|
||||||
config.n_embd % config.n_head == 0
|
|
||||||
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
|
||||||
|
|
||||||
if n_head is None and head_dim is None:
|
if n_head is None and head_dim is None:
|
||||||
head_dim = config.n_embd // config.n_head
|
head_dim = config.n_embd // config.n_head
|
||||||
@ -491,7 +461,6 @@ def _find_mha_dims(
|
|||||||
|
|
||||||
if n_head_kv is None:
|
if n_head_kv is None:
|
||||||
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
||||||
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
|
||||||
|
|
||||||
return n_head, n_head_kv, head_dim
|
return n_head, n_head_kv, head_dim
|
||||||
|
|
||||||
@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|||||||
|
|
||||||
batch_start = inference_params.batch_size_offset
|
batch_start = inference_params.batch_size_offset
|
||||||
batch_end = batch_start + kv.shape[0]
|
batch_end = batch_start + kv.shape[0]
|
||||||
assert batch_end <= kv_cache.shape[0]
|
|
||||||
|
|
||||||
sequence_start = inference_params.seqlen_offset
|
sequence_start = inference_params.seqlen_offset
|
||||||
sequence_end = sequence_start + kv.shape[1]
|
sequence_end = sequence_start + kv.shape[1]
|
||||||
assert sequence_end <= kv_cache.shape[1]
|
|
||||||
|
|
||||||
assert kv_cache is not None
|
|
||||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||||
|
|
||||||
@ -560,7 +526,7 @@ class MHA(nn.Module):
|
|||||||
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
||||||
if rotary_cls is None:
|
if rotary_cls is None:
|
||||||
rotary_cls = RotaryEmbedding
|
rotary_cls = RotaryEmbedding
|
||||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs)
|
||||||
|
|
||||||
# MLP
|
# MLP
|
||||||
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
||||||
@ -632,7 +598,8 @@ class MHA(nn.Module):
|
|||||||
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
||||||
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.bool()
|
attention_mask = attention_mask.bool()
|
||||||
else:
|
else:
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user