Skip to content

[ET-VK][ez] Use standard quant naming scheme for quantized ops #10587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_8w_linear:
linear_qcsnw:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
Expand All @@ -18,6 +18,6 @@ q_8w_linear:
- VALUE: texture3d
- VALUE: buffer
shader_variants:
- NAME: q_8w_linear_W_packed_W_packed
- NAME: q_8w_linear_W_packed_H_packed
- NAME: linear_qcs8w_W_packed_W_packed
- NAME: linear_qcs8w_W_packed_H_packed
MAT2_PACKING: H_packed
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_8w_linear_coop:
linear_qcsnw_coop:
parameter_names_with_default_values:
DTYPE: float
IN_STORAGE: texture3d
Expand All @@ -17,11 +17,11 @@ q_8w_linear_coop:
- VALUE: 1
SUFFIX: o4x1
shader_variants:
- NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float
- NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float
- NAME: linear_qcs8w_coop_texture3d_texture3d_texture2d_texture2d_float
- NAME: linear_qcs8w_coop_buffer_buffer_texture2d_texture2d_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
- NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float
- NAME: linear_qcs8w_coop_buffer_buffer_buffer_buffer_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
WEIGHT_STORAGE: buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_8w_linear_tiled:
linear_qcsnw_tiled:
parameter_names_with_default_values:
DTYPE: float
IN_STORAGE: texture3d
Expand All @@ -21,11 +21,11 @@ q_8w_linear_tiled:
- VALUE: 4
SUFFIX: o4x4
shader_variants:
- NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float
- NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float
- NAME: linear_qcs8w_tiled_texture3d_texture3d_texture2d_texture2d_float
- NAME: linear_qcs8w_tiled_buffer_buffer_texture2d_texture2d_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
- NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float
- NAME: linear_qcs8w_tiled_buffer_buffer_buffer_buffer_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
WEIGHT_STORAGE: buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_4w_linear_coop:
linear_qga4w_coop:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture3d
Expand All @@ -13,11 +13,11 @@ q_4w_linear_coop:
PARAMS_STORAGE: buffer
TILE_ROWS: 1
shader_variants:
- NAME: q_4w_linear_coop_texture3d_texture3d_texture2d_float
- NAME: q_4w_linear_coop_buffer_buffer_texture2d_float
- NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float
- NAME: linear_qga4w_coop_buffer_buffer_texture2d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
- NAME: q_4w_linear_coop_buffer_buffer_buffer_float
- NAME: linear_qga4w_coop_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
WEIGHT_STORAGE: buffer
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

