Skip to content

Commit f7f18c7

Browse files
committed
Add max-prefill-length argument in distillation dataset generation script
1 parent 5c7df89 commit f7f18c7

File tree

3 files changed

+32
-29
lines changed

3 files changed

+32
-29
lines changed

MaxText/generate_distillation_data.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,26 @@
2424
--dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft --data-columns messages \
2525
--tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \
2626
--hf-access-token <access token> \
27-
--batch-size 1024 --num-batches 100 \
27+
--batch-size 1024 --num-batches 10 \
2828
--num-generations 2 \
29-
--max-output-length 128 --max-target-length 256 \
29+
--max-prefill-length 256 --max-target-length 2048 \
3030
--use-chat-template --remove-local-dataset-files \
3131
upload-to-hf --hf-repo-id <hf repository id>
3232
33-
Running this command executes 100 processing steps.
34-
In each step, it generates completions for a batch of 40 prompts.
35-
This results in inference running on 4000 prompts overall, producing 2 samples per prompt.
33+
Running this command executes 10 processing steps.
34+
In each step, it generates completions for a batch of 1024 prompts.
35+
This results in inference running on 10240 prompts overall, producing 2 unique samples per prompt.
36+
Some prompts may be filtered out if prompt tokens are longer than `max-prefill-length`.
37+
`max-target-length` is the max length of prompt tokens and expected completion tokens.
38+
Set `--remove-local-dataset-files` to remove dataset files created locally after uploading to Hugging Face or GCS.
39+
`upload-to-hf` will upload the dataset to Hugging Face and `upload-to-gcs` will upload the dataset to GCS.
40+
For more information, check out `python3 -m MaxText.generate_distillation_data --help`.
3641
Note:
3742
Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server:
3843
python3 -m MaxText.maxengine_server MaxText/configs/base.yml \
3944
model_name=deepseek2-16b tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \
4045
load_parameters_path=<unscanned checkpoint path> \
46+
max_target_length=2048 max_prefill_predict_length=256 \
4147
per_device_batch_size=10 multi_sampling=True ici_tensor_parallelism=4 \
4248
decode_sampling_strategy=weighted scan_layers=False
4349
"""
@@ -92,7 +98,7 @@ async def send_request(config, request, stub, tokenizer, progress_bar): # pylin
9298

9399
outputs = []
94100
for tokens in completion_tokens:
95-
completion = tokenizer.decode(tokens, skip_special_tokens=True)
101+
completion = tokenizer.decode(tokens, skip_special_tokens=True).strip()
96102
outputs.append(
97103
{
98104
"prompt": [{"role": "user", "content": prompt}],
@@ -256,9 +262,7 @@ def generate_data(config): # pylint: disable=redefined-outer-name
256262
)
257263
parser.add_argument("--tokenizer-path", type=str, required=True, help="Path to Hugging Face tokenizer.")
258264
parser.add_argument("--use-chat-template", action="store_true", help="Enable tokenizer to apply a chat template.")
259-
parser.add_argument(
260-
"--max-output-length", type=int, required=True, help="The maximum completion tokens to generate for a prompt."
261-
)
265+
parser.add_argument("--max-prefill-length", type=int, default=256, help="The maximum prompt length.")
262266
parser.add_argument(
263267
"--max-target-length", type=int, default=2048, help="The maximum prompt length plus the output completion length."
264268
)
@@ -293,6 +297,6 @@ def generate_data(config): # pylint: disable=redefined-outer-name
293297
config = parser.parse_args()
294298

295299
assert (
296-
config.max_output_length < config.max_target_length
297-
), "Maximum output length of completion should be less than maximum target length."
300+
config.max_prefill_length < config.max_target_length
301+
), "Maximum length of prompt should be less than maximum target length."
298302
generate_data(config)

MaxText/input_pipeline/_distillation_data_processing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,21 @@ def filter_dataset(config, dataset, tokenizer):
118118
prompt = data["prompt"][0]
119119
actual_completion = data["completion"][0]
120120

121-
max_output_tokens = min(config.max_output_length, len(tokenizer.encode(actual_completion)))
121+
max_output_length = config.max_target_length - config.max_prefill_length
122+
max_output_tokens = min(max_output_length, len(tokenizer.encode(actual_completion)))
122123
if config.use_chat_template:
123124
message = [{"role": "user", "content": prompt}]
124125
prompt_token_ids = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=True)
125126
else:
126127
prompt_token_ids = tokenizer.encode(prompt)
127128

128-
# Filter out long prompt sequences
129-
if len(prompt_token_ids) + max_output_tokens > config.max_target_length:
129+
# Filter out prompt sequences that are longer than max_prefill_length
130+
if len(prompt_token_ids) > config.max_prefill_length:
130131
continue
131132

132133
request = InputRequest(prompt, prompt_token_ids, actual_completion, max_output_tokens)
133134
filtered_dataset.append(request)
134135
if len(filtered_dataset) < len(dataset):
136+
max_logging.log("Some prompts are longer than `max-prefill-length` and will be filtered out.")
135137
max_logging.log(f"Filtering reduced dataset batch from {len(dataset)} to {len(filtered_dataset)} samples.")
136138
return filtered_dataset

MaxText/tests/distillation_data_processing_test.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""
15-
distillation data processing test
16-
"""
14+
15+
"""Data processing tests for distillation."""
1716

