diff --git a/modeling_phi.py b/modeling_phi.py index 93d0bcf..519907e 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module): seqlen_offset: int = 0, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - seq_start = seqlen_offset - seq_end = seq_start + qkv.shape[1] - if ( - self._cos_cached.device != qkv.device + self._seq_len_cached < qkv.shape[1] + seqlen_offset + or self._cos_cached.device != qkv.device or self._cos_cached.dtype != qkv.dtype or (self.training and self._cos_cached.is_inference()) ): - self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype) + self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: return _apply_rotary_emb_qkv( qkv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], + self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:], ) else: q = _apply_rotary_emb( qkv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], + self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:], ) kv = _apply_rotary_emb_kv( kv, - self._cos_cached[seq_start:seq_end], - self._sin_cached[seq_start:seq_end], + self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:], ) return q, kv @@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l sequence_end = sequence_start + kv.shape[1] # When the current sequence length is equal to or larger than the maximum sequence length, - # we need to roll the cache to the left and update it + # we need to concatenate the current `kv` with the cached `kv` to expand its length if sequence_end >= inference_params.max_seqlen: - inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1) + inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1) inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...] @@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel): attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, **kwargs, ) -> Dict[str, Any]: - # Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding - # the maximum sequence length - if input_ids.shape[1] > self.config.n_positions: - input_ids = input_ids[:, -self.config.n_positions :] - if attention_mask is not None: - attention_mask = attention_mask[:, -self.config.n_positions :] - if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): past_key_values = InferenceParams( max_seqlen=self.config.n_positions,