fix(modeling_phi): Fixes initial generation with length larger than context length.
This commit is contained in:
parent
37527ba0b8
commit
ca573e3fa3
@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|||||||
sequence_start = inference_params.seqlen_offset
|
sequence_start = inference_params.seqlen_offset
|
||||||
sequence_end = sequence_start + kv.shape[1]
|
sequence_end = sequence_start + kv.shape[1]
|
||||||
|
|
||||||
# When the current sequence length is equal to or larger than the maximum sequence length,
|
# When the current sequence length is larger than the maximum sequence length,
|
||||||
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
||||||
if sequence_end >= inference_params.max_seqlen:
|
if sequence_end > inference_params.max_seqlen:
|
||||||
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
||||||
|
|
||||||
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||||
@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
||||||
|
max_batch_size, max_seqlen = input_ids.shape
|
||||||
past_key_values = InferenceParams(
|
past_key_values = InferenceParams(
|
||||||
max_seqlen=self.config.n_positions,
|
max_seqlen=max(max_seqlen, self.config.n_positions),
|
||||||
max_batch_size=input_ids.shape[0],
|
max_batch_size=max_batch_size,
|
||||||
seqlen_offset=0,
|
seqlen_offset=0,
|
||||||
batch_size_offset=0,
|
batch_size_offset=0,
|
||||||
key_value_memory_dict={},
|
key_value_memory_dict={},
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user