fix(modeling_phi): Fixes cached generation when above maximum context length.
This commit is contained in:
parent
5fd430c7bc
commit
37527ba0b8
@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
|
|||||||
seqlen_offset: int = 0,
|
seqlen_offset: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
seq_start = seqlen_offset
|
|
||||||
seq_end = seq_start + qkv.shape[1]
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._cos_cached.device != qkv.device
|
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
||||||
|
or self._cos_cached.device != qkv.device
|
||||||
or self._cos_cached.dtype != qkv.dtype
|
or self._cos_cached.dtype != qkv.dtype
|
||||||
or (self.training and self._cos_cached.is_inference())
|
or (self.training and self._cos_cached.is_inference())
|
||||||
):
|
):
|
||||||
self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
|
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||||
|
|
||||||
if kv is None:
|
if kv is None:
|
||||||
return _apply_rotary_emb_qkv(
|
return _apply_rotary_emb_qkv(
|
||||||
qkv,
|
qkv,
|
||||||
self._cos_cached[seq_start:seq_end],
|
self._cos_cached[seqlen_offset:],
|
||||||
self._sin_cached[seq_start:seq_end],
|
self._sin_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = _apply_rotary_emb(
|
q = _apply_rotary_emb(
|
||||||
qkv,
|
qkv,
|
||||||
self._cos_cached[seq_start:seq_end],
|
self._cos_cached[seqlen_offset:],
|
||||||
self._sin_cached[seq_start:seq_end],
|
self._sin_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
kv = _apply_rotary_emb_kv(
|
kv = _apply_rotary_emb_kv(
|
||||||
kv,
|
kv,
|
||||||
self._cos_cached[seq_start:seq_end],
|
self._cos_cached[seqlen_offset:],
|
||||||
self._sin_cached[seq_start:seq_end],
|
self._sin_cached[seqlen_offset:],
|
||||||
)
|
)
|
||||||
|
|
||||||
return q, kv
|
return q, kv
|
||||||
@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|||||||
sequence_end = sequence_start + kv.shape[1]
|
sequence_end = sequence_start + kv.shape[1]
|
||||||
|
|
||||||
# When the current sequence length is equal to or larger than the maximum sequence length,
|
# When the current sequence length is equal to or larger than the maximum sequence length,
|
||||||
# we need to roll the cache to the left and update it
|
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
||||||
if sequence_end >= inference_params.max_seqlen:
|
if sequence_end >= inference_params.max_seqlen:
|
||||||
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
|
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
||||||
|
|
||||||
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||||
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
||||||
@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|||||||
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
|
|
||||||
# the maximum sequence length
|
|
||||||
if input_ids.shape[1] > self.config.n_positions:
|
|
||||||
input_ids = input_ids[:, -self.config.n_positions :]
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask[:, -self.config.n_positions :]
|
|
||||||
|
|
||||||
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
||||||
past_key_values = InferenceParams(
|
past_key_values = InferenceParams(
|
||||||
max_seqlen=self.config.n_positions,
|
max_seqlen=self.config.n_positions,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user