Skip to content

Commit f0aa672

Browse files
authored
[LLM INFER] add rope_theta for block_multihead_attention (#9334)
* add rope_theta for block_multihead_attention
1 parent 975d5c7 commit f0aa672

File tree

5 files changed

+15
-2
lines changed

5 files changed

+15
-2
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@
4141
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
4242
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
4343
)
44+
4445
if (
4546
paddle.device.get_all_custom_device_type() is not None and len(paddle.device.get_all_custom_device_type()) > 0
46-
) or core.is_compiled_with_cuda():
47+
) or paddle.is_compiled_with_cuda():
4748
from paddlenlp_ops import rebuild_padding_v2
4849

4950

@@ -147,6 +148,7 @@ def __init__(
147148
activation="gelu",
148149
norm_type="layernorm",
149150
use_neox_rotary_style=False,
151+
rope_theta=10000.0,
150152
normalize_before=True,
151153
ln_scale_attrs=None,
152154
ln_bias_attrs=None,
@@ -210,7 +212,7 @@ def __init__(
210212
self.dropout_rate = dropout_rate
211213
self.activation = activation
212214
self.norm_type = norm_type
213-
215+
self.rope_theta = rope_theta
214216
self.use_neox_rotary_style = use_neox_rotary_style
215217
self.normalize_before = normalize_before
216218
self.ln_scale_attrs = ln_scale_attrs
@@ -2234,6 +2236,7 @@ def compute_attn(
22342236
quant_round_type=self.config.quant_round_type,
22352237
quant_max_bound=self.config.quant_max_bound,
22362238
quant_min_bound=self.config.quant_min_bound,
2239+
rope_theta=self.rope_theta,
22372240
)[0]
22382241
else:
22392242
k_quant_scales = kwargs.get("k_quant_scales", None)
@@ -2275,6 +2278,7 @@ def compute_attn(
22752278
quant_round_type=self.config.quant_round_type,
22762279
quant_max_bound=self.config.quant_max_bound,
22772280
quant_min_bound=self.config.quant_min_bound,
2281+
rope_theta=self.rope_theta,
22782282
)[0]
22792283

22802284
out_linear_out = self.compute_out_linear(fmha_out, i)
@@ -2420,6 +2424,7 @@ def compute_attn(
24202424
quant_min_bound=self.quant_min_bound,
24212425
out_scale=self.act_scales["out_linear_in_scale"][i],
24222426
compute_dtype=self._fuse_kernel_compute_dtype,
2427+
rope_theta=self.rope_theta,
24232428
)[0]
24242429

24252430
out_linear_out = self.compute_out_linear(fmha_out, i)
@@ -2932,6 +2937,7 @@ def compute_attn(
29322937
quant_max_bound=self.config.quant_max_bound,
29332938
quant_min_bound=self.config.quant_min_bound,
29342939
out_scale=self.act_scales.scale["out_linear_in_scale"][i],
2940+
rope_theta=self.rope_theta,
29352941
)[0]
29362942
out_linear_out = self.compute_out_linear(fmha_out, i)
29372943

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(self, config: LlamaConfig):
214214
ffn2_bias_attrs=None,
215215
norm_type="rmsnorm",
216216
epsilon=self.epsilon,
217+
rope_theta=self.rope_theta,
217218
nranks=config.tensor_parallel_degree,
218219
avx_config=avx_config,
219220
)
@@ -629,6 +630,7 @@ def __init__(self, config: LlamaConfig):
629630
ffn2_weight_attrs=ffn2_weight_attrs,
630631
ffn2_bias_attrs=ffn2_bias_attrs,
631632
epsilon=self.epsilon,
633+
rope_theta=self.rope_theta,
632634
norm_type="rmsnorm",
633635
use_neox_rotary_style=self.use_neox,
634636
rank_id=config.tensor_parallel_rank,
@@ -675,6 +677,7 @@ def __init__(self, config: LlamaConfig):
675677
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
676678
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
677679
epsilon=self.epsilon,
680+
rope_theta=self.rope_theta,
678681
norm_type="rmsnorm",
679682
use_neox_rotary_style=self.use_neox,
680683
cachekv_int8_type=config.cachekv_int8_type,

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(self, config: MixtralConfig):
334334
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
335335
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
336336
epsilon=self.epsilon,
337+
rope_theta=self.rope_theta,
337338
norm_type="rmsnorm",
338339
use_neox_rotary_style=self.use_neox,
339340
cachekv_int8_type=config.cachekv_int8_type,

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(self, config: Qwen2Config):
334334
ffn2_weight_attrs=ffn2_weight_attrs,
335335
ffn2_bias_attrs=ffn2_bias_attrs,
336336
epsilon=self.rms_norm_eps,
337+
rope_theta=self.rope_theta,
337338
norm_type="rmsnorm",
338339
use_neox_rotary_style=self.use_neox,
339340
rank_id=config.tensor_parallel_rank,
@@ -380,6 +381,7 @@ def __init__(self, config: Qwen2Config):
380381
cache_k_out_scale_attrs=cache_k_out_scale_attrs,
381382
cache_v_out_scale_attrs=cache_v_out_scale_attrs,
382383
epsilon=self.rms_norm_eps,
384+
rope_theta=self.rope_theta,
383385
norm_type="rmsnorm",
384386
use_neox_rotary_style=self.use_neox,
385387
cachekv_int8_type=config.cachekv_int8_type,

paddlenlp/experimental/transformers/qwen2_moe/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def __init__(self, config: Qwen2MoeConfig):
252252
ffn2_weight_scale_attrs=ffn2_weight_scale_attrs,
253253
qkv_bias_attrs=qkv_bias_attrs,
254254
epsilon=self.rms_norm_eps,
255+
rope_theta=self.rope_theta,
255256
norm_type="rmsnorm",
256257
use_neox_rotary_style=self.use_neox,
257258
rank_id=config.tensor_parallel_rank,

0 commit comments

Comments
 (0)