fix(modeling_phi): Fixes cached generation when above maximum context length.
This commit is contained in:
parent
5fd430c7bc
commit
37527ba0b8
@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
|
||||
seqlen_offset: int = 0,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
seq_start = seqlen_offset
|
||||
seq_end = seq_start + qkv.shape[1]
|
||||
|
||||
if (
|
||||
self._cos_cached.device != qkv.device
|
||||
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
||||
or self._cos_cached.device != qkv.device
|
||||
or self._cos_cached.dtype != qkv.dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
|
||||
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
|
||||
if kv is None:
|
||||
return _apply_rotary_emb_qkv(
|
||||
qkv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
else:
|
||||
q = _apply_rotary_emb(
|
||||
qkv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
kv = _apply_rotary_emb_kv(
|
||||
kv,
|
||||
self._cos_cached[seq_start:seq_end],
|
||||
self._sin_cached[seq_start:seq_end],
|
||||
self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:],
|
||||
)
|
||||
|
||||
return q, kv
|
||||
@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
|
||||
# When the current sequence length is equal to or larger than the maximum sequence length,
|
||||
# we need to roll the cache to the left and update it
|
||||
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
||||
if sequence_end >= inference_params.max_seqlen:
|
||||
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
|
||||
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
||||
|
||||
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
||||
@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
|
||||
# the maximum sequence length
|
||||
if input_ids.shape[1] > self.config.n_positions:
|
||||
input_ids = input_ids[:, -self.config.n_positions :]
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -self.config.n_positions :]
|
||||
|
||||
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
||||
past_key_values = InferenceParams(
|
||||
max_seqlen=self.config.n_positions,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user