diff --git a/modeling_phi.py b/modeling_phi.py index b1ac5e8..93d0bcf 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -481,7 +481,7 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l num_heads, head_dim = kv.shape[-2:] if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( + inference_params.key_value_memory_dict[layer_idx] = torch.empty( inference_params.max_batch_size, inference_params.max_seqlen, 2, @@ -490,9 +490,6 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l dtype=kv.dtype, device=kv.device, ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - kv_cache = inference_params.key_value_memory_dict[layer_idx] batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] @@ -500,9 +497,14 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + # When the current sequence length is equal to or larger than the maximum sequence length, + # we need to roll the cache to the left and update it + if sequence_end >= inference_params.max_seqlen: + inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1) + inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...] + return kv @@ -710,7 +712,6 @@ class MHA(nn.Module): attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, **kwargs, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool()) if attention_mask is not None: attention_mask = attention_mask.bool() else: @@ -863,6 +864,13 @@ class PhiPreTrainedModel(PreTrainedModel): attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, **kwargs, ) -> Dict[str, Any]: + # Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding + # the maximum sequence length + if input_ids.shape[1] > self.config.n_positions: + input_ids = input_ids[:, -self.config.n_positions :] + if attention_mask is not None: + attention_mask = attention_mask[:, -self.config.n_positions :] + if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): past_key_values = InferenceParams( max_seqlen=self.config.n_positions, @@ -874,7 +882,7 @@ class PhiPreTrainedModel(PreTrainedModel): ) else: # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` - past_key_values.seqlen_offset = len(input_ids[0]) - 1 + past_key_values.seqlen_offset = input_ids.shape[1] - 1 input_ids = input_ids[:, -1].unsqueeze(-1) return {