fix(modeling_phi): Fixes cached generation when above maximum context length.

This commit is contained in:
Gustavo de Rosa 2023-12-05 21:09:53 +00:00 committed by huggingface-web
parent 5fd430c7bc
commit 37527ba0b8

@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
seqlen_offset: int = 0,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_start = seqlen_offset
seq_end = seq_start + qkv.shape[1]
if (
self._cos_cached.device != qkv.device
self._seq_len_cached < qkv.shape[1] + seqlen_offset
or self._cos_cached.device != qkv.device
or self._cos_cached.dtype != qkv.dtype
or (self.training and self._cos_cached.is_inference())
):
self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
return _apply_rotary_emb_qkv(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
else:
q = _apply_rotary_emb(
qkv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
kv = _apply_rotary_emb_kv(
kv,
self._cos_cached[seq_start:seq_end],
self._sin_cached[seq_start:seq_end],
self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:],
)
return q, kv
@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
sequence_end = sequence_start + kv.shape[1]
# When the current sequence length is equal to or larger than the maximum sequence length,
# we need to roll the cache to the left and update it
# we need to concatenate the current `kv` with the cached `kv` to expand its length
if sequence_end >= inference_params.max_seqlen:
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs,
) -> Dict[str, Any]:
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
# the maximum sequence length
if input_ids.shape[1] > self.config.n_positions:
input_ids = input_ids[:, -self.config.n_positions :]
if attention_mask is not None:
attention_mask = attention_mask[:, -self.config.n_positions :]
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams(
max_seqlen=self.config.n_positions,