prototype of unblocking onnx export
This commit is contained in:
parent
92557d03bb
commit
2066613d0b
@ -117,10 +117,6 @@ def _apply_rotary_emb(
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
|
||||
assert rotary_dim <= head_dim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
|
||||
x_rot = x[:, :, :, :rotary_dim]
|
||||
x_pass = x[:, :, :, rotary_dim:]
|
||||
|
||||
@ -141,13 +137,9 @@ def _apply_rotary_emb_kv(
|
||||
sin_k: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
_, seqlen, two, _, head_dim = kv.shape
|
||||
assert two == 2
|
||||
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= head_dim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
|
||||
k_rot = kv[:, :, 0, :, :rotary_dim]
|
||||
k_pass = kv[:, :, 0, :, rotary_dim:]
|
||||
@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv(
|
||||
sin_k: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
_, seqlen, three, _, head_dim = qkv.shape
|
||||
assert three == 3
|
||||
|
||||
rotary_seqlen, rotary_dim = cos.shape
|
||||
rotary_dim *= 2
|
||||
assert rotary_dim <= head_dim
|
||||
assert seqlen <= rotary_seqlen
|
||||
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
|
||||
|
||||
q_rot = qkv[:, :, 0, :, :rotary_dim]
|
||||
q_pass = qkv[:, :, 0, :, rotary_dim:]
|
||||
@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module):
|
||||
scale_base: Optional[float] = None,
|
||||
pos_idx_in_fp32: bool = True,
|
||||
device: Optional[str] = None,
|
||||
max_position_embeddings=2048,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module):
|
||||
)
|
||||
self.register_buffer("scale", scale, persistent=False)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
# NOTE: initialize cached attributes
|
||||
self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32)
|
||||
|
||||
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
|
||||
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
||||
@ -262,43 +248,36 @@ class RotaryEmbedding(nn.Module):
|
||||
) -> None:
|
||||
# Reset the tables if sequence length has been chaned, if we are on a
|
||||
# new device or if we are switching from inference mode to training
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached is None
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
self._seq_len_cached = seqlen
|
||||
|
||||
# fp32 is preferred since the output of `torch.arange` can be quite large
|
||||
# and bf16 would lose a lot of precision
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
# fp32 is preferred since the output of `torch.arange` can be quite large
|
||||
# and bf16 would lose a lot of precision
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self._compute_inv_freq(device=device)
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
inv_freq = self.inv_freq
|
||||
|
||||
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
|
||||
# Force the scale multiplication to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
# Force the scale multiplication to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -309,10 +288,11 @@ class RotaryEmbedding(nn.Module):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seqlen = qkv.shape[1]
|
||||
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
else:
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if seqlen > self._seq_len_cached:
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
else:
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
|
||||
if kv is None:
|
||||
return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||
@ -336,7 +316,6 @@ class MLP(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
act_fn = config.activation_function if act_fn is None else act_fn
|
||||
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
|
||||
|
||||
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
|
||||
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
|
||||
@ -436,7 +415,6 @@ class CrossAttention(nn.Module):
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||
|
||||
if kv.shape[3] != q.shape[2]:
|
||||
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
||||
@ -474,14 +452,6 @@ def _find_mha_dims(
|
||||
n_head_kv: Optional[int] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
assert all(
|
||||
hasattr(config, attr) for attr in ["n_embd", "n_head"]
|
||||
), "`config` must have `n_embd` and `n_head` attributes."
|
||||
|
||||
if head_dim is None:
|
||||
assert (
|
||||
config.n_embd % config.n_head == 0
|
||||
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
|
||||
|
||||
if n_head is None and head_dim is None:
|
||||
head_dim = config.n_embd // config.n_head
|
||||
@ -491,7 +461,6 @@ def _find_mha_dims(
|
||||
|
||||
if n_head_kv is None:
|
||||
n_head_kv = getattr(config, "n_head_kv", None) or n_head
|
||||
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
|
||||
|
||||
return n_head, n_head_kv, head_dim
|
||||
|
||||
@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
||||
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
assert batch_end <= kv_cache.shape[0]
|
||||
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert sequence_end <= kv_cache.shape[1]
|
||||
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
|
||||
@ -560,7 +526,7 @@ class MHA(nn.Module):
|
||||
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
|
||||
if rotary_cls is None:
|
||||
rotary_cls = RotaryEmbedding
|
||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
|
||||
self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs)
|
||||
|
||||
# MLP
|
||||
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
|
||||
@ -632,7 +598,8 @@ class MHA(nn.Module):
|
||||
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
if attention_mask is not None and torch.any(~attention_mask.bool()):
|
||||
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.bool()
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user