1817
import argparse
1918
import os
@@ -33,7 +32,7 @@
3332
{"content": "Why is the sky blue?", "role": "user"},
3433
],
3534
[
36-
{"content": "How many days are in a week?", "role": "user"},
35+
{"content": "Can you tell me how many days are in a week?", "role": "user"},
3736
],
3837
]
3938

@@ -55,7 +54,7 @@
5554
{"content": "The sky appears blue due a phenomemon called Rayleigh scattering.", "role": "assistant"},
5655
],
5756
[
58-
{"content": "How many days are in a week?", "role": "user"},
57+
{"content": "Can you tell me how many days are in a week?", "role": "user"},
5958
{"content": "There are 7 days in a week.", "role": "assistant"},
6059
],
6160
]
@@ -64,11 +63,9 @@
6463
def add_arguments_to_parser(parser):
6564
parser.add_argument("--data-columns", nargs="+", required=True, help="Columns names that contain relevant data.")
6665
parser.add_argument("--use-chat-template", action="store_true", help="Enable tokenizer to apply a chat template.")
66+
parser.add_argument("--max-prefill-length", type=int, default=16, help="The maximum length for prompt tokens.")
6767
parser.add_argument(
68-
"--max-output-length", type=int, default=8, help="The maximum completion tokens to generate for a prompt."
69-
)
70-
parser.add_argument(
71-
"--max-target-length", type=int, default=16, help="The maximum prompt length plus the output completion length."
68+
"--max-target-length", type=int, default=32, help="The maximum prompt length plus the output completion length."
7269
)
7370
return parser
7471

@@ -83,7 +80,7 @@ def setUpClass(cls):
8380
"gsutil",
8481
"cp",
8582
"-r",
86-
"gs://maxtext-dataset/hf/llama2-tokenizer",
83+
"gs://maxtext-dataset/hf/llama2-chat-tokenizer",
8784
os.path.join(os.path.dirname(PKG_DIR), "assets", ""),
8885
]
8986
)
@@ -93,7 +90,7 @@ def setUpClass(cls):
9390
def setUp(self):
9491
super().setUp()
9592
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
96-
os.path.join(os.path.dirname(PKG_DIR), "assets", "llama2-tokenizer"),
93+
os.path.join(os.path.dirname(PKG_DIR), "assets", "llama2-chat-tokenizer"),
9794
)
9895
self.parser = argparse.ArgumentParser()
9996
self.parser = add_arguments_to_parser(self.parser)
@@ -104,7 +101,7 @@ def test_data_processing_with_messages(self):
104101

105102
processed_dataset = _distillation_data_processing.process_dataset(config, dataset)
106103

107-
expected_prompts = [["What color is the sky?", "Why is the sky blue?"], ["How many days are in a week?"]]
104+
expected_prompts = [["What color is the sky?", "Why is the sky blue?"], ["Can you tell me how many days are in a week?"]]
108105
expected_completions = [
109106
["The sky is blue.", "The sky appears blue due a phenomemon called Rayleigh scattering."],
110107
["There are 7 days in a week."],
@@ -121,7 +118,7 @@ def test_data_processing_with_messages(self):
121118
self.assertEqual(data["completion"][c_idx], completion)
122119

123120
def test_data_filtering_with_messages(self):
124-
config = self.parser.parse_args(["--data-columns", "messages"])
121+
config = self.parser.parse_args(["--data-columns", "messages", "--use-chat-template"])
125122
dataset = Dataset.from_dict({"messages": MESSAGES_DATA})
126123

127124
processed_dataset = _distillation_data_processing.process_dataset(config, dataset)
@@ -137,7 +134,7 @@ def test_data_processing_with_prompt_completion(self):
137134

138135
processed_dataset = _distillation_data_processing.process_dataset(config, dataset)
139136

140-
expected_prompts = [["What color is the sky?", "Why is the sky blue?"], ["How many days are in a week?"]]
137+
expected_prompts = [["What color is the sky?", "Why is the sky blue?"], ["Can you tell me how many days are in a week?"]]
141138
expected_completions = [
142139
["The sky is blue.", "The sky appears blue due a phenomemon called Rayleigh scattering."],
143140
["There are 7 days in a week."],
@@ -154,7 +151,7 @@ def test_data_processing_with_prompt_completion(self):
154151
self.assertEqual(data["completion"][c_idx], completion)
155152

156153
def test_data_filtering_with_prompt_completion(self):
157-
config = self.parser.parse_args(["--data-columns", "prompt", "completion"])
154+
config = self.parser.parse_args(["--data-columns", "prompt", "completion", "--use-chat-template"])
158155
dataset = Dataset.from_dict({"prompt": PROMPT_DATA, "completion": COMPLETION_DATA})
159156

160157
processed_dataset = _distillation_data_processing.process_dataset(config, dataset)

0 commit comments

Comments
 (0)