diff --git a/modeling_phi.py b/modeling_phi.py index 519907e..f0710a2 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] - # When the current sequence length is equal to or larger than the maximum sequence length, + # When the current sequence length is larger than the maximum sequence length, # we need to concatenate the current `kv` with the cached `kv` to expand its length - if sequence_end >= inference_params.max_seqlen: + if sequence_end > inference_params.max_seqlen: inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1) inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv @@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel): **kwargs, ) -> Dict[str, Any]: if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): + max_batch_size, max_seqlen = input_ids.shape past_key_values = InferenceParams( - max_seqlen=self.config.n_positions, - max_batch_size=input_ids.shape[0], + max_seqlen=max(max_seqlen, self.config.n_positions), + max_batch_size=max_batch_size, seqlen_offset=0, batch_size_offset=0, key_value_memory_dict={},