Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 94 additions & 64 deletions custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
Expand All @@ -154,7 +155,8 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
dim_head,
block_size,
elem_nums,
kv_num_heads);
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
Expand Down Expand Up @@ -261,7 +263,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
Expand All @@ -284,7 +287,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
Expand All @@ -311,7 +315,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
Expand All @@ -334,7 +339,8 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
block_size,
127.0f,
-127.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
}
Expand Down Expand Up @@ -398,7 +404,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
Expand All @@ -423,7 +430,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
Expand Down Expand Up @@ -452,7 +460,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
Expand All @@ -477,7 +486,8 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
block_size,
7.0f,
-8.0f,
kv_num_heads);
kv_num_heads,
rope_3d);
}
}
}
Expand Down
1 change: 0 additions & 1 deletion custom_ops/gpu_ops/moe/moe_ffn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ paddle::Tensor MoeExpertFFNFunc(
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method, const bool used_in_ep_low_latency) {

cudaCheckError();
const auto t_type = quant_method == "w4a8" ? up_gate_proj_scale.get().dtype() : permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, t_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def __init__(self, quant_method=None):
"down_proj_weight_scale",
]

def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
"""process_prequanted_weights"""
pass

def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts
assert self.quant_method.name() == "wint8"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, quant_config):
self.group_size = -1

def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)
layer.up_gate_proj_weight.set_value(paddle.transpose(stacked_up_gate_proj_weights, [0, 2, 1]))
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(self, quant_config):
self.quant_multi_process_group_size = int(os.getenv("FD_MOE_QUANT_MULTI_PROCESS_GROUP_SIZE", 8))
logger.info(f"GCUWeightOnlyMoEMethod quant_multi_process_group_size: {self.quant_multi_process_group_size}")

def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle gcu process prequanted weights.
"""
Expand Down Expand Up @@ -299,7 +299,7 @@ def create_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)

def quant_worker(p_group_idx, shared_dict, weights, moe_quant_type, group_size):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def create_weights(self, layer, **extra_weight_attrs):
is_bias=False,
)

def process_prequanted_weights(self, layer, state_dict) -> None:
def process_prequanted_weights(self, layer, state_dict, is_rearrange: bool = False) -> None:
"""
Process pre-quantized weights before applying them to the model
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def __init__(self, quant_config=None):
"down_proj_weight_scale",
]

def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
"""process_prequanted_weights"""
pass

def create_weights(self, layer: nn.Layer, state_dict):
"""
Triton MoE create weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
assert len(up_gate_proj_weights) == layer.num_local_experts
assert len(down_proj_weights) == layer.num_local_experts

Expand Down
109 changes: 87 additions & 22 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
"""

def process_loaded_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
stacked_up_gate_proj_weights = paddle.stack(up_gate_proj_weights, axis=0)
stacked_down_proj_weights = paddle.stack(down_proj_weights, axis=0)

Expand Down Expand Up @@ -325,7 +327,7 @@ def __init__(self, quant_config):
self.moe_quant_type = "w4a8"
self.pack_num = 2

def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
Expand All @@ -341,6 +343,7 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
is_rearrange,
)
)

Expand All @@ -350,22 +353,62 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
up_gate_proj_in_scale = []
down_proj_in_scale = []

if isinstance(state_dict, list):
state_dict = dict(state_dict)

if layer.ep_size > 1:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
scale_tensor = get_tensor(
(
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scale_all_experts.append(scale_tensor)

for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else down_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))

up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)
Expand Down Expand Up @@ -427,7 +470,9 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.extract_moe_ffn_weights(state_dict)
)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
Expand All @@ -438,7 +483,9 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
quanted_weight = paddle.stack(weight_list, axis=0)
getattr(layer, weight_name).set_value(quanted_weight)

self.load_w4a8_scale_weights(layer, layer.weight_key_map, state_dict)
self.load_w4a8_scale_weights(
layer, layer.weight_key_map, state_dict, logical_expert_ids, ep_rank_to_expert_id_list
)

def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
"""
Expand Down Expand Up @@ -492,7 +539,14 @@ def create_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict):
),
)

def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_dict: dict):
def load_w4a8_scale_weights(
self,
layer: nn.Layer,
weight_key_map: dict,
state_dict: dict,
logical_expert_ids: paddle.Tensor,
ep_rank_to_expert_id_list: list,
):
"""
Get w4a8 weights from state dict and process them.
Args:
Expand All @@ -501,8 +555,15 @@ def load_w4a8_scale_weights(self, layer: nn.Layer, weight_key_map: dict, state_d
state_dict (dict): The state dict.
"""

def _extract_scale_tensor(state_dict, key_template, expert_idx):
return get_tensor(state_dict.pop(key_template.format(expert_idx)))
def _extract_scale_tensor(layer: nn.Layer, state_dict, key_template, expert_idx):
return get_tensor(
(
state_dict.pop(key_template.format(expert_idx))
if key_template.format(expert_idx) in state_dict
else key_template.format(expert_idx)
),
layer.fd_config.model_config.model,
)

def _process_in_scale(name: str, in_scales: list[paddle.Tensor]):
processed_in_scale = 1 / paddle.concat(in_scales)
Expand Down Expand Up @@ -544,17 +605,23 @@ def _process_weight_scale(

# 2. Extract scale tensor from state dict
if layer.ep_size > 1:
for expert_idx in range(layer.num_experts):
scale_tensor = get_tensor(state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)])
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(
(
state_dict[scale_key_map["up_gate_proj_in_scale"].format(expert_idx)]
if scale_key_map["up_gate_proj_in_scale"].format(expert_idx) in state_dict
else scale_key_map["up_gate_proj_in_scale"].format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scales_all_experts.append(1 / scale_tensor)
getattr(layer, "up_gate_proj_in_scale_all_experts").set_value(
paddle.concat(up_gate_proj_in_scales_all_experts)
)

for local_expert_idx in range(layer.num_local_experts):
expert_idx = local_expert_idx + layer.expert_id_offset
for expert_idx in logical_expert_ids:
for name, scale_key_template in scale_key_map.items():
scale_tensor = _extract_scale_tensor(state_dict, scale_key_template, expert_idx)
scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx)
scale_weight_map[name].append(scale_tensor)

# 3. Process scale tensor and set to layer
Expand All @@ -581,7 +648,7 @@ def __init__(self, quant_config):
self.moe_quant_type = self.quant_config.algo
self.pack_num = 1

def process_prequanted_weights(self, layer: nn.Layer, state_dict):
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False):
"""
Paddle cutlass process prequanted weights.
"""
Expand All @@ -591,9 +658,7 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)

up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
# self.check(layer, up_gate_proj_weights, down_proj_weights)
up_gate_proj_weight_scale = []
Expand Down Expand Up @@ -695,7 +760,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
"""
Paddle cutlass load weight process.
"""
up_gate_proj_weights, down_proj_weights = layer.extract_moe_ffn_weights(state_dict)
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
self.check(layer, up_gate_proj_weights, down_proj_weights)
for idx, weight_tensor in enumerate([up_gate_proj_weights, down_proj_weights]):
weight_name = self.added_weight_attrs[idx]
Expand Down
Loading
Loading