Update modeling_phi.py

This commit is contained in:
Gustavo de Rosa 2024-01-16 16:05:38 +00:00 committed by system
parent 59e722d14e
commit 34a1490e06
No known key found for this signature in database
GPG Key ID: 6A528E38E0733467

@ -509,7 +509,7 @@ class PhiFlashAttention2(PhiAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()