Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8f3b16f

Browse files
committedMar 31, 2025·
[ET-VK] Efficient tiled int8 matmul
## Context Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation. This implementation takes advantage of the following principles to squeeze out performance: * Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data. * Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load * Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix. Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/) [ghstack-poisoned]
1 parent 708ecb7 commit 8f3b16f

File tree

5 files changed

+184
-310
lines changed

5 files changed

+184
-310
lines changed
 

‎backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl

Lines changed: 0 additions & 212 deletions
This file was deleted.

‎backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}
15+
16+
#define TILE_ROWS ${TILE_ROWS}
17+
18+
${define_required_extensions(DTYPE)}
19+
20+
$if STORAGE == "buffer":
21+
${define_required_extensions("int8")}
22+
23+
#extension GL_EXT_control_flow_attributes : require
24+
25+
layout(std430) buffer;
26+
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)}
28+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)}
29+
${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)}
30+
${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)}
31+
32+
33+
layout(push_constant) uniform restrict Block {
34+
ivec4 out_sizes;
35+
ivec4 in_sizes;
36+
ivec4 weight_sizes;
37+
};
38+
39+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
40+
41+
void main() {
42+
const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43+
const uint out_col = gl_GlobalInvocationID.x << 2;
44+
45+
if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
46+
return;
47+
}
48+
49+
VEC4_T a[TILE_ROWS];
50+
VEC4_T b[4];
51+
VEC4_T c[TILE_ROWS];
52+
53+
$if STORAGE == "buffer":
54+
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
55+
$else:
56+
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0));
57+
58+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
59+
c[i] = VEC4_T(0.0);
60+
}
61+
62+
for (int pos = 0; pos < in_sizes.x; pos += 4) {
63+
// Preload weight tensor
64+
[[unroll]] for (int i = 0; i < 4; i++) {
65+
$if STORAGE == "buffer":
66+
b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2];
67+
$else:
68+
b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0));
69+
}
70+
71+
// Preload input tensor
72+
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
73+
$if STORAGE == "buffer":
74+
a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2];
75+
$else:
76+
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
77+
}
78+
79+
// Compute partial output
80+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
81+
c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3];
82+
}
83+
}
84+
85+
// Store output tensor
86+
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
87+
$if STORAGE == "buffer":
88+
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
89+
$else:
90+
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
91+
}
92+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
q_8w_linear_tiled:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
TILE_ROWS: 4
12+
shader_variants:
13+
- NAME: q_8w_linear_tiled_o4x4_texture3d_float
14+
STORAGE: texture3d
15+
TILE_ROWS: 4
16+
- NAME: q_8w_linear_tiled_o4x6_texture3d_float
17+
STORAGE: texture3d
18+
TILE_ROWS: 6

‎backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 74 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -160,100 +160,111 @@ void add_q_8w_linear_node(
160160
}
161161
}
162162

