11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- """
15
- distillation data processing test
16
- """
14
+
15
+ """Data processing tests for distillation."""
17
16
18
17
import argparse
19
18
import os
33
32
{"content" : "Why is the sky blue?" , "role" : "user" },
34
33
],
35
34
[
36
- {"content" : "How many days are in a week?" , "role" : "user" },
35
+ {"content" : "Can you tell me how many days are in a week?" , "role" : "user" },
37
36
],
38
37
]
39
38
55
54
{"content" : "The sky appears blue due a phenomemon called Rayleigh scattering." , "role" : "assistant" },
56
55
],
57
56
[
58
- {"content" : "How many days are in a week?" , "role" : "user" },
57
+ {"content" : "Can you tell me how many days are in a week?" , "role" : "user" },
59
58
{"content" : "There are 7 days in a week." , "role" : "assistant" },
60
59
],
61
60
]
64
63
def add_arguments_to_parser (parser ):
65
64
parser .add_argument ("--data-columns" , nargs = "+" , required = True , help = "Columns names that contain relevant data." )
66
65
parser .add_argument ("--use-chat-template" , action = "store_true" , help = "Enable tokenizer to apply a chat template." )
66
+ parser .add_argument ("--max-prefill-length" , type = int , default = 16 , help = "The maximum length for prompt tokens." )
67
67
parser .add_argument (
68
- "--max-output-length" , type = int , default = 8 , help = "The maximum completion tokens to generate for a prompt."
69
- )
70
- parser .add_argument (
71
- "--max-target-length" , type = int , default = 16 , help = "The maximum prompt length plus the output completion length."
68
+ "--max-target-length" , type = int , default = 32 , help = "The maximum prompt length plus the output completion length."
72
69
)
73
70
return parser
74
71
@@ -83,7 +80,7 @@ def setUpClass(cls):
83
80
"gsutil" ,
84
81
"cp" ,
85
82
"-r" ,
86
- "gs://maxtext-dataset/hf/llama2-tokenizer" ,
83
+ "gs://maxtext-dataset/hf/llama2-chat- tokenizer" ,
87
84
os .path .join (os .path .dirname (PKG_DIR ), "assets" , "" ),
88
85
]
89
86
)
@@ -93,7 +90,7 @@ def setUpClass(cls):
93
90
def setUp (self ):
94
91
super ().setUp ()
95
92
self .tokenizer = transformers .AutoTokenizer .from_pretrained (
96
- os .path .join (os .path .dirname (PKG_DIR ), "assets" , "llama2-tokenizer" ),
93
+ os .path .join (os .path .dirname (PKG_DIR ), "assets" , "llama2-chat- tokenizer" ),
97
94
)
98
95
self .parser = argparse .ArgumentParser ()
99
96
self .parser = add_arguments_to_parser (self .parser )
@@ -104,7 +101,7 @@ def test_data_processing_with_messages(self):
104
101
105
102
processed_dataset = _distillation_data_processing .process_dataset (config , dataset )
106
103
107
- expected_prompts = [["What color is the sky?" , "Why is the sky blue?" ], ["How many days are in a week?" ]]
104
+ expected_prompts = [["What color is the sky?" , "Why is the sky blue?" ], ["Can you tell me how many days are in a week?" ]]
108
105
expected_completions = [
109
106
["The sky is blue." , "The sky appears blue due a phenomemon called Rayleigh scattering." ],
110
107
["There are 7 days in a week." ],
@@ -121,7 +118,7 @@ def test_data_processing_with_messages(self):
121
118
self .assertEqual (data ["completion" ][c_idx ], completion )
122
119
123
120
def test_data_filtering_with_messages (self ):
124
- config = self .parser .parse_args (["--data-columns" , "messages" ])
121
+ config = self .parser .parse_args (["--data-columns" , "messages" , "--use-chat-template" ])
125
122
dataset = Dataset .from_dict ({"messages" : MESSAGES_DATA })
126
123
127
124
processed_dataset = _distillation_data_processing .process_dataset (config , dataset )
@@ -137,7 +134,7 @@ def test_data_processing_with_prompt_completion(self):
137
134
138
135
processed_dataset = _distillation_data_processing .process_dataset (config , dataset )
139
136
140
- expected_prompts = [["What color is the sky?" , "Why is the sky blue?" ], ["How many days are in a week?" ]]
137
+ expected_prompts = [["What color is the sky?" , "Why is the sky blue?" ], ["Can you tell me how many days are in a week?" ]]
141
138
expected_completions = [
142
139
["The sky is blue." , "The sky appears blue due a phenomemon called Rayleigh scattering." ],
143
140
["There are 7 days in a week." ],
@@ -154,7 +151,7 @@ def test_data_processing_with_prompt_completion(self):
154
151
self .assertEqual (data ["completion" ][c_idx ], completion )
155
152
156
153
def test_data_filtering_with_prompt_completion (self ):
157
- config = self .parser .parse_args (["--data-columns" , "prompt" , "completion" ])
154
+ config = self .parser .parse_args (["--data-columns" , "prompt" , "completion" , "--use-chat-template" ])
158
155
dataset = Dataset .from_dict ({"prompt" : PROMPT_DATA , "completion" : COMPLETION_DATA })
159
156
160
157
processed_dataset = _distillation_data_processing .process_dataset (config , dataset )
0 commit comments