Update modeling_phi.py
This commit is contained in:
parent
051d15f1e7
commit
e0f03c4877
@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
||||
attn_weights = torch.matmul(
|
||||
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user