prototype of unblocking onnx export

This commit is contained in:
titaiwang 2023-11-02 21:19:36 +00:00
parent 92557d03bb
commit 2066613d0b

@ -117,10 +117,6 @@ def _apply_rotary_emb(
rotary_seqlen, rotary_dim = cos.shape rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2 rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
x_rot = x[:, :, :, :rotary_dim] x_rot = x[:, :, :, :rotary_dim]
x_pass = x[:, :, :, rotary_dim:] x_pass = x[:, :, :, rotary_dim:]
@ -141,13 +137,9 @@ def _apply_rotary_emb_kv(
sin_k: Optional[torch.FloatTensor] = None, sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
_, seqlen, two, _, head_dim = kv.shape _, seqlen, two, _, head_dim = kv.shape
assert two == 2
rotary_seqlen, rotary_dim = cos.shape rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2 rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
k_rot = kv[:, :, 0, :, :rotary_dim] k_rot = kv[:, :, 0, :, :rotary_dim]
k_pass = kv[:, :, 0, :, rotary_dim:] k_pass = kv[:, :, 0, :, rotary_dim:]
@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv(
sin_k: Optional[torch.FloatTensor] = None, sin_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
_, seqlen, three, _, head_dim = qkv.shape _, seqlen, three, _, head_dim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2 rotary_dim *= 2
assert rotary_dim <= head_dim
assert seqlen <= rotary_seqlen
assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2)
q_rot = qkv[:, :, 0, :, :rotary_dim] q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:] q_pass = qkv[:, :, 0, :, rotary_dim:]
@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module):
scale_base: Optional[float] = None, scale_base: Optional[float] = None,
pos_idx_in_fp32: bool = True, pos_idx_in_fp32: bool = True,
device: Optional[str] = None, device: Optional[str] = None,
max_position_embeddings=2048,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module):
) )
self.register_buffer("scale", scale, persistent=False) self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0 # NOTE: initialize cached attributes
self._cos_cached = None self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32)
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
@ -262,13 +248,6 @@ class RotaryEmbedding(nn.Module):
) -> None: ) -> None:
# Reset the tables if sequence length has been chaned, if we are on a # Reset the tables if sequence length has been chaned, if we are on a
# new device or if we are switching from inference mode to training # new device or if we are switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen self._seq_len_cached = seqlen
# fp32 is preferred since the output of `torch.arange` can be quite large # fp32 is preferred since the output of `torch.arange` can be quite large
@ -309,6 +288,7 @@ class RotaryEmbedding(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
seqlen = qkv.shape[1] seqlen = qkv.shape[1]
if seqlen > self._seq_len_cached:
if max_seqlen is not None: if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
else: else:
@ -336,7 +316,6 @@ class MLP(nn.Module):
super().__init__() super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd n_inner = n_inner if n_inner is not None else 4 * config.n_embd
@ -436,7 +415,6 @@ class CrossAttention(nn.Module):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
batch_size, seqlen_q = q.shape[0], q.shape[1] batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1] seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
if kv.shape[3] != q.shape[2]: if kv.shape[3] != q.shape[2]:
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
@ -474,14 +452,6 @@ def _find_mha_dims(
n_head_kv: Optional[int] = None, n_head_kv: Optional[int] = None,
head_dim: Optional[int] = None, head_dim: Optional[int] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes."
if head_dim is None:
assert (
config.n_embd % config.n_head == 0
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
if n_head is None and head_dim is None: if n_head is None and head_dim is None:
head_dim = config.n_embd // config.n_head head_dim = config.n_embd // config.n_head
@ -491,7 +461,6 @@ def _find_mha_dims(
if n_head_kv is None: if n_head_kv is None:
n_head_kv = getattr(config, "n_head_kv", None) or n_head n_head_kv = getattr(config, "n_head_kv", None) or n_head
assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`."
return n_head, n_head_kv, head_dim return n_head, n_head_kv, head_dim
@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
assert batch_end <= kv_cache.shape[0]
sequence_start = inference_params.seqlen_offset sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
assert sequence_end <= kv_cache.shape[1]
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...] kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
@ -560,7 +526,7 @@ class MHA(nn.Module):
rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding
if rotary_cls is None: if rotary_cls is None:
rotary_cls = RotaryEmbedding rotary_cls = RotaryEmbedding
self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs) self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs)
# MLP # MLP
self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim) self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
@ -632,7 +598,8 @@ class MHA(nn.Module):
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
if attention_mask is not None and torch.any(~attention_mask.bool()): # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
if attention_mask is not None:
attention_mask = attention_mask.bool() attention_mask = attention_mask.bool()
else: else:
attention_mask = None attention_mask = None