Update modeling_phi.py
This commit is contained in:
parent
59e722d14e
commit
34a1490e06
@ -509,7 +509,7 @@ class PhiFlashAttention2(PhiAttention):
|
|||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
attn_output = self._flash_attention_forward(
|
attn_output = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0
|
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user