q_4w_linear_tiled:
linear_qga4w_tiled:
parameter_names_with_default_values:
DTYPE: float
OUT_STORAGE: texture3d
Expand All @@ -13,11 +13,11 @@ q_4w_linear_tiled:
PARAMS_STORAGE: buffer
TILE_ROWS: 3
shader_variants:
- NAME: q_4w_linear_tiled_texture3d_texture3d_texture2d_float
- NAME: q_4w_linear_tiled_buffer_buffer_texture2d_float
- NAME: linear_qga4w_tiled_texture3d_texture3d_texture2d_float
- NAME: linear_qga4w_tiled_buffer_buffer_texture2d_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
- NAME: q_4w_linear_tiled_buffer_buffer_buffer_float
- NAME: linear_qga4w_tiled_buffer_buffer_buffer_float
OUT_STORAGE: buffer
IN_STORAGE: buffer
WEIGHT_STORAGE: buffer
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace vkcompute {

void check_q_8w_linear_args(
void check_linear_qcsnw_args(
const ComputeGraph& graph,
const ValueRef mat1,
const ValueRef qmat2_data,
Expand All @@ -37,7 +37,7 @@ void check_q_8w_linear_args(
utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes));
}

void resize_q_8w_linear_node(
void resize_linear_qcs8w_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
Expand All @@ -64,7 +64,7 @@ void resize_q_8w_linear_node(
out->virtual_resize(new_out_sizes);
}

void add_q_8w_linear_node(
void add_linear_qcs8w_node(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef q_mat2_data,
Expand All @@ -91,7 +91,7 @@ void add_q_8w_linear_node(
ValueRef scales = prepack_standard(
graph, scales_data, graph.storage_type_of(out), utils::kWidthPacked);

std::string kernel_name = "q_8w_linear";
std::string kernel_name = "linear_qcs8w";
kernel_name.reserve(kShaderNameReserve);
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
Expand Down Expand Up @@ -131,7 +131,7 @@ void add_q_8w_linear_node(
// Specialization Constants
{},
// Resizing Logic
resize_q_8w_linear_node,
resize_linear_qcs8w_node,
{},
pcs));
if (!graph.is_buffer_storage(out) &&
Expand All @@ -140,7 +140,7 @@ void add_q_8w_linear_node(
}
}

void add_q_8w_linear_tiled_node(
void add_linear_qcs8w_tiled_node(
ComputeGraph& graph,
const bool use_coop_algorithm,
const ValueRef mat1,
Expand Down Expand Up @@ -170,7 +170,7 @@ void add_q_8w_linear_tiled_node(
prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked);

std::string kernel_name =
use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled";
use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
Expand Down Expand Up @@ -218,7 +218,7 @@ void add_q_8w_linear_tiled_node(
// Specialization Constants
{},
// Resizing Logic
resize_q_8w_linear_node,
resize_linear_qcs8w_node,
{},
// Push Constants
{{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}}));
Expand Down Expand Up @@ -280,13 +280,13 @@ bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) {
void weight_int8pack_mm(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
check_linear_qcsnw_args(graph, args[0], args[1], args[2], args[3]);
if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) {
bool use_coop_algorithm = can_use_coop_impl(graph, args[0]);
return add_q_8w_linear_tiled_node(
return add_linear_qcs8w_tiled_node(
graph, use_coop_algorithm, args[0], args[1], args[2], args[3]);
}
return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
return add_linear_qcs8w_node(graph, args[0], args[1], args[2], args[3]);
}

REGISTER_OPERATORS {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace vkcompute {

void check_q_4w_linear_args(
void check_linear_qga4w_args(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef mat2_data,
Expand Down Expand Up @@ -43,7 +43,7 @@ void check_q_4w_linear_args(
VK_CHECK_COND(graph.has_standard_axis_map(out));
}

void resize_q_4w_linear_node(
void resize_linear_qga4w_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
Expand Down Expand Up @@ -118,14 +118,14 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
return qmat2;
}

void add_q_4w_linear_node(
void add_linear_qga4w_node(
ComputeGraph& graph,
const ValueRef mat1,
const ValueRef mat2_data,
const ValueRef group_size,
const ValueRef scales_and_zeros_data,
const ValueRef out) {
check_q_4w_linear_args(
check_linear_qga4w_args(
graph, mat1, mat2_data, group_size, scales_and_zeros_data, out);

const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
Expand All @@ -143,7 +143,7 @@ void add_q_4w_linear_node(
ValueRef scales_and_zeros = prepack_standard_hw_transposed(
graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked);

std::string kernel_name = "q_4w_linear";
std::string kernel_name = "linear_qga4w";
if (use_coop_algorithm) {
kernel_name += "_coop";
} else {
Expand Down Expand Up @@ -176,7 +176,7 @@ void add_q_4w_linear_node(
// Specialization Constants
{SV(group_size_val)},
// Resizing Logic
resize_q_4w_linear_node,
resize_linear_qga4w_node,
{},
// Push Constants
{graph.sizes_pc_of(out),
Expand All @@ -187,7 +187,7 @@ void add_q_4w_linear_node(
void linear_weight_int4(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
return add_q_4w_linear_node(
return add_linear_qga4w_node(
graph,
args[0], // mat1
args[1], // mat2
Expand Down
30 changes: 15 additions & 15 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// Reference Implementations
//

at::Tensor linear_weight_int4_reference_impl(
at::Tensor linear_qga4w_reference_impl(
const at::Tensor& x,
const at::Tensor& weights_4x2,
const int64_t groupsize,
Expand Down Expand Up @@ -101,7 +101,7 @@ at::Tensor dequantize_and_linear(
// Test functions
//

void test_reference_linear_int4(
void test_reference_linear_qga4w(
const int B,
const int M,
const int K,
Expand All @@ -119,7 +119,7 @@ void test_reference_linear_int4(
at::Tensor scales_and_zeros =
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));

at::Tensor out = linear_weight_int4_reference_impl(
at::Tensor out = linear_qga4w_reference_impl(
x,
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
group_size,
Expand Down Expand Up @@ -152,7 +152,7 @@ vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
}
}

void test_vulkan_linear_int4_impl(
void test_vulkan_linear_qga4w_impl(
const int B,
const int M,
const int K,
Expand All @@ -174,7 +174,7 @@ void test_vulkan_linear_int4_impl(
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));

at::Tensor weights_int = unpack_weights_4x2(weights_4x2);
at::Tensor out_ref = linear_weight_int4_reference_impl(
at::Tensor out_ref = linear_qga4w_reference_impl(
x,
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
group_size,
Expand Down Expand Up @@ -237,14 +237,14 @@ void test_vulkan_linear_int4_impl(
ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4));
}

void test_vulkan_linear_int4(
void test_vulkan_linear_qga4w(
const int B,
const int M,
const int K,
const int N,
const int group_size = 32,
const int inner_k_tiles = 8) {
test_vulkan_linear_int4_impl(
test_vulkan_linear_qga4w_impl(
B,
M,
K,
Expand All @@ -254,7 +254,7 @@ void test_vulkan_linear_int4(
vkcompute::utils::kBuffer,
vkcompute::utils::kBuffer);

test_vulkan_linear_int4_impl(
test_vulkan_linear_qga4w_impl(
B,
M,
K,
Expand All @@ -265,30 +265,30 @@ void test_vulkan_linear_int4(
vkcompute::utils::kTexture3D);
}

TEST(VulkanInt4LinearTest, test_reference_impl) {
test_reference_linear_int4(
TEST(VulkanLinearQGA4WTest, test_reference_impl) {
test_reference_linear_qga4w(
/*B = */ 1,
/*M = */ 4,
/*K = */ 128,
/*N = */ 32);
}

TEST(VulkanInt4LinearTest, test_vulkan_impl_small_m) {
test_vulkan_linear_int4(
TEST(VulkanLinearQGA4WTest, test_vulkan_impl_small_m) {
test_vulkan_linear_qga4w(
/*B = */ 1,
/*M = */ 4,
/*K = */ 128,
/*N = */ 32);

test_vulkan_linear_int4(
test_vulkan_linear_qga4w(
/*B = */ 1,
/*M = */ 1,
/*K = */ 256,
/*N = */ 256);
}

TEST(VulkanInt4LinearTest, test_vulkan_impl_gemm) {
test_vulkan_linear_int4(
TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) {
test_vulkan_linear_qga4w(
/*B = */ 1,
/*M = */ 256,
/*K = */ 256,
Expand Down
Loading