Skip to content

Commit 557a1a0

Browse files
committed
adapt CI tests to use compiled_rmsnorm
ghstack-source-id: fa38b22 Pull Request resolved: #451
1 parent 8f810ff commit 557a1a0

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

test_runner.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def build_test_list():
6969
"--experimental.pipeline_parallel_split_points layers.4",
7070
"--experimental.pipeline_parallel_schedule 1f1b",
7171
"--training.data_parallel_degree 1",
72-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
72+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
7373
],
7474
],
7575
"PP 1D test 1f1b",
@@ -85,7 +85,7 @@ def build_test_list():
8585
"--experimental.pipeline_parallel_split_points layers.4",
8686
"--experimental.pipeline_parallel_schedule gpipe",
8787
"--training.data_parallel_degree 1",
88-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
88+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
8989
],
9090
],
9191
"PP 1D test gpipe",
@@ -101,7 +101,7 @@ def build_test_list():
101101
"--experimental.pipeline_parallel_split_points layers.4",
102102
"--experimental.pipeline_parallel_schedule 1f1b",
103103
"--training.data_parallel_degree 2",
104-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
104+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
105105
],
106106
],
107107
"PP+DP 1f1b 2D test",
@@ -116,7 +116,7 @@ def build_test_list():
116116
"--experimental.pipeline_parallel_split_points layers.4",
117117
"--experimental.pipeline_parallel_schedule gpipe",
118118
"--training.data_parallel_degree 2",
119-
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
119+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
120120
],
121121
],
122122
"PP+DP gpipe 2D test",
@@ -130,7 +130,6 @@ def build_test_list():
130130
"--experimental.pipeline_parallel_degree 2",
131131
"--experimental.pipeline_parallel_split_points layers.4",
132132
"--training.tensor_parallel_degree 2",
133-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
134133
],
135134
],
136135
"PP+TP 2D test",
@@ -144,7 +143,6 @@ def build_test_list():
144143
"--experimental.pipeline_parallel_degree 2",
145144
"--experimental.pipeline_parallel_split_points layers.4",
146145
"--experimental.pipeline_parallel_split_mode tracer",
147-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
148146
],
149147
],
150148
"PP tracer frontend test",
@@ -162,7 +160,16 @@ def build_test_list():
162160
OverrideDefinitions(
163161
[
164162
[
165-
"--training.compile --model.norm_type=rmsnorm",
163+
"--training.tensor_parallel_degree 2",
164+
],
165+
],
166+
"2D eager",
167+
"2d_eager",
168+
),
169+
OverrideDefinitions(
170+
[
171+
[
172+
"--training.compile",
166173
],
167174
],
168175
"1D compile",
@@ -182,29 +189,20 @@ def build_test_list():
182189
OverrideDefinitions(
183190
[
184191
[
185-
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
192+
"--training.compile --training.tensor_parallel_degree 2",
186193
],
187194
],
188195
"2D compile",
189196
"2d_compile",
190197
),
191-
OverrideDefinitions(
192-
[
193-
[
194-
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
195-
],
196-
],
197-
"Eager mode 2DParallel with rmsnorm",
198-
"eager_2d_rmsnorm",
199-
),
200198
OverrideDefinitions(
201199
[
202200
[
203201
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
204202
],
205203
],
206-
"Eager mode 2DParallel with fused_rmsnorm",
207-
"eager_2d_fused_rmsnorm",
204+
"2D eager with fused_rmsnorm",
205+
"2d_eager_fused_rmsnorm",
208206
),
209207
OverrideDefinitions(
210208
[
@@ -248,7 +246,6 @@ def build_test_list():
248246
"--experimental.pipeline_parallel_split_points layers.4",
249247
"--training.data_parallel_degree 2",
250248
"--training.tensor_parallel_degree 2",
251-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
252249
],
253250
[
254251
"--training.steps 20",
@@ -257,7 +254,6 @@ def build_test_list():
257254
"--experimental.pipeline_parallel_split_points layers.4",
258255
"--training.data_parallel_degree 2",
259256
"--training.tensor_parallel_degree 2",
260-
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
261257
],
262258
],
263259
"PP+DP+TP 3D test with save/load resume ckpt",
@@ -272,7 +268,7 @@ def build_test_list():
272268
"--experimental.pipeline_parallel_degree 4",
273269
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
274270
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
275-
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
271+
"--model.norm_type rmsnorm", # compiled_rmsnorm / fused_rmsnorm crashes with PP
276272
],
277273
],
278274
"PP looped 1f1b test",
@@ -292,7 +288,8 @@ def build_test_list():
292288
OverrideDefinitions(
293289
[
294290
[
295-
"--memory_estimation.enabled --model.norm_type rmsnorm",
291+
"--memory_estimation.enabled",
292+
"--model.norm_type rmsnorm", # estimation mode does not support compiled_rmsnorm yet
296293
]
297294
],
298295
"FSDP2 Memory Tracking and Estimation",

0 commit comments

Comments
 (0)