@@ -160,100 +160,111 @@ void add_q_8w_linear_node(
160
160
}
161
161
}
162
162
163
- void add_q_8w_linear_optimized_node (
163
+ void add_q_8w_linear_tiled_node (
164
164
ComputeGraph& graph,
165
165
const ValueRef mat1,
166
166
const ValueRef q_mat2_data,
167
167
const ValueRef scales_data,
168
168
const ValueRef out) {
169
- auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
170
- ValueRef mat1_W_packed = mat1;
171
- ValueRef out_W_packed = out;
172
- if (!graph.is_buffer_storage (out) &&
173
- graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
174
- // Ensure mat1 is width packed
175
- mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
176
- viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
177
- // Ensure out is packed correctly
178
- out_W_packed = graph.add_tensor_like (out, utils::kWidthPacked );
179
- }
180
-
181
169
utils::StorageType stype = graph.storage_type_of (out);
182
- ValueRef q_mat2 =
183
- prepack_standard ( graph, q_mat2_data, stype, utils::kWidthPacked );
170
+ ValueRef q_mat2 = prepack_standard_hw_transposed (
171
+ graph, q_mat2_data, stype, utils::kWidthPacked );
184
172
ValueRef scales =
185
173
prepack_standard (graph, scales_data, stype, utils::kWidthPacked );
186
174
187
- std::string kernel_name = " q_8w_linear_optimized " ;
175
+ std::string kernel_name = " q_8w_linear_tiled " ;
188
176
kernel_name.reserve (kShaderNameReserve );
189
- add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1_W_packed));
190
- add_packed_dim_suffix (kernel_name, graph.packed_dim_of (q_mat2));
191
- std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1_W_packed);
192
- const int mat1_dims = mat1_sizes.size ();
193
- if (mat1_dims == 3 ) {
194
- kernel_name = " batch_" + kernel_name;
195
- }
196
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
197
- kernel_name += " _tile_row_2" ;
177
+ std::vector<int64_t > mat1_sizes = graph.sizes_of (mat1);
178
+ const int64_t M = utils::val_at (-2 , mat1_sizes);
179
+ int out_tile_nrows = 4 ;
180
+ if (M % 6 == 0 ) {
181
+ kernel_name += " _o4x6" ;
182
+ out_tile_nrows = 6 ;
198
183
} else {
199
- kernel_name += " _tile_row_4" ;
184
+ kernel_name += " _o4x4" ;
185
+ out_tile_nrows = 4 ;
200
186
}
201
187
202
- add_dtype_suffix (kernel_name, graph.dtype_of (out_W_packed ));
203
- add_storage_type_suffix (kernel_name, graph.storage_type_of (out_W_packed ));
188
+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out ));
189
+ add_dtype_suffix (kernel_name, graph.dtype_of (out ));
204
190
205
- vkapi::ParamsBindList ubos ({});
191
+ utils::uvec3 global_wg_size = graph.logical_limits_of (out);
192
+ global_wg_size[1 ] = global_wg_size[1 ] / out_tile_nrows;
206
193
207
- utils::uvec3 global_size;
208
- utils::uvec3 local_size;
209
- if (graph.is_buffer_storage (out)) {
210
- ubos.append (
211
- {graph.sizes_ubo (out_W_packed),
212
- graph.strides_ubo (out_W_packed),
213
- graph.numel_ubo (out_W_packed),
214
- graph.sizes_ubo (mat1_W_packed),
215
- graph.strides_ubo (mat1_W_packed),
216
- graph.strides_ubo (q_mat2),
217
- graph.strides_ubo (scales)});
218
- global_size = graph.create_global_wg_size (out_W_packed);
219
- local_size = graph.create_local_wg_size (out_W_packed);
220
- } else {
221
- global_size = graph.logical_limits_of (out_W_packed);
222
- ubos.append (
223
- {graph.logical_limits_ubo (out_W_packed),
224
- graph.sizes_ubo (mat1_W_packed)});
225
- if (mat1_sizes.at (mat1_dims - 2 ) < 8 ) {
226
- global_size = global_size = utils::divup_vec (global_size, {1 , 2 , 1 });
227
- } else {
228
- global_size = utils::divup_vec (global_size, {1 , 4 , 1 });
229
- }
230
- local_size = {16 , 3 , 1 };
231
- }
194
+ utils::uvec3 local_wg_size{64 , 1 , 1 };
232
195
233
196
graph.execute_nodes ().emplace_back (new DispatchNode (
234
197
graph,
235
198
VK_KERNEL_FROM_STR (kernel_name),
236
- global_size ,
237
- local_size ,
199
+ global_wg_size ,
200
+ local_wg_size ,
238
201
// Inputs and Outputs
239
- {{out_W_packed, vkapi::MemoryAccessType::WRITE},
240
- {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
202
+ {{out, vkapi::kWrite }, {{mat1, q_mat2, scales}, vkapi::kRead }},
241
203
// Shader params buffers
242
- ubos ,
204
+ {} ,
243
205
// Specialization Constants
244
- {}, // spec_vars,
206
+ {},
245
207
// Resizing Logic
246
- resize_q_8w_linear_node));
208
+ resize_q_8w_linear_node,
209
+ {},
210
+ // Push Constants
211
+ {{graph.sizes_pc_of (out), graph.sizes_pc_of (mat1)}}));
212
+ }
247
213
248
- if (!graph.is_buffer_storage (out)) {
249
- viewFn (graph, {out_W_packed, graph.add_none (), out});
214
+ bool can_use_tiled_impl (
215
+ ComputeGraph& graph,
216
+ const ValueRef mat1,
217
+ const ValueRef q_mat2_data,
218
+ const ValueRef scales_data,
219
+ const ValueRef out) {
220
+ (void )q_mat2_data;
221
+ (void )scales_data;
222
+
223
+ // Check if mat1 is not a 3D tensor or that batches = 1
224
+ // TODO(ssjia): Add support for batches in the tiled impl
225
+ if (graph.dim_of (mat1) == 3 && graph.size_at <int >(-1 , mat1) != 1 ) {
226
+ return false ;
227
+ }
228
+ // Check that K is a multiple of 4
229
+ if (graph.size_at <int >(-1 , mat1) % 4 != 0 ) {
230
+ return false ;
250
231
}
232
+ // Check that M is a multiple of 4 or 6
233
+ if (graph.size_at <int >(-2 , mat1) % 4 != 0 &&
234
+ graph.size_at <int >(-2 , mat1) % 6 != 0 ) {
235
+ return false ;
236
+ }
237
+ // Check that the storage type is texture
238
+ // TODO(ssjia): Add support for buffer storage in the tiled impl
239
+ if (graph.storage_type_of (out) != utils::kTexture3D ) {
240
+ return false ;
241
+ }
242
+ // Check that the packed dim is the width dim
243
+ if (graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
244
+ return false ;
245
+ }
246
+ // Check that no special axis mapping is used for the input
247
+ // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl
248
+ if (!graph.has_standard_axis_map (mat1)) {
249
+ return false ;
250
+ }
251
+ // Check that no special axis mapping is used for the output
252
+ // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl
253
+ if (!graph.has_standard_axis_map (out)) {
254
+ return false ;
255
+ }
256
+
257
+ return true ;
251
258
}
252
259
253
260
void weight_int8pack_mm (
254
261
ComputeGraph& graph,
255
262
const std::vector<ValueRef>& args) {
256
263
check_q_8w_linear_args (graph, args[0 ], args[1 ], args[2 ], args[3 ]);
264
+ if (can_use_tiled_impl (graph, args[0 ], args[1 ], args[2 ], args[3 ])) {
265
+ return add_q_8w_linear_tiled_node (
266
+ graph, args[0 ], args[1 ], args[2 ], args[3 ]);
267
+ }
257
268
return add_q_8w_linear_node (graph, args[0 ], args[1 ], args[2 ], args[3 ]);
258
269
}
259
270
0 commit comments