Update modeling_phi.py

This commit is contained in:
Gustavo de Rosa 2024-01-11 16:40:17 +00:00 committed by system
parent 051d15f1e7
commit e0f03c4877
No known key found for this signature in database
GPG Key ID: 6A528E38E0733467

@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
attn_weights = torch.matmul( attn_weights = torch.matmul(
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
) / math.sqrt(self.head_dim) ) / math.sqrt(self.head_dim)