diff --git a/configuration_stablelm_epoch.py b/configuration_stablelm_epoch.py index ea24f38..7d8e249 100755 --- a/configuration_stablelm_epoch.py +++ b/configuration_stablelm_epoch.py @@ -65,6 +65,8 @@ class StableLMEpochConfig(PretrainedConfig): Whether or not the model should use bias for qkv layers. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. """ model_type = "stablelm_epoch" keys_to_ignore_at_inference = ["past_key_values"] @@ -88,6 +90,7 @@ class StableLMEpochConfig(PretrainedConfig): bos_token_id=0, eos_token_id=2, tie_word_embeddings=False, + attention_dropout: float = 0.0, **kwargs, ): self.vocab_size = vocab_size @@ -105,6 +108,7 @@ class StableLMEpochConfig(PretrainedConfig): self.use_cache = use_cache self.use_qkv_bias = use_qkv_bias self.tie_word_embeddings = tie_word_embeddings + self.attention_dropout = attention_dropout super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, diff --git a/modeling_stablelm_epoch.py b/modeling_stablelm_epoch.py index 930b98f..e7fde9f 100755 --- a/modeling_stablelm_epoch.py +++ b/modeling_stablelm_epoch.py @@ -191,6 +191,7 @@ class Attention(nn.Module): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.is_causal = True + self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -275,6 +276,7 @@ class Attention(nn.Module): # Upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):