feat: add dropout support

This commit is contained in:
jon-tow 2024-01-23 18:49:25 +00:00
parent 4c846d7114
commit 810b45c00e
2 changed files with 6 additions and 0 deletions

@ -65,6 +65,8 @@ class StableLMEpochConfig(PretrainedConfig):
Whether or not the model should use bias for qkv layers.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type = "stablelm_epoch"
keys_to_ignore_at_inference = ["past_key_values"]
@ -88,6 +90,7 @@ class StableLMEpochConfig(PretrainedConfig):
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=False,
attention_dropout: float = 0.0,
**kwargs,
):
self.vocab_size = vocab_size
@ -105,6 +108,7 @@ class StableLMEpochConfig(PretrainedConfig):
self.use_cache = use_cache
self.use_qkv_bias = use_qkv_bias
self.tie_word_embeddings = tie_word_embeddings
self.attention_dropout = attention_dropout
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,

@ -191,6 +191,7 @@ class Attention(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.is_causal = True
self.attention_dropout = config.attention_dropout
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
@ -275,6 +276,7 @@ class Attention(nn.Module):
# Upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):