fix(modeling_phi): Fixes initial generation with length larger than context length.

This commit is contained in:
Gustavo de Rosa 2023-12-08 17:40:16 +00:00 committed by huggingface-web
parent 37527ba0b8
commit ca573e3fa3

@ -495,9 +495,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
sequence_start = inference_params.seqlen_offset sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
# When the current sequence length is equal to or larger than the maximum sequence length, # When the current sequence length is larger than the maximum sequence length,
# we need to concatenate the current `kv` with the cached `kv` to expand its length # we need to concatenate the current `kv` with the cached `kv` to expand its length
if sequence_end >= inference_params.max_seqlen: if sequence_end > inference_params.max_seqlen:
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1) inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
@ -863,9 +863,10 @@ class PhiPreTrainedModel(PreTrainedModel):
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
max_batch_size, max_seqlen = input_ids.shape
past_key_values = InferenceParams( past_key_values = InferenceParams(
max_seqlen=self.config.n_positions, max_seqlen=max(max_seqlen, self.config.n_positions),
max_batch_size=input_ids.shape[0], max_batch_size=max_batch_size,
seqlen_offset=0, seqlen_offset=0,
batch_size_offset=0, batch_size_offset=0,
key_value_memory_dict={}, key_value_memory_dict={},