From 2066613d0b11dd211b7559c0728fb2090e3ccbae Mon Sep 17 00:00:00 2001 From: titaiwang Date: Thu, 2 Nov 2023 21:19:36 +0000 Subject: [PATCH] prototype of unblocking onnx export --- modeling_mixformer_sequential.py | 105 +++++++++++-------------------- 1 file changed, 36 insertions(+), 69 deletions(-) diff --git a/modeling_mixformer_sequential.py b/modeling_mixformer_sequential.py index b4efc53..fff9df6 100644 --- a/modeling_mixformer_sequential.py +++ b/modeling_mixformer_sequential.py @@ -117,10 +117,6 @@ def _apply_rotary_emb( rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2) - x_rot = x[:, :, :, :rotary_dim] x_pass = x[:, :, :, rotary_dim:] @@ -141,13 +137,9 @@ def _apply_rotary_emb_kv( sin_k: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: _, seqlen, two, _, head_dim = kv.shape - assert two == 2 rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2) k_rot = kv[:, :, 0, :, :rotary_dim] k_pass = kv[:, :, 0, :, rotary_dim:] @@ -175,13 +167,9 @@ def _apply_rotary_emb_qkv( sin_k: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: _, seqlen, three, _, head_dim = qkv.shape - assert three == 3 rotary_seqlen, rotary_dim = cos.shape rotary_dim *= 2 - assert rotary_dim <= head_dim - assert seqlen <= rotary_seqlen - assert cos.shape == sin.shape == (rotary_seqlen, rotary_dim // 2) q_rot = qkv[:, :, 0, :, :rotary_dim] q_pass = qkv[:, :, 0, :, rotary_dim:] @@ -223,6 +211,7 @@ class RotaryEmbedding(nn.Module): scale_base: Optional[float] = None, pos_idx_in_fp32: bool = True, device: Optional[str] = None, + max_position_embeddings=2048, **kwargs, ) -> None: super().__init__() @@ -248,11 +237,8 @@ class RotaryEmbedding(nn.Module): ) self.register_buffer("scale", scale, persistent=False) - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None + # NOTE: initialize cached attributes + self._update_cos_sin_cache(seqlen=max_position_embeddings, device=device, dtype=torch.float32) def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) @@ -262,43 +248,36 @@ class RotaryEmbedding(nn.Module): ) -> None: # Reset the tables if sequence length has been chaned, if we are on a # new device or if we are switching from inference mode to training - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen + self._seq_len_cached = seqlen - # fp32 is preferred since the output of `torch.arange` can be quite large - # and bf16 would lose a lot of precision - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq + # fp32 is preferred since the output of `torch.arange` can be quite large + # and bf16 would lose a lot of precision + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq - # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - # Force the scale multiplication to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + # Force the scale multiplication to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward( self, @@ -309,10 +288,11 @@ class RotaryEmbedding(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: seqlen = qkv.shape[1] - if max_seqlen is not None: - self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) - else: - self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if seqlen > self._seq_len_cached: + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + else: + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: return _apply_rotary_emb_qkv(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) @@ -336,7 +316,6 @@ class MLP(nn.Module): super().__init__() act_fn = config.activation_function if act_fn is None else act_fn - assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}." n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner n_inner = n_inner if n_inner is not None else 4 * config.n_embd @@ -436,7 +415,6 @@ class CrossAttention(nn.Module): ) -> torch.FloatTensor: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] if kv.shape[3] != q.shape[2]: kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) @@ -474,14 +452,6 @@ def _find_mha_dims( n_head_kv: Optional[int] = None, head_dim: Optional[int] = None, ) -> Tuple[int, int]: - assert all( - hasattr(config, attr) for attr in ["n_embd", "n_head"] - ), "`config` must have `n_embd` and `n_head` attributes." - - if head_dim is None: - assert ( - config.n_embd % config.n_head == 0 - ), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})." if n_head is None and head_dim is None: head_dim = config.n_embd // config.n_head @@ -491,7 +461,6 @@ def _find_mha_dims( if n_head_kv is None: n_head_kv = getattr(config, "n_head_kv", None) or n_head - assert n_head % n_head_kv == 0, "`n_head` must be divisible by `n_head_kv`." return n_head, n_head_kv, head_dim @@ -515,13 +484,10 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] - assert batch_end <= kv_cache.shape[0] sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] - assert sequence_end <= kv_cache.shape[1] - assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv kv = kv_cache[batch_start:batch_end, :sequence_end, ...] @@ -560,7 +526,7 @@ class MHA(nn.Module): rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding if rotary_cls is None: rotary_cls = RotaryEmbedding - self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs) + self.rotary_emb = rotary_cls(self.rotary_emb_dim, max_position_embeddings=config.n_positions, **rotary_kwargs) # MLP self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim) @@ -632,7 +598,8 @@ class MHA(nn.Module): attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - if attention_mask is not None and torch.any(~attention_mask.bool()): + # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool()) + if attention_mask is not None: attention_mask = attention_mask.bool() else: attention_mask = None