diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md
index 63439cb09..795846efd 100644
--- a/DOCUMENTATION.md
+++ b/DOCUMENTATION.md
@@ -91,7 +91,7 @@ With the exception of `_build_input_queue`, submitters can call any of these fun
 def step_hint(self): -> int
 ```
 
-- The `step_hint` function gives the number of global steps the baseline algorithm was allowed to use to reach the targets for a workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules.
+- The `step_hint` function gives the number of global steps the baseline algorithm can perform with the `max_runtime` to reach the targets for a workload. The `step_hint` is therefore dependent on the `max_runtime` and the workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules.
 
 ###### Data augmentation and preprocessing
 
@@ -418,7 +418,7 @@ In each trial, the tuning trial with the fastest training time to achieve the *v
 
 Submissions to this ruleset are not allowed to have user-defined hyperparameters. This ruleset allows both submissions that use the same hyperparameters for all workloads, including the randomized ones (e.g. Adam with default parameters), as well as submissions that perform inner-loop tuning during their training run (e.g. SGD with line searches).
 
-Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is tripled. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time.
+Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is $1.5$ times longer. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time.
 
 ### Workloads
 
@@ -439,11 +439,11 @@ The currently eight fixed workloads are:
 |            | **Task**                      | **Dataset** | **Model**               | **Loss** | **Metric** | Validation<br>**Target** | Test<br>**Target**   | Maximum<br>**Runtime** <br>(in secs) |
 |------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------|
 | **1**      | Clickthrough rate prediction  | Criteo 1TB  | DLRMsmall               | CE       | CE         | 0.123735                 | 0.126041                |       7,703                 |
-| **2**      | MRI reconstruction            | fastMRI     | U-Net                   | L1       | SSIM       | 0.723653                   | 0.740633             |          8,859              |
-| **3<br>4** | Image classification          | ImageNet    | ResNet-50<br>ViT        | CE       | ER         | 0.22569<br>0.22691        | 0.3440<br>0.3481    |        63,008    <br> 77,520            |
-| **5<br>6** | Speech recognition            | LibriSpeech | Conformer<br>DeepSpeech | CTC      | WER        | 0.085884<br>0.119936     | 0.052981<br>0.074143 |       61,068<br>55,506                 |
-| **7**      | Molecular property prediction | OGBG        | GNN                     | CE       | mAP        | 0.28098                 | 0.268729             |       18,477                 |
-| **8**      | Translation                   | WMT         | Transformer             | CE       | BLEU       | 30.8491                  | 30.7219              |       48,151                 |
+| **2**      | MRI reconstruction            | fastMRI     | U-Net                   | L1       | SSIM       | 0.723653                   | 0.740633             |          4,430              |
+| **3<br>4** | Image classification          | ImageNet    | ResNet-50<br>ViT        | CE       | ER         | 0.22569<br>0.22691        | 0.3440<br>0.3481    |        66,159    <br> 69,768            |
+| **5<br>6** | Speech recognition            | LibriSpeech | Conformer<br>DeepSpeech | CTC      | WER        | 0.085884<br>0.119936     | 0.052981<br>0.074143 |       58,015<br>44,405                 |
+| **7**      | Molecular property prediction | OGBG        | GNN                     | CE       | mAP        | 0.28098                 | 0.268729             |       12,011                 |
+| **8**      | Translation                   | WMT         | Transformer             | CE       | BLEU       | 30.8491                  | 30.7219              |       43,336                 |
 
 Default Dropout Values for Different Workloads:
 
@@ -503,7 +503,7 @@ For self-reported results, it is acceptable to perform the tuning trials on hard
 Target performances on the validation and test sets will be defined for each [workload](#workloads) separately. For the [fixed workloads](#fixed-workloads), we take the best performance achievable by one of four standard algorithms (AdamW, NadamW, Nesterov Momentum, and Heavy Ball Momentum). These target-setting algorithms will follow the general process of the external tuning ruleset, with a significantly larger tuning budget of $200$ trials to guarantee competitive performance. Once the best algorithm and its hyperparameters are determined, training is repeated $20$ times. The median of the best achieved validation errors across seeds is used as the *validation* target. Out of the $10$ repeated runs that achieved this validation target, we took the worst achieved test error across seeds as our *test* target. Taking the median validation performance after rerunning the best hyperparameter point prevents our procedure from selecting a lucky outlier.
 To save computational resources, we only tuned two training algorithms instead of four, for the [randomized workloads](#randomized-workloads). For each workload variant, we used NadamW and the other best-performing training algorithm on the corresponding fixed workload the randomized workload is based on.
 
-Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The `max_runtime` for submissions on each workload is $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be three times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section).
+Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The initial `max_runtime` for submissions on each workload was $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be $1.5$ times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section). After the initial round of submissions, we have adapated the `max_runtime` based on the performance of the submissions (see [this issue](https://github.com/mlcommons/algorithmic-efficiency/issues/836)).
 
 #### Benchmark score using performance profiles
 
diff --git a/algoperf/spec.py b/algoperf/spec.py
index 381d52f32..cf4f1a14e 100644
--- a/algoperf/spec.py
+++ b/algoperf/spec.py
@@ -206,7 +206,7 @@ def eval_period_time_sec(self) -> int:
   @property
   @abc.abstractmethod
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
+    """Approx. steps the baseline can do in the allowed runtime budget."""
 
   @property
   def param_shapes(self):
diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py
index 80ec9d67a..617b2e987 100644
--- a/algoperf/workloads/criteo1tb/workload.py
+++ b/algoperf/workloads/criteo1tb/workload.py
@@ -93,7 +93,7 @@ def train_stddev(self):
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 7703  # ~2 hours.
+    return 7_703  # ~2.1 hours.
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -123,7 +123,7 @@ def _build_input_queue(
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
+    """Approx. steps the baseline can do in the allowed runtime budget."""
     return 10_666
 
   def _eval_model_on_split(self,
diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py
index e9a2a313a..051749cc3 100644
--- a/algoperf/workloads/fastmri/workload.py
+++ b/algoperf/workloads/fastmri/workload.py
@@ -95,7 +95,7 @@ def accelerations(self):
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 8859  # ~2.5 hours
+    return 4_430  # ~1.2 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -103,8 +103,8 @@ def eval_period_time_sec(self) -> int:
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 36_189
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 18_094
 
   def _build_input_queue(self,
                          data_rng: spec.RandomState,
diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py
index 8b3393ded..83fe97108 100644
--- a/algoperf/workloads/imagenet_resnet/workload.py
+++ b/algoperf/workloads/imagenet_resnet/workload.py
@@ -102,7 +102,7 @@ def resize_size(self) -> int:
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 63_008  # ~17.5 hours
+    return 66_159  # ~18.4 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -144,5 +144,5 @@ def _build_input_queue(
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 186_666
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 195_999
diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py
index 7f06715a3..9c885ca7c 100644
--- a/algoperf/workloads/imagenet_vit/workload.py
+++ b/algoperf/workloads/imagenet_vit/workload.py
@@ -3,8 +3,7 @@
 from typing import Dict, Iterator, Optional
 
 from algoperf import spec
-from algoperf.workloads.imagenet_resnet.workload import \
-    BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.workload import BaseImagenetResNetWorkload
 
 
 def decode_variant(variant: str) -> Dict[str, int]:
@@ -81,7 +80,7 @@ def eval_batch_size(self) -> int:
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 77_520  # ~22 hours
+    return 69_768  # ~19.4 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -110,5 +109,5 @@ def _build_dataset(
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 186_666
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 167_999
diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py
index c9f5a3c59..94f01dd97 100644
--- a/algoperf/workloads/librispeech_conformer/workload.py
+++ b/algoperf/workloads/librispeech_conformer/workload.py
@@ -79,7 +79,7 @@ def train_stddev(self):
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 61_068  # ~17 hours
+    return 58_015  # ~16.1 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -87,5 +87,5 @@ def eval_period_time_sec(self) -> int:
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 80_000
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 76_000
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 9fd0898b4..1cadebf45 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -1,15 +1,15 @@
 import functools
 from typing import Dict, Optional, Tuple
 
-from flax import jax_utils
 import jax
 import jax.numpy as jnp
 import numpy as np
+from flax import jax_utils
 
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
-    LibriSpeechConformerWorkload
+from algoperf import param_utils, spec
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+  LibriSpeechConformerWorkload,
+)
 from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models
 
 
@@ -99,12 +99,12 @@ def test_target_value(self) -> float:
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 48_000
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 38_400
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 55_506  # ~15.4 hours
+    return 44_405  # ~12.3 hours
 
   @property
   def use_tanh(self) -> bool:
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index 4f8ad1974..c72c1daee 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -3,17 +3,18 @@
 import torch
 from torch.nn.parallel import DistributedDataParallel as DDP
 
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
 from algoperf.pytorch_utils import pytorch_setup
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
-    initialize
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
-    LibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
-    DeepspeechConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
-    DeepspeechEncoderDecoder
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+  initialize,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+  LibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+  DeepspeechConfig,
+  DeepspeechEncoderDecoder,
+)
 
 USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
 
@@ -76,12 +77,12 @@ def test_target_value(self) -> float:
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 48_000
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 38_400
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 55_506  # ~15.4 hours
+    return 44_405  # ~12.3 hours
 
   @property
   def use_tanh(self) -> bool:
diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py
index c6a2162d7..ca123f885 100644
--- a/algoperf/workloads/ogbg/workload.py
+++ b/algoperf/workloads/ogbg/workload.py
@@ -9,8 +9,7 @@
 
 from algoperf import random_utils as prng
 from algoperf import spec
-from algoperf.workloads.ogbg import input_pipeline
-from algoperf.workloads.ogbg import metrics
+from algoperf.workloads.ogbg import input_pipeline, metrics
 
 
 class BaseOgbgWorkload(spec.Workload):
@@ -88,7 +87,7 @@ def train_stddev(self):
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 18_477  # ~5 hours
+    return 12_011  # ~3.3 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -140,8 +139,8 @@ def loss_fn(
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 80_000
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 52_000
 
   @abc.abstractmethod
   def _normalize_eval_metrics(
diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py
index e9a07d2b3..51b33373d 100644
--- a/algoperf/workloads/wmt/workload.py
+++ b/algoperf/workloads/wmt/workload.py
@@ -88,7 +88,7 @@ def train_stddev(self):
 
   @property
   def max_allowed_runtime_sec(self) -> int:
-    return 48_151  # ~13.5 hours
+    return 43_336  # ~12.0 hours
 
   @property
   def eval_period_time_sec(self) -> int:
@@ -96,8 +96,8 @@ def eval_period_time_sec(self) -> int:
 
   @property
   def step_hint(self) -> int:
-    """Max num steps the baseline algo was given to reach the target."""
-    return 133_333
+    """Approx. steps the baseline can do in the allowed runtime budget."""
+    return 120_000
 
   @property
   def pre_ln(self) -> bool:
diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py
index 5fb5f259d..d0e5bf70b 100644
--- a/scoring/compute_speedups.py
+++ b/scoring/compute_speedups.py
@@ -25,6 +25,7 @@
                      'Whether to save the results to disk.')
 FLAGS = flags.FLAGS
 
+# These are the old budgets, used in the first iteration of the competition.
 MAX_BUDGETS = {
     'criteo1tb': 7703,
     'fastmri': 8859,
diff --git a/submission_runner.py b/submission_runner.py
index 1be56aeab..a9d13e7cb 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -409,10 +409,10 @@ def train_once(
             prepare_for_eval_end_time - prepare_for_eval_start_time)
 
       # Check if time is remaining,
-      # use 3x the runtime budget for the self-tuning ruleset.
+      # use 1.5x the runtime budget for the self-tuning ruleset.
       max_allowed_runtime_sec = (
           workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
-          else 3 * workload.max_allowed_runtime_sec)
+          else 1.5 * workload.max_allowed_runtime_sec)
       train_state['is_time_remaining'] = (
           train_state['accumulated_submission_time'] < max_allowed_runtime_sec)