diff --git a/MaxText/input_pipeline/input_pipeline_interface.py b/MaxText/input_pipeline/input_pipeline_interface.py index 149213a5e..21b1c64dc 100644 --- a/MaxText/input_pipeline/input_pipeline_interface.py +++ b/MaxText/input_pipeline/input_pipeline_interface.py @@ -83,6 +83,7 @@ def __init__(self, config, mesh): self.mesh = mesh dataset = BadSyntheticDataIterator.get_bad_synthetic_data(config) self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) + self.local_iterator = self.data_generator.local_iterator def __iter__(self): return self.data_generator