Skip to content

Fix bad synthetic dataloader with per device batch size < 1. #1706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

wang2yn84
Copy link
Collaborator

@wang2yn84 wang2yn84 commented May 9, 2025

Description

Fix BadSyntheticDataIterator for grain. The local iterator is missing and workload will error out on when using Grain dataset together with pdb < 1.

Tests

Manually run the following workload: python -m MaxText.train MaxText/configs/base.yml skip_jax_distributed_system=True run_name=lance_test attention=dot_product dataset_type=grain tokenizer_path=assets/tokenizer.llama2 hardware=gpu logits_dot_in_fp32=false enable_goodput_recording=false monitor_goodput=false remat_policy=full weight_dtype=bfloat16 save_config_to_gcs=false scan_layers=false per_device_batch_size=0.25 dcn_fsdp_parallelism=-1 dcn_data_parallelism=1 ici_fsdp_parallelism=1 ici_tensor_parallelism=8 packing=false enable_checkpoint_cloud_logger=true dataset_path=/scratch/lancewang/dataset_pvc/ grain_train_files=/scratch/lancewang/dataset_pvc/array-record/c4/en/3.0.1/c4-train.array_record* grain_worker_count=1 enable_checkpointing=false async_checkpointing=true checkpoint_period=10 save_config_to_gcs=false base_output_directory=/scratch/lancewang/outputs

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@wang2yn84 wang2yn84 changed the title Fix the way we use bad synthetic dataloader. Fix bad synthetic dataloader with per device batch size < 1. May 9, 2025
@@ -83,6 +83,7 @@ def __init__(self, config, mesh):
self.mesh = mesh
dataset = BadSyntheticDataIterator.get_bad_synthetic_data(config)
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)
self.local_iterator = self.data_generator.local_iterator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be used somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Branden, yes it's used here

iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants