Fixes exceeding maximum sequence length when using generate().

This commit is contained in:
Gustavo de Rosa 2023-11-20 18:11:04 +00:00 committed by huggingface-web
parent d212a78962
commit 5fd430c7bc

@ -481,7 +481,7 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
num_heads, head_dim = kv.shape[-2:] num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict: if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty( inference_params.key_value_memory_dict[layer_idx] = torch.empty(
inference_params.max_batch_size, inference_params.max_batch_size,
inference_params.max_seqlen, inference_params.max_seqlen,
2, 2,
@ -490,9 +490,6 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
dtype=kv.dtype, dtype=kv.dtype,
device=kv.device, device=kv.device,
) )
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0] batch_end = batch_start + kv.shape[0]
@ -500,9 +497,14 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
sequence_start = inference_params.seqlen_offset sequence_start = inference_params.seqlen_offset
sequence_end = sequence_start + kv.shape[1] sequence_end = sequence_start + kv.shape[1]
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv # When the current sequence length is equal to or larger than the maximum sequence length,
kv = kv_cache[batch_start:batch_end, :sequence_end, ...] # we need to roll the cache to the left and update it
if sequence_end >= inference_params.max_seqlen:
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
return kv return kv
@ -710,7 +712,6 @@ class MHA(nn.Module):
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs, **kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.bool() attention_mask = attention_mask.bool()
else: else:
@ -863,6 +864,13 @@ class PhiPreTrainedModel(PreTrainedModel):
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
# the maximum sequence length
if input_ids.shape[1] > self.config.n_positions:
input_ids = input_ids[:, -self.config.n_positions :]
if attention_mask is not None:
attention_mask = attention_mask[:, -self.config.n_positions :]
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
past_key_values = InferenceParams( past_key_values = InferenceParams(
max_seqlen=self.config.n_positions, max_seqlen=self.config.n_positions,
@ -874,7 +882,7 @@ class PhiPreTrainedModel(PreTrainedModel):
) )
else: else:
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
past_key_values.seqlen_offset = len(input_ids[0]) - 1 past_key_values.seqlen_offset = input_ids.shape[1] - 1
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return { return {