diff --git a/curated_transformers/layers/attention.py b/curated_transformers/layers/attention.py index d72c9f60..0ea0875a 100644 --- a/curated_transformers/layers/attention.py +++ b/curated_transformers/layers/attention.py @@ -708,7 +708,7 @@ def forward( key=key, value=value, attn_mask=logit_mask, - dropout_p=self.dropout_prob if self.training else 0.0, + dropout_p=self.dropout.p if self.training else 0.0, ) # Torch SDP returns NaNs for pieces where every is piece masked out.