From 914c8fb3c681ebe3cacbe3c748858a572283ddde Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Wed, 10 Jan 2024 13:54:40 +0000 Subject: [PATCH] Upload modeling_phi.py --- modeling_phi.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/modeling_phi.py b/modeling_phi.py index 9e0bcdb..7f9c2ca 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -358,7 +358,9 @@ class PhiAttention(nn.Module): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -374,7 +376,7 @@ class PhiAttention(nn.Module): attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) @@ -483,8 +485,10 @@ class PhiFlashAttention2(PhiAttention): # in fp32. if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype @@ -1093,7 +1097,7 @@ class PhiForCausalLM(PhiPreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] @@ -1225,9 +1229,10 @@ class PhiForSequenceClassification(PhiPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( - logits.device - ) + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1