From e0f03c4877e57ed1f16bca47a0910d4b7bb28452 Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Thu, 11 Jan 2024 16:40:17 +0000 Subject: [PATCH] Update modeling_phi.py --- modeling_phi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modeling_phi.py b/modeling_phi.py index 7f9c2ca..52f1d42 100644 --- a/modeling_phi.py +++ b/modeling_phi.py @@ -308,7 +308,6 @@ class PhiAttention(nn.Module): past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -358,6 +357,7 @@ class PhiAttention(nn.Module): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow attn_weights = torch.matmul( query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) ) / math.sqrt(self.head_dim)