Update modeling_phi.py
This commit is contained in:
parent
051d15f1e7
commit
e0f03c4877
@ -308,7 +308,6 @@ class PhiAttention(nn.Module):
|
|||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
@ -358,6 +357,7 @@ class PhiAttention(nn.Module):
|
|||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
||||||
attn_weights = torch.matmul(
|
attn_weights = torch.matmul(
|
||||||
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
||||||
) / math.sqrt(self.head_dim)
|
) / math.sqrt(self.head_dim)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user