Skip to content

Commit c9bdaa8

Browse files
feat(server): reduce mlp and attn in one op for flash neox (#145)
1 parent f000068 commit c9bdaa8

File tree

1 file changed

+102
-104
lines changed

1 file changed

+102
-104
lines changed

server/text_generation_server/models/flash_neox_modeling.py

Lines changed: 102 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
# coding=utf-8
2+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5+
# and OPT implementations in this library. It has been modified from its
6+
# original forms to accommodate minor architectural differences compared
7+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
121
import torch
222
import torch.distributed
323

@@ -16,6 +36,42 @@
1636
from flash_attn.layers.rotary import RotaryEmbedding
1737

1838

39+
class FastLayerNorm(nn.LayerNorm):
40+
def forward(self, hidden_states, residual=None):
41+
if hidden_states.shape[-1] > 6144:
42+
if residual is not None:
43+
hidden_states += residual
44+
residual = hidden_states
45+
46+
return super(FastLayerNorm, self).forward(hidden_states), residual
47+
else:
48+
(
49+
normed_hidden_states,
50+
residual,
51+
*rest,
52+
) = dropout_layer_norm.dropout_add_ln_fwd(
53+
hidden_states,
54+
residual,
55+
self.weight,
56+
self.bias,
57+
None,
58+
None,
59+
None,
60+
None,
61+
0.0,
62+
self.eps,
63+
1.0,
64+
0,
65+
None,
66+
False,
67+
False,
68+
)
69+
if residual is None:
70+
residual = hidden_states
71+
72+
return normed_hidden_states, residual
73+
74+
1975
class FastLinear(nn.Linear):
2076
def __init__(
2177
self,
@@ -59,22 +115,21 @@ def __init__(
59115
dtype=dtype,
60116
)
61117

62-
def forward(self, input):
63-
return super(TensorParallelColumnLinear, self).forward(input)
64-
65118

66119
class TensorParallelRowLinear(FastLinear):
67120
def __init__(
68121
self,
69122
in_features,
70123
out_features,
71124
process_group: torch.distributed.ProcessGroup,
125+
reduce=True,
72126
bias=True,
73127
device=None,
74128
dtype=None,
75129
):
76130
self.process_group = process_group
77131
self.tp_world_size = process_group.size()
132+
self.reduce = reduce
78133
assert in_features % self.tp_world_size == 0
79134
in_features = in_features // self.tp_world_size
80135

@@ -88,7 +143,8 @@ def __init__(
88143

89144
def forward(self, input: torch.Tensor) -> torch.Tensor:
90145
out = super(TensorParallelRowLinear, self).forward(input)
91-
torch.distributed.all_reduce(out, group=self.process_group)
146+
if self.reduce:
147+
torch.distributed.all_reduce(out, group=self.process_group)
92148

93149
return out
94150

@@ -196,7 +252,13 @@ def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
196252

197253
class FlashNeoxAttention(torch.nn.Module):
198254
def __init__(
199-
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
255+
self,
256+
num_heads,
257+
hidden_size,
258+
rotary_pct,
259+
rotary_emb_base,
260+
process_group=None,
261+
reduce=True,
200262
):
201263
super().__init__()
202264
self.num_heads = num_heads
@@ -218,9 +280,7 @@ def __init__(
218280
process_group=process_group,
219281
)
220282
self.dense = TensorParallelRowLinear(
221-
hidden_size,
222-
hidden_size,
223-
process_group=process_group,
283+
hidden_size, hidden_size, process_group=process_group, reduce=reduce
224284
)
225285

226286
def shuffle_qkv_dims(self):
@@ -309,7 +369,9 @@ def forward(
309369

310370

311371
class FlashMLP(nn.Module):
312-
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
372+
def __init__(
373+
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
374+
):
313375
super().__init__()
314376
self.act = (
315377
ACT2FN[act]
@@ -330,6 +392,7 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None):
330392
intermediate_size,
331393
hidden_size,
332394
process_group=process_group,
395+
reduce=reduce,
333396
)
334397
self.process_group = process_group
335398

@@ -355,12 +418,24 @@ def __init__(
355418
):
356419
super().__init__()
357420
self.use_parallel_residual = use_parallel_residual
358-
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
359-
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
421+
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
422+
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
360423
self.attention = FlashNeoxAttention(
361-
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
424+
num_heads,
425+
hidden_size,
426+
rotary_pct,
427+
rotary_emb_base,
428+
process_group,
429+
reduce=not use_parallel_residual,
430+
)
431+
self.mlp = FlashMLP(
432+
act,
433+
hidden_size,
434+
intermediate_size,
435+
process_group,
436+
reduce=not use_parallel_residual,
362437
)
363-
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
438+
self.process_group = process_group
364439

365440
def forward(
366441
self,
@@ -375,24 +450,7 @@ def forward(
375450
cu_seqlens_q,
376451
):
377452
if self.use_parallel_residual:
378-
# faster input layer norm
379-
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
380-
hidden_states,
381-
None,
382-
self.input_layernorm.weight,
383-
self.input_layernorm.bias,
384-
None,
385-
None,
386-
None,
387-
None,
388-
0.0,
389-
self.input_layernorm.eps,
390-
1.0,
391-
0,
392-
None,
393-
False,
394-
False,
395-
)
453+
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
396454

397455
attn_output = self.attention(
398456
ln1_hidden_states,
@@ -405,46 +463,18 @@ def forward(
405463
cu_seqlens_q,
406464
)
407465

408-
# faster post attention layer norm
409-
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
410-
hidden_states,
411-
None,
412-
self.post_attention_layernorm.weight,
413-
self.post_attention_layernorm.bias,
414-
None,
415-
None,
416-
None,
417-
None,
418-
0.0,
419-
self.post_attention_layernorm.eps,
420-
1.0,
421-
0,
422-
None,
423-
False,
424-
False,
425-
)
466+
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
426467

427468
mlp_output = self.mlp(ln2_hidden_states)
428-
return mlp_output + attn_output + hidden_states, None
469+
intermediate = mlp_output + attn_output
470+
471+
# Only reduce once and after the addition instead of once per layer
472+
if self.process_group is not None:
473+
torch.distributed.all_reduce(intermediate, group=self.process_group)
474+
475+
return intermediate + hidden_states, None
429476
else:
430-
# faster input layer norm
431-
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
432-
hidden_states,
433-
residual,
434-
self.input_layernorm.weight,
435-
self.input_layernorm.bias,
436-
None,
437-
None,
438-
None,
439-
None,
440-
0.0,
441-
self.input_layernorm.eps,
442-
1.0,
443-
0,
444-
None,
445-
False,
446-
False,
447-
)
477+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
448478

449479
hidden_states = self.attention(
450480
hidden_states,
@@ -457,23 +487,8 @@ def forward(
457487
cu_seqlens_q,
458488
)
459489

460-
# faster post attention layer norm
461-
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
462-
hidden_states,
463-
residual,
464-
self.post_attention_layernorm.weight,
465-
self.post_attention_layernorm.bias,
466-
None,
467-
None,
468-
None,
469-
None,
470-
0.0,
471-
self.post_attention_layernorm.eps,
472-
1.0,
473-
0,
474-
None,
475-
False,
476-
False,
490+
hidden_states, residual = self.post_attention_layernorm(
491+
hidden_states, residual
477492
)
478493

479494
mlp_output = self.mlp(hidden_states)
@@ -523,7 +538,7 @@ def __init__(self, config, process_group=None):
523538
for _ in range(config.num_hidden_layers)
524539
]
525540
)
526-
self.final_layer_norm = nn.LayerNorm(
541+
self.final_layer_norm = FastLayerNorm(
527542
config.hidden_size, eps=config.layer_norm_eps
528543
)
529544

@@ -603,24 +618,7 @@ def forward(
603618
cu_seqlens_q,
604619
)
605620

606-
# Faster final layer norm
607-
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
608-
hidden_states,
609-
residual,
610-
self.final_layer_norm.weight,
611-
self.final_layer_norm.bias,
612-
None,
613-
None,
614-
None,
615-
None,
616-
0.0,
617-
self.final_layer_norm.eps,
618-
1.0,
619-
0,
620-
None,
621-
False,
622-
False,
623-
)
621+
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
624622

625623
return hidden_states, past_key_values
626624

0 commit comments

Comments
 (0)