1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
1
21
import torch
2
22
import torch .distributed
3
23
16
36
from flash_attn .layers .rotary import RotaryEmbedding
17
37
18
38
39
+ class FastLayerNorm (nn .LayerNorm ):
40
+ def forward (self , hidden_states , residual = None ):
41
+ if hidden_states .shape [- 1 ] > 6144 :
42
+ if residual is not None :
43
+ hidden_states += residual
44
+ residual = hidden_states
45
+
46
+ return super (FastLayerNorm , self ).forward (hidden_states ), residual
47
+ else :
48
+ (
49
+ normed_hidden_states ,
50
+ residual ,
51
+ * rest ,
52
+ ) = dropout_layer_norm .dropout_add_ln_fwd (
53
+ hidden_states ,
54
+ residual ,
55
+ self .weight ,
56
+ self .bias ,
57
+ None ,
58
+ None ,
59
+ None ,
60
+ None ,
61
+ 0.0 ,
62
+ self .eps ,
63
+ 1.0 ,
64
+ 0 ,
65
+ None ,
66
+ False ,
67
+ False ,
68
+ )
69
+ if residual is None :
70
+ residual = hidden_states
71
+
72
+ return normed_hidden_states , residual
73
+
74
+
19
75
class FastLinear (nn .Linear ):
20
76
def __init__ (
21
77
self ,
@@ -59,22 +115,21 @@ def __init__(
59
115
dtype = dtype ,
60
116
)
61
117
62
- def forward (self , input ):
63
- return super (TensorParallelColumnLinear , self ).forward (input )
64
-
65
118
66
119
class TensorParallelRowLinear (FastLinear ):
67
120
def __init__ (
68
121
self ,
69
122
in_features ,
70
123
out_features ,
71
124
process_group : torch .distributed .ProcessGroup ,
125
+ reduce = True ,
72
126
bias = True ,
73
127
device = None ,
74
128
dtype = None ,
75
129
):
76
130
self .process_group = process_group
77
131
self .tp_world_size = process_group .size ()
132
+ self .reduce = reduce
78
133
assert in_features % self .tp_world_size == 0
79
134
in_features = in_features // self .tp_world_size
80
135
@@ -88,7 +143,8 @@ def __init__(
88
143
89
144
def forward (self , input : torch .Tensor ) -> torch .Tensor :
90
145
out = super (TensorParallelRowLinear , self ).forward (input )
91
- torch .distributed .all_reduce (out , group = self .process_group )
146
+ if self .reduce :
147
+ torch .distributed .all_reduce (out , group = self .process_group )
92
148
93
149
return out
94
150
@@ -196,7 +252,13 @@ def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
196
252
197
253
class FlashNeoxAttention (torch .nn .Module ):
198
254
def __init__ (
199
- self , num_heads , hidden_size , rotary_pct , rotary_emb_base , process_group = None
255
+ self ,
256
+ num_heads ,
257
+ hidden_size ,
258
+ rotary_pct ,
259
+ rotary_emb_base ,
260
+ process_group = None ,
261
+ reduce = True ,
200
262
):
201
263
super ().__init__ ()
202
264
self .num_heads = num_heads
@@ -218,9 +280,7 @@ def __init__(
218
280
process_group = process_group ,
219
281
)
220
282
self .dense = TensorParallelRowLinear (
221
- hidden_size ,
222
- hidden_size ,
223
- process_group = process_group ,
283
+ hidden_size , hidden_size , process_group = process_group , reduce = reduce
224
284
)
225
285
226
286
def shuffle_qkv_dims (self ):
@@ -309,7 +369,9 @@ def forward(
309
369
310
370
311
371
class FlashMLP (nn .Module ):
312
- def __init__ (self , act , hidden_size , intermediate_size , process_group = None ):
372
+ def __init__ (
373
+ self , act , hidden_size , intermediate_size , process_group = None , reduce = True
374
+ ):
313
375
super ().__init__ ()
314
376
self .act = (
315
377
ACT2FN [act ]
@@ -330,6 +392,7 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None):
330
392
intermediate_size ,
331
393
hidden_size ,
332
394
process_group = process_group ,
395
+ reduce = reduce ,
333
396
)
334
397
self .process_group = process_group
335
398
@@ -355,12 +418,24 @@ def __init__(
355
418
):
356
419
super ().__init__ ()
357
420
self .use_parallel_residual = use_parallel_residual
358
- self .input_layernorm = nn . LayerNorm (hidden_size , eps = layer_norm_eps )
359
- self .post_attention_layernorm = nn . LayerNorm (hidden_size , eps = layer_norm_eps )
421
+ self .input_layernorm = FastLayerNorm (hidden_size , eps = layer_norm_eps )
422
+ self .post_attention_layernorm = FastLayerNorm (hidden_size , eps = layer_norm_eps )
360
423
self .attention = FlashNeoxAttention (
361
- num_heads , hidden_size , rotary_pct , rotary_emb_base , process_group
424
+ num_heads ,
425
+ hidden_size ,
426
+ rotary_pct ,
427
+ rotary_emb_base ,
428
+ process_group ,
429
+ reduce = not use_parallel_residual ,
430
+ )
431
+ self .mlp = FlashMLP (
432
+ act ,
433
+ hidden_size ,
434
+ intermediate_size ,
435
+ process_group ,
436
+ reduce = not use_parallel_residual ,
362
437
)
363
- self .mlp = FlashMLP ( act , hidden_size , intermediate_size , process_group )
438
+ self .process_group = process_group
364
439
365
440
def forward (
366
441
self ,
@@ -375,24 +450,7 @@ def forward(
375
450
cu_seqlens_q ,
376
451
):
377
452
if self .use_parallel_residual :
378
- # faster input layer norm
379
- ln1_hidden_states , * rest = dropout_layer_norm .dropout_add_ln_fwd (
380
- hidden_states ,
381
- None ,
382
- self .input_layernorm .weight ,
383
- self .input_layernorm .bias ,
384
- None ,
385
- None ,
386
- None ,
387
- None ,
388
- 0.0 ,
389
- self .input_layernorm .eps ,
390
- 1.0 ,
391
- 0 ,
392
- None ,
393
- False ,
394
- False ,
395
- )
453
+ ln1_hidden_states , _ = self .input_layernorm (hidden_states )
396
454
397
455
attn_output = self .attention (
398
456
ln1_hidden_states ,
@@ -405,46 +463,18 @@ def forward(
405
463
cu_seqlens_q ,
406
464
)
407
465
408
- # faster post attention layer norm
409
- ln2_hidden_states , * rest = dropout_layer_norm .dropout_add_ln_fwd (
410
- hidden_states ,
411
- None ,
412
- self .post_attention_layernorm .weight ,
413
- self .post_attention_layernorm .bias ,
414
- None ,
415
- None ,
416
- None ,
417
- None ,
418
- 0.0 ,
419
- self .post_attention_layernorm .eps ,
420
- 1.0 ,
421
- 0 ,
422
- None ,
423
- False ,
424
- False ,
425
- )
466
+ ln2_hidden_states , _ = self .post_attention_layernorm (hidden_states )
426
467
427
468
mlp_output = self .mlp (ln2_hidden_states )
428
- return mlp_output + attn_output + hidden_states , None
469
+ intermediate = mlp_output + attn_output
470
+
471
+ # Only reduce once and after the addition instead of once per layer
472
+ if self .process_group is not None :
473
+ torch .distributed .all_reduce (intermediate , group = self .process_group )
474
+
475
+ return intermediate + hidden_states , None
429
476
else :
430
- # faster input layer norm
431
- hidden_states , residual , * rest = dropout_layer_norm .dropout_add_ln_fwd (
432
- hidden_states ,
433
- residual ,
434
- self .input_layernorm .weight ,
435
- self .input_layernorm .bias ,
436
- None ,
437
- None ,
438
- None ,
439
- None ,
440
- 0.0 ,
441
- self .input_layernorm .eps ,
442
- 1.0 ,
443
- 0 ,
444
- None ,
445
- False ,
446
- False ,
447
- )
477
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
448
478
449
479
hidden_states = self .attention (
450
480
hidden_states ,
@@ -457,23 +487,8 @@ def forward(
457
487
cu_seqlens_q ,
458
488
)
459
489
460
- # faster post attention layer norm
461
- hidden_states , residual , * rest = dropout_layer_norm .dropout_add_ln_fwd (
462
- hidden_states ,
463
- residual ,
464
- self .post_attention_layernorm .weight ,
465
- self .post_attention_layernorm .bias ,
466
- None ,
467
- None ,
468
- None ,
469
- None ,
470
- 0.0 ,
471
- self .post_attention_layernorm .eps ,
472
- 1.0 ,
473
- 0 ,
474
- None ,
475
- False ,
476
- False ,
490
+ hidden_states , residual = self .post_attention_layernorm (
491
+ hidden_states , residual
477
492
)
478
493
479
494
mlp_output = self .mlp (hidden_states )
@@ -523,7 +538,7 @@ def __init__(self, config, process_group=None):
523
538
for _ in range (config .num_hidden_layers )
524
539
]
525
540
)
526
- self .final_layer_norm = nn . LayerNorm (
541
+ self .final_layer_norm = FastLayerNorm (
527
542
config .hidden_size , eps = config .layer_norm_eps
528
543
)
529
544
@@ -603,24 +618,7 @@ def forward(
603
618
cu_seqlens_q ,
604
619
)
605
620
606
- # Faster final layer norm
607
- hidden_states , * rest = dropout_layer_norm .dropout_add_ln_fwd (
608
- hidden_states ,
609
- residual ,
610
- self .final_layer_norm .weight ,
611
- self .final_layer_norm .bias ,
612
- None ,
613
- None ,
614
- None ,
615
- None ,
616
- 0.0 ,
617
- self .final_layer_norm .eps ,
618
- 1.0 ,
619
- 0 ,
620
- None ,
621
- False ,
622
- False ,
623
- )
621
+ hidden_states , _ = self .final_layer_norm (hidden_states , residual )
624
622
625
623
return hidden_states , past_key_values
626
624
0 commit comments