fix(modeling_phi): Fixes initial generation with length larger than context length.
This commit is contained in:
parent
37527ba0b8
commit
ca573e3fa3
@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
||||
sequence_start = inference_params.seqlen_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
|
||||
# When the current sequence length is equal to or larger than the maximum sequence length,
|
||||
# When the current sequence length is larger than the maximum sequence length,
|
||||
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
||||
if sequence_end >= inference_params.max_seqlen:
|
||||
if sequence_end > inference_params.max_seqlen:
|
||||
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
||||
|
||||
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
||||
max_batch_size, max_seqlen = input_ids.shape
|
||||
past_key_values = InferenceParams(
|
||||
max_seqlen=self.config.n_positions,
|
||||
max_batch_size=input_ids.shape[0],
|
||||
max_seqlen=max(max_seqlen, self.config.n_positions),
|
||||
max_batch_size=max_batch_size,
|
||||
seqlen_offset=0,
|
||||
batch_size_offset=0,
|
||||
key_value_memory_dict={},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user