163-
void add_q_8w_linear_optimized_node(
163+
void add_q_8w_linear_tiled_node(
164164
ComputeGraph& graph,
165165
const ValueRef mat1,
166166
const ValueRef q_mat2_data,
167167
const ValueRef scales_data,
168168
const ValueRef out) {
169-
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
170-
ValueRef mat1_W_packed = mat1;
171-
ValueRef out_W_packed = out;
172-
if (!graph.is_buffer_storage(out) &&
173-
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
174-
// Ensure mat1 is width packed
175-
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
176-
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
177-
// Ensure out is packed correctly
178-
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
179-
}
180-
181169
utils::StorageType stype = graph.storage_type_of(out);
182-
ValueRef q_mat2 =
183-
prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked);
170+
ValueRef q_mat2 = prepack_standard_hw_transposed(
171+
graph, q_mat2_data, stype, utils::kWidthPacked);
184172
ValueRef scales =
185173
prepack_standard(graph, scales_data, stype, utils::kWidthPacked);
186174

187-
std::string kernel_name = "q_8w_linear_optimized";
175+
std::string kernel_name = "q_8w_linear_tiled";
188176
kernel_name.reserve(kShaderNameReserve);
189-
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
190-
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
191-
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
192-
const int mat1_dims = mat1_sizes.size();
193-
if (mat1_dims == 3) {
194-
kernel_name = "batch_" + kernel_name;
195-
}
196-
if (mat1_sizes.at(mat1_dims - 2) < 8) {
197-
kernel_name += "_tile_row_2";
177+
std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
178+
const int64_t M = utils::val_at(-2, mat1_sizes);
179+
int out_tile_nrows = 4;
180+
if (M % 6 == 0) {
181+
kernel_name += "_o4x6";
182+
out_tile_nrows = 6;
198183
} else {
199-
kernel_name += "_tile_row_4";
184+
kernel_name += "_o4x4";
185+
out_tile_nrows = 4;
200186
}
201187

202-
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
203-
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
188+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
189+
add_dtype_suffix(kernel_name, graph.dtype_of(out));
204190

205-
vkapi::ParamsBindList ubos({});
191+
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
192+
global_wg_size[1] = global_wg_size[1] / out_tile_nrows;
206193

207-
utils::uvec3 global_size;
208-
utils::uvec3 local_size;
209-
if (graph.is_buffer_storage(out)) {
210-
ubos.append(
211-
{graph.sizes_ubo(out_W_packed),
212-
graph.strides_ubo(out_W_packed),
213-
graph.numel_ubo(out_W_packed),
214-
graph.sizes_ubo(mat1_W_packed),
215-
graph.strides_ubo(mat1_W_packed),
216-
graph.strides_ubo(q_mat2),
217-
graph.strides_ubo(scales)});
218-
global_size = graph.create_global_wg_size(out_W_packed);
219-
local_size = graph.create_local_wg_size(out_W_packed);
220-
} else {
221-
global_size = graph.logical_limits_of(out_W_packed);
222-
ubos.append(
223-
{graph.logical_limits_ubo(out_W_packed),
224-
graph.sizes_ubo(mat1_W_packed)});
225-
if (mat1_sizes.at(mat1_dims - 2) < 8) {
226-
global_size = global_size = utils::divup_vec(global_size, {1, 2, 1});
227-
} else {
228-
global_size = utils::divup_vec(global_size, {1, 4, 1});
229-
}
230-
local_size = {16, 3, 1};
231-
}
194+
utils::uvec3 local_wg_size{64, 1, 1};
232195

233196
graph.execute_nodes().emplace_back(new DispatchNode(
234197
graph,
235198
VK_KERNEL_FROM_STR(kernel_name),
236-
global_size,
237-
local_size,
199+
global_wg_size,
200+
local_wg_size,
238201
// Inputs and Outputs
239-
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
240-
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
202+
{{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}},
241203
// Shader params buffers
242-
ubos,
204+
{},
243205
// Specialization Constants
244-
{}, // spec_vars,
206+
{},
245207
// Resizing Logic
246-
resize_q_8w_linear_node));
208+
resize_q_8w_linear_node,
209+
{},
210+
// Push Constants
211+
{{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}}));
212+
}
247213

248-
if (!graph.is_buffer_storage(out)) {
249-
viewFn(graph, {out_W_packed, graph.add_none(), out});
214+
bool can_use_tiled_impl(
215+
ComputeGraph& graph,
216+
const ValueRef mat1,
217+
const ValueRef q_mat2_data,
218+
const ValueRef scales_data,
219+
const ValueRef out) {
220+
(void)q_mat2_data;
221+
(void)scales_data;
222+
223+
// Check if mat1 is not a 3D tensor or that batches = 1
224+
// TODO(ssjia): Add support for batches in the tiled impl
225+
if (graph.dim_of(mat1) == 3 && graph.size_at<int>(-1, mat1) != 1) {
226+
return false;
227+
}
228+
// Check that K is a multiple of 4
229+
if (graph.size_at<int>(-1, mat1) % 4 != 0) {
230+
return false;
250231
}
232+
// Check that M is a multiple of 4 or 6
233+
if (graph.size_at<int>(-2, mat1) % 4 != 0 &&
234+
graph.size_at<int>(-2, mat1) % 6 != 0) {
235+
return false;
236+
}
237+
// Check that the storage type is texture
238+
// TODO(ssjia): Add support for buffer storage in the tiled impl
239+
if (graph.storage_type_of(out) != utils::kTexture3D) {
240+
return false;
241+
}
242+
// Check that the packed dim is the width dim
243+
if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
244+
return false;
245+
}
246+
// Check that no special axis mapping is used for the input
247+
// TODO(ssjia): Add support for non-standard axis mapping in the tiled impl
248+
if (!graph.has_standard_axis_map(mat1)) {
249+
return false;
250+
}
251+
// Check that no special axis mapping is used for the output
252+
// TODO(ssjia): Add support for non-standard axis mapping in the tiled impl
253+
if (!graph.has_standard_axis_map(out)) {
254+
return false;
255+
}
256+
257+
return true;
251258
}
252259

253260
void weight_int8pack_mm(
254261
ComputeGraph& graph,
255262
const std::vector<ValueRef>& args) {
256263
check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]);
264+
if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) {
265+
return add_q_8w_linear_tiled_node(
266+
graph, args[0], args[1], args[2], args[3]);
267+
}
257268
return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]);
258269
}
259270

0 commit comments

Comments
 (0)
Please sign in to comment.