diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py index a16b99b70..09c0cf2de 100644 --- a/algoperf/workloads/wmt/wmt_jax/models.py +++ b/algoperf/workloads/wmt/wmt_jax/models.py @@ -223,7 +223,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE): bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=0.0, + dropout_rate=0.0, # The dropout is applied at the end of this layer deterministic=cfg.deterministic, )(cfg.attention_temp * x, x, mask=encoder_mask) @@ -286,7 +286,7 @@ def __call__( bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, - dropout_rate=dropout_rate, + dropout_rate=0.0, # Dropout applied after MultiHeadDotProductAttention deterministic=cfg.deterministic, decode=cfg.decode, )(cfg.attention_temp * x, x, mask=decoder_mask)