diff --git a/backends/transforms/fuse_dequant_linear.py b/backends/transforms/fuse_dequant_linear.py deleted file mode 100644 index 235715ac74f..00000000000 --- a/backends/transforms/fuse_dequant_linear.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class FuseDequantLinearPass(ExportPass): - """ - Fuses weight dequantize_per_channel nodes with linear nodes into - weight_int8pack_mm nodes, for 8-bit weight-only quantization. - - Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm - Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add - """ - - def fuse_dequant_with_linear( - self, - graph_module: torch.fx.GraphModule, - dequant_node: torch.fx.Node, - linear_node: torch.fx.Node, - ) -> None: - activations = linear_node.args[0] - bias = None - if len(linear_node.args) > 2: - bias = linear_node.args[2] - quant_weight = dequant_node.args[0] - scale = dequant_node.args[1] - - with graph_module.graph.inserting_before(linear_node): - weight_int8pack_mm_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten._weight_int8pack_mm.default, - (activations, quant_weight, scale), - ) - if bias: - add_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.add.Tensor, - (weight_int8pack_mm_node, bias), - ) - linear_node.replace_all_uses_with(add_node) - else: - linear_node.replace_all_uses_with(weight_int8pack_mm_node) - graph_module.graph.erase_node(linear_node) - graph_module.graph.erase_node(dequant_node) - - def is_node_target( - self, node: torch.fx.Node, target: torch._ops.OperatorBase - ) -> bool: - return node.op == "call_function" and node.target == target - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - if self.is_node_target(node, exir_ops.edge.aten.linear.default): - weight_node = node.args[1] - if self.is_node_target( - weight_node, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - ): - # only fuse if weight tensor is int8 packed - quant_weight = weight_node.args[0] - if quant_weight.meta["val"].dtype != torch.int8: - continue - self.fuse_dequant_with_linear(graph_module, weight_node, node) - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 66ff9111f52..71980195962 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,21 +77,6 @@ def define_common_targets(): ], ) - runtime.python_library( - name = "fuse_dequant_linear", - srcs = ["fuse_dequant_linear.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 5478ad0eab6..cfe20892994 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -3,6 +3,23 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") +runtime.python_library( + name = "fuse_quantized_ops", + srcs = ["fuse_quantized_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/transforms:utils", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/backends/vulkan:utils_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "insert_prepack_nodes", srcs = ["insert_prepack_nodes.py"], @@ -13,6 +30,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan:op_registry", ], ) @@ -110,6 +128,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fuse_quantized_ops", ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_asserts", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 220afa6a35c..7ff93a6ee38 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,9 @@ # pyre-strict +from executorch.backends.vulkan._passes.fuse_quantized_ops import ( + FuseQuantizedOpsTransform, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, @@ -26,6 +29,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FuseQuantizedOpsTransform", "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "remove_asserts", diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py new file mode 100644 index 00000000000..d510e1d4342 --- /dev/null +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Optional, Tuple + +import executorch.backends.vulkan.utils as utils +import torch + +import torch.nn.functional as F + +from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +################# +## linear_qcnw ## +################# + + +def matches_linear_qcnw_pattern( # noqa: C901 + program: ExportedProgram, node: torch.fx.Node +) -> Optional[Tuple[torch.qscheme, int]]: + """ + Checks if the nodes surrounding a linear node matches the pattern for weight only + quantized linear, where the weight is quantized channelswise to n bits. + + If the graph pattern matches, then return a tuple of (quantization_method, nbits) + describing the type of quantization used for the weights. Otherwise, return None. + """ + if not utils.is_linear_node(node): + return None + + input_node = node.args[0] + weight_node = node.args[1] + + # Type checking + if not isinstance(weight_node, torch.fx.Node): + return None + if not isinstance(input_node, torch.fx.Node): + return None + + # The input arg should not be a dequant node; if it is, then it is indicative that + # dynamically quantized linear should be used instead + if utils.is_dequant_node(input_node): + return None + + # The weight arg should be a dequant node dequantizing the quantized weight + # Furthermore, the op expects per channel quantization of the weight + if not utils.is_dequant_per_channel_node(weight_node): + return None + + orig_weight = weight_node.args[0] + zeros = weight_node.args[2] + + # Type checking + if not isinstance(orig_weight, torch.fx.Node): + return None + if not is_param_node(program, orig_weight): + return None + if not isinstance(zeros, torch.fx.Node): + return None + if not is_param_node(program, zeros): + return None + + zeros_tensor = get_param_tensor(program, zeros) + if not isinstance(zeros_tensor, torch.Tensor): + return None + + quant_method = torch.per_channel_affine + # Check for symmetric quantization, where the zeros used for dequantization will + # actually be all zeros. + if torch.all(zeros_tensor == 0): + quant_method = torch.per_channel_symmetric + + orig_weight_tensor = get_param_tensor(program, orig_weight) + if not isinstance(orig_weight_tensor, torch.Tensor): + return None + # Sanity check the dtype of the quantized weight + if orig_weight_tensor.dtype != torch.int8: + return None + + quant_min = orig_weight_tensor.min().item() + quant_max = orig_weight_tensor.max().item() + # Determine the number of bits the weight has been quantized to + if quant_min >= -8 and quant_max <= 7: + return quant_method, 4 + elif quant_min >= -128 and quant_max <= 127: + return quant_method, 8 + + return None + + +def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: + """ + Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed + weight tensor by packing 2 4-bit values in one unsigned 8-bit value. + + An input weight tensor of shape (M, K) will produce a packed weight tensor of shape + (M, K / 2). + """ + + # Assert we got a properly quantized tensor. + min, max = inp.min().item(), inp.max().item() + assert ( + max <= 7 and min >= -8 + ), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]" + + # Assuming we have a 2d tensor + if inp.ndim != 2: + inp = inp.squeeze() + assert ( + inp.ndim == 2 + ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}" + + # pad ic + if inp.shape[-1] % 2 != 0: + inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) + + # Shape after padding + oc, ic = inp.shape + assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" + + # Adjust inp tensor for zp + inp = inp.to(dtype=torch.uint8) + 8 + + # Prepare the Result tensor + inp = inp.contiguous().view(-1) + return (inp[::2] << 4 | inp[1::2]).view(oc, int(ic / 2)) + + +def fuse_into_linear_qcnw_node( + program: ExportedProgram, + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, + quant_method: torch.qscheme, + nbits: int, +) -> None: + """ + The weight_int8pack_mm operator represents a weight only quantized linear operator, + where the weight tensor has been quantized channelswise to nbits bits. + + After the PT2E quantization flow, the expected graph pattern is + + dq_weight = dequantize(weight, scales) + out = linear(activation, dq_weight, bias?) + + The goal of this function is to condense that sequence into + + out = quantized_linear(activation, dq_weight, scales) + out = out + bias + """ + activation = linear_node.args[0] + dq_weight_node = linear_node.args[1] + assert isinstance(activation, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + bias = None + if len(linear_node.args) > 2: + bias = linear_node.args[2] + assert isinstance(bias, torch.fx.Node) + + orig_weight = dq_weight_node.args[0] + scale = dq_weight_node.args[1] + + # For 4 bit quantization, pack the weight tensor + if nbits == 4: + assert isinstance(orig_weight, torch.fx.Node) + orig_weight_tensor = get_param_tensor(program, orig_weight) + assert isinstance(orig_weight_tensor, torch.Tensor) + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) + utils.update_program_state_dict( + program, + orig_weight.name, + packed_weight_tensor, + ) + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) + + if nbits == 8 and quant_method == torch.per_channel_symmetric: + op_target = exir_ops.edge.aten._weight_int8pack_mm.default + elif nbits == 4 and quant_method == torch.per_channel_symmetric: + op_target = exir_ops.edge.et_vk.linear_qcs4w.default + else: + raise NotImplementedError( + "only 4 and 8 bits per channel symmetric quant supported for linear_qcnw" + ) + + with graph_module.graph.inserting_before(linear_node): + weight_int8pack_mm_node = graph_module.graph.create_node( + "call_function", + op_target, + (activation, orig_weight, scale), + ) + if bias: + add_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, + (weight_int8pack_mm_node, bias), + ) + linear_node.replace_all_uses_with(add_node) + else: + linear_node.replace_all_uses_with(weight_int8pack_mm_node) + graph_module.graph.erase_node(linear_node) + graph_module.graph.erase_node(dq_weight_node) + + +class FuseQuantizedOpsTransform(ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.program = exported_program + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + qcnw_details = matches_linear_qcnw_pattern(self.program, node) + if qcnw_details is not None: + qcnw_method, qcnw_nbits = qcnw_details + fuse_into_linear_qcnw_node( + self.program, graph_module, node, qcnw_method, qcnw_nbits + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 0275239a86a..af6fcbfbb14 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -184,6 +184,53 @@ def linear_weight_int4_impl( lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) +################# +## linear_qcs4w ## +################# + + +def linear_qcs4w( + x: torch.Tensor, + weights_4x2: torch.Tensor, + scales: torch.Tensor, +): + original_x_shape = x.shape + x = x.reshape(-1, original_x_shape[-1]) + + unpacked_weights_shape = weights_4x2.shape + out_features = unpacked_weights_shape[0] + in_features = unpacked_weights_shape[1] + + weights_unpacked = torch.empty( + (out_features, in_features * 2), dtype=torch.int8, device=weights_4x2.device + ) + + weights_unpacked[:, ::2] = weights_4x2 >> 4 + weights_unpacked[:, 1::2] = weights_4x2 & 0x0F + + n_bit = 8 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + dq_weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights_unpacked, + scales, + None, + 0, + quant_min, + quant_max, + torch.int8, + ) + + out = torch.nn.functional.linear(x, dq_weights) + out_shape = original_x_shape[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_qcs4w" +lib.define(f"{name}(Tensor self, Tensor weight, Tensor scales) -> Tensor") +lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") +linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index aa3cca5f384..8502e254ec5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -377,7 +377,12 @@ def register_mm_op(features: OpFeatures): return features -@update_features(exir_ops.edge.aten._weight_int8pack_mm.default) +@update_features( + [ + exir_ops.edge.aten._weight_int8pack_mm.default, + exir_ops.edge.et_vk.linear_qcs4w.default, + ] +) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_axis_map=False, diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 2ea3e321dc3..b2f1a658040 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, Optional +from typing import Callable, Optional import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( @@ -18,53 +18,60 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.observer import PerChannelMinMaxObserver from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node __all__ = [ "VulkanQuantizer", - "get_weight_quantization_config", + "get_linear_weight_qcs_qspec", + "get_linear_weight_only_qcs_xnn_qconfig", ] -@functools.lru_cache -def get_weight_quantization_config( - is_per_channel: bool = True, - weight_qmin: int = -128, - weight_qmax: int = 127, -) -> QuantizationConfig: - - weight_qscheme = ( - torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric - ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( - PerChannelMinMaxObserver if is_per_channel else MinMaxObserver - ) - extra_args: Dict[str, Any] = {"eps": 2**-12} +def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec: + """ + Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization + of weight tensors of linear layers to the number of bits specified by quant_bits. + """ + weight_observer = PerChannelMinMaxObserver + assert quant_bits in { + 8, + 4, + }, f"Unsupported weight quantization bits: {quant_bits}" - weight_quantization_spec = QuantizationSpec( + quant_min = -(2 ** (quant_bits - 1)) + quant_max = 2 ** (quant_bits - 1) - 1 + qscheme = torch.per_channel_symmetric + + return QuantizationSpec( dtype=torch.int8, - quant_min=weight_qmin, - quant_max=weight_qmax, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( - **extra_args - ), + observer_or_fake_quant_ctr=weight_observer, ) - quantization_config = QuantizationConfig( + +@functools.lru_cache +def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig: + """ + Return a XNNPACKQuantizer QuantizationConfig class instance that specifies + quantizing the weight tensors of linear layers using per-channel symmetric (qcs) + quantization to the number of bits specified by quant_bits. + """ + weight_qspec = get_linear_weight_qcs_qspec(quant_bits) + + return QuantizationConfig( input_activation=None, output_activation=None, - weight=weight_quantization_spec, + weight=weight_qspec, bias=None, is_qat=False, ) - return quantization_config _SUPPORTED_OPS = [ diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 2126104430f..2b41d2b7e1a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -41,22 +41,32 @@ /* * Fast division by 4 using bit shifting */ -#define div4(x) (x >> 2) +#define div4(x) ((x) >> 2) + +/* + * Fast multiplication by 4 using bit shifting + */ +#define mul4(x) ((x) << 2) /* * Divides input and rounds up to 4 */ -#define divup4(x) ((x + 3) >> 2) +#define divup4(x) (((x) + 3) >> 2) + +/* + * Divides input by denominator and rounds up + */ +#define divup(x, d) (((x) + (d) - 1) / (d)) /* * Aligns input to the next multiple of 4 */ -#define alignup4(x) ((x + 3) & -4) +#define alignup4(x) (((x) + 3) & -4) /* * Fast modulo by 4 using bit masking */ -#define mod4(x) (x & 3) +#define mod4(x) ((x) & 3) /* * Find the packed dimension of a tensor given its strides. The packed dimension diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl index 3ad9e759910..c766a3cd7d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl @@ -14,6 +14,7 @@ #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} #define TILE_ROWS ${TILE_ROWS} +#define TILE_TXCOLS ${TILE_TXCOLS} #define NGROUPS 8 #define NWORKERS 8 @@ -29,7 +30,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +$if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} layout(push_constant) uniform restrict Block { @@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS]; +shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS]; void main() { - const uint out_width_ntexels = divup4(out_sizes.x); - const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2; - const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS; + // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. + $if TILE_TXCOLS > 1: + const uint global_wg_x = uint(divup(out_sizes.x, 4 * TILE_TXCOLS)); + const uint out_txcol = uint( + (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + $else: + const uint global_wg_x = uint(divup4(out_sizes.x)); + const uint out_txcol = uint(gl_GlobalInvocationID.x % global_wg_x); + + const uint out_row = uint( + (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + + $if QUANT_NBITS == 4: + const uint weight_txcol = uint(out_txcol / 2); const int gid = int(gl_LocalInvocationID.x); // group id const int wid = int(gl_LocalInvocationID.z); // worker id @@ -56,46 +71,78 @@ void main() { return; } - VEC4_T a[TILE_ROWS]; - VEC4_T b[4]; - VEC4_T local_c[TILE_ROWS]; + VEC4_T mat1[TILE_ROWS]; + VEC4_T qmat2[4][TILE_TXCOLS]; + VEC4_T local_sums[TILE_ROWS][TILE_TXCOLS]; - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - local_c[i] = VEC4_T(0.0); + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + local_sums[r][${c}] = VEC4_T(0.0); } - $if SCALES_STORAGE == "buffer": - const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); - $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); - - for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) { - // Preload t_weight - [[unroll]] for (int i = 0; i < 4; i++) { - $if WEIGHT_STORAGE == "buffer": - b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2]; + VEC4_T scales[TILE_TXCOLS]; + $for c in range(TILE_TXCOLS): + $if SCALES_STORAGE == "buffer": + scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); + $else: + scales[${c}] = VEC4_T( + texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0)); + + for (int pos = (4 * wid), txpos = wid; + pos < in_sizes.x; + pos += (4 * NWORKERS), txpos += NWORKERS) { + $if WEIGHT_STORAGE == "buffer": + uint qmat2_bufi; + uint weight_row_txstride = div4(weight_sizes.x); + + // Preload weight tensor + [[unroll]] for (int r = 0; r < 4; r++) { + $if QUANT_NBITS == 4: + $for c in range(0, TILE_TXCOLS, 2): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; + const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + $else: + const uvec4 packed_weight_tex = texelFetch( + t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); + + qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); + qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); $else: - b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); + $for c in range(TILE_TXCOLS): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; + qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + $else: + qmat2[r][${c}] = VEC4_T( + texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); } - // Preload t_in - for (int i = 0; i < TILE_ROWS; i++) { + + $if IN_STORAGE == "buffer": + uint in_row_txstride = div4(in_sizes.x); + + // Preload input tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { $if IN_STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos]; $else: - a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + mat1[i] = VEC4_T( + texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0)); } // Accumulate partial output - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - local_c[i] += a[i].x * b[0] + - a[i].y * b[1] + - a[i].z * b[2] + - a[i].w * b[3]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + local_sums[r][${c}] += mat1[r].x * qmat2[0][${c}] + + mat1[r].y * qmat2[1][${c}] + + mat1[r].z * qmat2[2][${c}] + + mat1[r].w * qmat2[3][${c}]; } } - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - partial_c[gid][wid][i] = local_c[i]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + partial_sums[gid][wid][r][${c}] = local_sums[r][${c}]; } memoryBarrierShared(); @@ -105,21 +152,33 @@ void main() { return; } - VEC4_T c[TILE_ROWS]; + VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; + + for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] = VEC4_T(0.0); - for (int row = 0; row < TILE_ROWS; ++row) { - c[row] = VEC4_T(0.0); [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { - c[row] += partial_c[gid][worker][row]; + $for c in range(TILE_TXCOLS): + sums[r][${c}] += partial_sums[gid][worker][r][${c}]; } } - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if OUT_STORAGE == "buffer": - if (out_row + i < out_sizes.y) { - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; - } - $else: - imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + $if OUT_STORAGE == "buffer": + uint out_bufi; + uint out_row_txstride = div4(out_sizes.x); + + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + out_bufi = (out_row + r) * out_row_txstride + out_txcol; + t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + } + $else: + imageStore( + t_out, + ivec3(out_txcol + ${c}, out_row + r, 0), + sums[r][${c}] * scales[${c}]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml index e0477a3a3d1..3dff6855142 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml @@ -12,6 +12,8 @@ linear_qcsnw_coop: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 + TILE_TXCOLS: 1 + QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: - VALUE: 1 @@ -26,3 +28,11 @@ linear_qcsnw_coop: OUT_STORAGE: buffer WEIGHT_STORAGE: buffer SCALES_STORAGE: buffer + - NAME: linear_qcs4w_coop_texture3d_texture3d_texture2d_texture2d_float + TILE_TXCOLS: 2 + QUANT_NBITS: 4 + - NAME: linear_qcs4w_coop_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + TILE_TXCOLS: 2 + QUANT_NBITS: 4 diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 3ef952ea34d..f6f05aab7ca 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -14,6 +14,7 @@ #define VEC4_T ${buffer_gvec_type(DTYPE, 4)} #define TILE_ROWS ${TILE_ROWS} +#define TILE_TXCOLS ${TILE_TXCOLS} ${define_required_extensions(DTYPE)} @@ -26,7 +27,10 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +$if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} @@ -43,57 +47,110 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require void main() { - const uint16_t out_width_ntexels = uint16_t(divup4(out_sizes.x)); - const uint16_t out_col = uint16_t((gl_GlobalInvocationID.x % out_width_ntexels) << 2); - const uint16_t out_row = uint16_t((gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS); + // txcol stands for "texel column". One txcol corresponds to 4 scalar columns. + $if TILE_TXCOLS > 1: + const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS)); + const uint16_t out_txcol = uint16_t( + (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS); + $else: + const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x)); + const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x); + + const uint16_t out_row = uint16_t( + (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS); + + $if QUANT_NBITS == 4: + const uint16_t weight_txcol = uint16_t(out_txcol / 2); if (out_row >= uint16_t(out_sizes.y)) { return; } - VEC4_T a[TILE_ROWS]; - VEC4_T b[4]; - VEC4_T c[TILE_ROWS]; + VEC4_T mat1[TILE_ROWS]; + VEC4_T qmat2[4][TILE_TXCOLS]; + VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; - $if SCALES_STORAGE == "buffer": - const VEC4_T scales = VEC4_T(t_scales[int(out_col >> 2)]); - $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2(out_col >> 2, 0), 0)); + VEC4_T scales[TILE_TXCOLS]; + $for c in range(TILE_TXCOLS): + $if SCALES_STORAGE == "buffer": + scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]); + $else: + scales[${c}] = VEC4_T( + texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0)); - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - c[i] = VEC4_T(0.0); + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] = VEC4_T(0.0); } - for (uint16_t pos = uint16_t(0); pos < uint16_t(in_sizes.x); pos += uint16_t(4)) { + for (uint16_t pos = uint16_t(0), txpos = uint16_t(0); + pos < uint16_t(in_sizes.x); + pos += uint16_t(4), txpos += uint16_t(1)) { + $if WEIGHT_STORAGE == "buffer": + uint qmat2_bufi; + uint weight_row_txstride = div4(weight_sizes.x); + // Preload weight tensor - [[unroll]] for (int i = 0; i < 4; i++) { - $if WEIGHT_STORAGE == "buffer": - b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2]; + [[unroll]] for (int r = 0; r < 4; r++) { + $if QUANT_NBITS == 4: + $for c in range(0, TILE_TXCOLS, 2): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; + const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}] + $else: + const uvec4 packed_weight_tex = texelFetch( + t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0); + + qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0); + qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0); $else: - b[i] = VEC4_T(texelFetch(t_weight, u16vec2(out_col >> 2, pos + i), 0)); + $for c in range(TILE_TXCOLS): + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; + qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}]; + $else: + qmat2[r][${c}] = VEC4_T( + texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0)); } + $if IN_STORAGE == "buffer": + uint in_row_txstride = div4(in_sizes.x); + // Preload input tensor [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { $if IN_STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos]; $else: - a[i] = VEC4_T(texelFetch(t_in, u16vec3(pos >> 2, out_row + i, 0), 0)); + mat1[i] = VEC4_T( + texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0)); } // Accumulate output - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + sums[r][${c}] += mat1[r].x * qmat2[0][${c}] + + mat1[r].y * qmat2[1][${c}] + + mat1[r].z * qmat2[2][${c}] + + mat1[r].w * qmat2[3][${c}]; } } // Store to output tensor - [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if OUT_STORAGE == "buffer": - if (out_row + i < out_sizes.y) { - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; - } - $else: - imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + $if OUT_STORAGE == "buffer": + uint out_bufi; + uint out_row_txstride = div4(out_sizes.x); + + [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { + $for c in range(TILE_TXCOLS): + $if OUT_STORAGE == "buffer": + if (out_row + r < out_sizes.y) { + out_bufi = (out_row + r) * out_row_txstride + out_txcol; + t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}]; + } + $else: + imageStore( + t_out, + ivec3(out_txcol + ${c}, out_row + r, 0), + sums[r][${c}] * scales[${c}]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index f9f0134d995..1c9ec4e524a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -12,6 +12,8 @@ linear_qcsnw_tiled: WEIGHT_STORAGE: texture2d SCALES_STORAGE: texture2d TILE_ROWS: 4 + TILE_TXCOLS: 1 + QUANT_NBITS: 8 generate_variant_forall: TILE_ROWS: - VALUE: 1 @@ -30,3 +32,11 @@ linear_qcsnw_tiled: OUT_STORAGE: buffer WEIGHT_STORAGE: buffer SCALES_STORAGE: buffer + - NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float + TILE_TXCOLS: 2 + QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + TILE_TXCOLS: 2 + QUANT_NBITS: 4 diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 85695488dfc..6e101195e3f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -17,6 +17,7 @@ namespace vkcompute { void check_linear_qcsnw_args( const ComputeGraph& graph, + const int quant_nbits, const ValueRef mat1, const ValueRef qmat2_data, const ValueRef scales, @@ -31,13 +32,20 @@ void check_linear_qcsnw_args( VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out)); - VK_CHECK_COND( - utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); - VK_CHECK_COND( - utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + if (quant_nbits == 4) { + VK_CHECK_COND( + utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes) * 2); + VK_CHECK_COND( + utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + } else { + VK_CHECK_COND( + utils::val_at(-1, mat1_sizes) == utils::val_at(-1, qmat2_sizes)); + VK_CHECK_COND( + utils::val_at(-1, scales_sizes) == utils::val_at(-2, qmat2_sizes)); + } } -void resize_linear_qcs8w_node( +void resize_linear_qcsnw_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { @@ -48,7 +56,12 @@ void resize_linear_qcs8w_node( vTensorPtr qmat2 = graph->get_tensor(args[1].refs[1]); const int out_cols = utils::val_at(-2, mat1->sizes()); - const int out_rows = utils::val_at(-1, qmat2->sizes()); + int out_rows = utils::val_at(-1, qmat2->sizes()); + // Byte dtype suggests 4-bit quantization in which case the weight tensor is + // packed with 2 values per byte. + if (qmat2->dtype() == vkapi::kByte) { + out_rows *= 2; + } std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { @@ -135,34 +148,40 @@ void add_linear_qcs8w_node( // Resize Args {}, // Resizing Logic - resize_linear_qcs8w_node)); + resize_linear_qcsnw_node)); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(out) != WHCN::kWidthDim) { viewFn(graph, {out_W_packed, graph.add_none(), out}); } } -void add_linear_qcs8w_tiled_node( +void add_linear_qcsnw_tiled_node( ComputeGraph& graph, const bool use_coop_algorithm, + const int quant_nbits, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - utils::StorageType q_mat2_storage = utils::kTexture2D; - uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); std::vector qmat2_orig_sizes = graph.sizes_of(q_mat2_data); const int64_t ndim = graph.dim_of(q_mat2_data); const int64_t K = qmat2_orig_sizes.at(ndim - 1); const int64_t N = qmat2_orig_sizes.at(ndim - 2); - if (N > max_extent * 4 || K > max_extent) { - q_mat2_storage = utils::kBuffer; - } + ValueRef q_mat2; + if (quant_nbits == 4) { + q_mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, q_mat2_data); + } else { + utils::StorageType q_mat2_storage = utils::kTexture2D; + if (N > max_extent * 4 || K > max_extent) { + q_mat2_storage = utils::kBuffer; + } - ValueRef q_mat2 = prepack_standard_hw_transposed( - graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + q_mat2 = prepack_standard_hw_transposed( + graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + } utils::StorageType scales_storage = utils::kTexture2D; if (N > max_extent) { @@ -171,8 +190,14 @@ void add_linear_qcs8w_tiled_node( ValueRef scales = prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); - std::string kernel_name = - use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled"; + std::string kernel_name; + if (quant_nbits == 4) { + kernel_name = + use_coop_algorithm ? "linear_qcs4w_coop" : "linear_qcs4w_tiled"; + } else { + kernel_name = + use_coop_algorithm ? "linear_qcs8w_coop" : "linear_qcs8w_tiled"; + } kernel_name.reserve(kShaderNameReserve); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); @@ -197,9 +222,16 @@ void add_linear_qcs8w_tiled_node( out_tile_nrows = 4; } + // Number of output texels in the output tile + uint32_t out_tile_ntxcols = 1; + if (quant_nbits == 4) { + out_tile_ntxcols = 2; + } + utils::uvec3 out_limits = graph.logical_limits_of(out); + uint32_t global_wg_x = utils::div_up(out_limits[0], out_tile_ntxcols); utils::uvec3 global_wg_size = { - out_limits[0] * (utils::div_up(out_limits[1], out_tile_nrows)), + global_wg_x * (utils::div_up(out_limits[1], out_tile_nrows)), 1, out_limits[2]}; @@ -224,7 +256,7 @@ void add_linear_qcs8w_tiled_node( // Resize Args {}, // Resizing Logic - resize_linear_qcs8w_node)); + resize_linear_qcsnw_node)); } bool can_use_tiled_impl( @@ -238,7 +270,7 @@ bool can_use_tiled_impl( // Check if mat1 is not a 3D tensor or that batches = 1 // TODO(ssjia): Add support for batches in the tiled impl - if (graph.dim_of(mat1) == 3 && graph.size_at(-1, mat1) != 1) { + if (graph.dim_of(mat1) == 3 && graph.size_at(0, mat1) != 1) { return false; } // Check that K is a multiple of 4 @@ -283,17 +315,27 @@ bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) { void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { - check_linear_qcsnw_args(graph, args[0], args[1], args[2], args[3]); + check_linear_qcsnw_args(graph, 8, args[0], args[1], args[2], args[3]); if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); - return add_linear_qcs8w_tiled_node( - graph, use_coop_algorithm, args[0], args[1], args[2], args[3]); + return add_linear_qcsnw_tiled_node( + graph, use_coop_algorithm, 8, args[0], args[1], args[2], args[3]); } return add_linear_qcs8w_node(graph, args[0], args[1], args[2], args[3]); } +void linear_qcs4w(ComputeGraph& graph, const std::vector& args) { + check_linear_qcsnw_args(graph, 4, args[0], args[1], args[2], args[3]); + + VK_CHECK_COND(can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])); + bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); + return add_linear_qcsnw_tiled_node( + graph, use_coop_algorithm, 4, args[0], args[1], args[2], args[3]); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten._weight_int8pack_mm.default, weight_int8pack_mm); + VK_REGISTER_OP(et_vk.linear_qcs4w.default, linear_qcs4w); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index b3ead94d8ff..8c5cb0093d9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -70,54 +70,6 @@ void resize_linear_qga4w_node( out->virtual_resize(new_out_sizes); } -ValueRef prepack_int4_linear_weight_transposed_interleaved( - ComputeGraph& graph, - const ValueRef qmat2_data) { - std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); - const int64_t ndim = graph.dim_of(qmat2_data); - - const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; - const int64_t N = qmat2_orig_sizes.at(ndim - 2); - const int64_t N_div2 = N / int64_t(2); - - utils::StorageType storage_type = utils::kTexture2D; - uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); - if (N_div2 > max_extent * 4 || K > max_extent) { - storage_type = utils::kBuffer; - } - - std::vector qmat2_sizes{K, N_div2}; - ValueRef qmat2 = graph.add_tensor( - qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); - - utils::uvec3 global_wg_size; - global_wg_size = graph.logical_limits_of(qmat2); - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2)); - - std::string kernel_name = - graph.context()->adapter_ptr()->has_full_int8_buffers_support() - ? "pack_int4_linear_weight_transposed_interleaved" - : "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer"; - add_storage_type_suffix(kernel_name, storage_type); - - graph.prepack_nodes().emplace_back(new PrepackNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), - // Inputs and Outputs - qmat2_data, - qmat2, - // UBOs - {}, - // Specialization Constants - {}, - // Push Constants - {graph.sizes_pc_of(qmat2)})); - - return qmat2; -} - void add_linear_qga4w_node( ComputeGraph& graph, const ValueRef mat1, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 32e63baeafc..f39b0fc33ff 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -246,6 +246,54 @@ ValueRef prepack_direct_copy_buffer( return tensor; } +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data) { + std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); + const int64_t ndim = graph.dim_of(qmat2_data); + + const int64_t K = qmat2_orig_sizes.at(ndim - 1) * 2; + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + const int64_t N_div2 = N / int64_t(2); + + utils::StorageType storage_type = utils::kTexture2D; + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + if (N_div2 > max_extent * 4 || K > max_extent) { + storage_type = utils::kBuffer; + } + + std::vector qmat2_sizes{K, N_div2}; + ValueRef qmat2 = graph.add_tensor( + qmat2_sizes, vkcompute::vkapi::kByte, storage_type, utils::kWidthPacked); + + utils::uvec3 global_wg_size; + global_wg_size = graph.logical_limits_of(qmat2); + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2)); + + std::string kernel_name = + graph.context()->adapter_ptr()->has_full_int8_buffers_support() + ? "pack_int4_linear_weight_transposed_interleaved" + : "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer"; + add_storage_type_suffix(kernel_name, storage_type); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_wg_size, + graph.create_local_wg_size(global_wg_size), + // Inputs and Outputs + qmat2_data, + qmat2, + // UBOs + {}, + // Specialization Constants + {}, + // Push Constants + {graph.sizes_pc_of(qmat2)})); + + return qmat2; +} + void prepack_op(ComputeGraph& graph, const std::vector& args) { return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 1b6f245bd34..090a3718295 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -87,4 +87,11 @@ ValueRef prepack_direct_copy_buffer( ComputeGraph& graph, const ValueRef tensor_data); +// +// Op specific prepack functions + +ValueRef prepack_int4_linear_weight_transposed_interleaved( + ComputeGraph& graph, + const ValueRef qmat2_data); + } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index aafc87ad2c3..665fde103fc 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -280,6 +280,7 @@ def define_common_targets(is_fbcode = False): deps = [ "//caffe2:torch", "//executorch/exir:tensor", + "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", "//executorch/backends/vulkan/serialization:lib", ] ) @@ -332,7 +333,6 @@ def define_common_targets(is_fbcode = False): "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", - "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze", diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 5ac87892762..8f07040d586 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -24,6 +24,19 @@ python_unittest( ], ) +python_unittest( + name = "test_vulkan_passes", + srcs = [ + "test_vulkan_passes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/backends/vulkan/quantizer:vulkan_quantizer", + "//executorch/backends/vulkan:vulkan_preprocess", + ] +) + python_unittest( name = "test_vulkan_delegate_header", srcs = [ diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index 5d08ee57859..b95b7b3aa6d 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -62,7 +62,7 @@ at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { return weights_unpacked; } -at::Tensor dequantize_and_linear( +at::Tensor dequantize_and_linear_qga4w( const at::Tensor& x, const at::Tensor& weights_4x2, const int64_t groupsize, @@ -97,6 +97,56 @@ at::Tensor dequantize_and_linear( return at::linear(x, weights_dequantized); } +at::Tensor dequantize_and_linear_qcs4w( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const at::Tensor& scales) { + std::vector weights_shape(weights_4x2.sizes().vec()); + weights_shape[1] *= 2; + + at::Tensor weights_dequantized = + at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); + + const int64_t N = weights_dequantized.size(0); + const int64_t K = weights_dequantized.size(1); + + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k += 2) { + // const int scale_idx = k_groups * n + group_idx; + const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); + const uint8_t second_val = packed_val & 0x0F; + const uint8_t first_val = (packed_val & 0xF0) >> 4; + + const float scale = scales[n].item().to(); + + weights_dequantized[n][k] = (float(first_val) - 8.0) * scale; + weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale; + } + } + + return at::linear(x, weights_dequantized); +} + +at::Tensor linear_qcs4w_reference_impl( + const at::Tensor& x, + const at::Tensor& weights_4x2, + const at::Tensor& scales) { + const std::vector original_x_size(x.sizes().vec()); + const size_t ndim = original_x_size.size(); + const int64_t out_features = weights_4x2.size(0); + const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); + + const at::Tensor weights_unpacked = + (unpack_weights_4x2(weights_4x2) - 8).to(at::kChar); + at::Tensor out = + at::_weight_int8pack_mm(x_flattened, weights_unpacked, scales); + + std::vector out_shape( + original_x_size.begin(), original_x_size.end()); + out_shape.at(ndim - 1) = out_features; + return out.reshape(out_shape); +} + // // Test functions // @@ -126,12 +176,31 @@ void test_reference_linear_qga4w( scales_and_zeros, inner_k_tiles); - at::Tensor out_ref = dequantize_and_linear( + at::Tensor out_ref = dequantize_and_linear_qga4w( x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); ASSERT_TRUE(at::allclose(out, out_ref)); } +void test_reference_linear_qcs4w( + const int B, + const int M, + const int K, + const int N) { + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + at::Tensor weights_int = unpack_weights_4x2(weights_4x2); + + at::Tensor scales = at::rand({N}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out = linear_qcs4w_reference_impl(x, weights_4x2, scales); + + at::Tensor out_ref = dequantize_and_linear_qcs4w(x, weights_4x2, scales); + + ASSERT_TRUE(at::allclose(out, out_ref)); +} + vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { using namespace vkcompute; switch (at_scalartype) { @@ -265,6 +334,85 @@ void test_vulkan_linear_qga4w( vkcompute::utils::kTexture3D); } +void test_vulkan_linear_qcs4w_impl( + const int B, + const int M, + const int K, + const int N, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor weights_4x2 = + at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); + + at::Tensor scales = at::rand({N}, at::device(at::kCPU).dtype(at::kFloat)); + + at::Tensor out_ref = linear_qcs4w_reference_impl(x, weights_4x2, scales); + + // Build Vulkan graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(utils::kTexture3D); + ComputeGraph graph(config); + +#define MAKE_TENSORREF_FOR(x) \ + ValueRef r_##x = graph.add_tensorref( \ + x.sizes().vec(), \ + from_at_scalartype(x.scalar_type()), \ + x.const_data_ptr()); + + MAKE_TENSORREF_FOR(weights_4x2); + MAKE_TENSORREF_FOR(scales); + + IOValueRef r_x = graph.add_input_tensor( + x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); + + const ValueRef r_out = graph.add_tensor( + out_ref.sizes().vec(), + from_at_scalartype(out_ref.scalar_type()), + out_storage); + + VK_GET_OP_FN("et_vk.linear_qcs4w.default") + (graph, {r_x.value, r_weights_4x2, r_scales, r_out}); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // + // Run model + // + + graph.propagate_resize(); + graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(out_ref); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); +} + +void test_vulkan_linear_qcs4w( + const int B, + const int M, + const int K, + const int N) { + test_vulkan_linear_qcs4w_impl( + B, M, K, N, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + + test_vulkan_linear_qcs4w_impl( + B, M, K, N, vkcompute::utils::kTexture3D, vkcompute::utils::kTexture3D); +} + TEST(VulkanLinearQGA4WTest, test_reference_impl) { test_reference_linear_qga4w( /*B = */ 1, @@ -294,3 +442,33 @@ TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { /*K = */ 256, /*N = */ 256); } + +TEST(VulkanLinearQCS4WTest, test_reference_impl) { + test_reference_linear_qcs4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); +} + +TEST(VulkanLinearQCS4WTest, test_vulkan_impl_small_m) { + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 4, + /*K = */ 128, + /*N = */ 32); + + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 1, + /*K = */ 256, + /*N = */ 256); +} + +TEST(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { + test_vulkan_linear_qcs4w( + /*B = */ 1, + /*M = */ 32, + /*K = */ 32, + /*N = */ 32); +} diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 5fba5ed54cf..b57710974e8 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -15,10 +15,19 @@ from executorch.backends.transforms.convert_dtype_pass import I64toI32 from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner + from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend -from executorch.exir import EdgeCompileConfig -from torch.export import Dim, export, ExportedProgram +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchProgramManager, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer import Quantizer +from torch.export import Dim, export, export_for_training, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -30,11 +39,66 @@ from executorch.extension.pytree import tree_flatten -class TestBackends(unittest.TestCase): - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( +def lower_module( + model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], dynamic_shapes=None +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + ) + + program: ExportedProgram = export( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) + program = export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +class TestVulkanBackend(unittest.TestCase): def assert_outputs_equal( self, model_output, @@ -88,6 +152,59 @@ def assert_outputs_equal( ) ) + def check_no_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + len(et_program.executorch_program.execution_plan[0].delegates), + 0, + ) + return + + def check_vk_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + et_program.executorch_program.execution_plan[0].delegates[0].id, + VulkanBackend.__name__, + ) + + def run_delegated_model_and_check_output( + self, + et_program: ExecutorchProgramManager, + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + test_inputs=None, + first_output_only=False, + ): + executorch_module = _load_for_executorch_from_buffer(et_program.buffer) + inputs_flattened, _ = tree_flatten(sample_inputs) + + model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) + ref_output = model(*sample_inputs) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + + if test_inputs is not None: + for test_input in test_inputs: + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + def lower_module_and_test_output( self, model: torch.nn.Module, @@ -105,80 +222,29 @@ def lower_module_and_test_output( outputs with the outputs of the eager module. """ - def run_test(): - compile_options = {} + # Validate that the model can execute in eager mode + model.eval() + model(*sample_inputs) - # At least model should run in eager mode. - model.eval() - model(*sample_inputs) + edge_program = lower_module(model, sample_inputs, dynamic_shapes=dynamic_shapes) - program: ExportedProgram = export( - model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True - ) + et_program = edge_program.to_executorch() - edge_program = to_edge_transform_and_lower( - program, - compile_config=self._edge_compile_config, - transform_passes=[ - I64toI32(self._edge_compile_config._skip_dim_order), - ], - partitioner=[VulkanPartitioner(compile_options)], - ) - executorch_program = edge_program.to_executorch() - - if expect_no_delegates: - self.assertEqual( - len( - executorch_program.executorch_program.execution_plan[ - 0 - ].delegates - ), - 0, - ) - return - else: - self.assertEqual( - executorch_program.executorch_program.execution_plan[0] - .delegates[0] - .id, - VulkanBackend.__name__, - ) - - executorch_module = _load_for_executorch_from_buffer( - executorch_program.buffer - ) - inputs_flattened, _ = tree_flatten(sample_inputs) + if expect_no_delegates: + self.check_no_delegation(et_program) + return - model_output = executorch_module.run_method( - "forward", tuple(inputs_flattened) - ) - ref_output = model(*sample_inputs) - - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - if test_inputs is not None: - for test_input in test_inputs: - test_inputs_flattened, _ = tree_flatten(test_input) - model_output = executorch_module.run_method( - "forward", tuple(test_inputs_flattened) - ) - ref_output = model(*test_input) + self.check_vk_delegation(et_program) - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - run_test() + self.run_delegated_model_and_check_output( + et_program, + model, + sample_inputs, + atol, + rtol, + test_inputs=test_inputs, + first_output_only=first_output_only, + ) def test_vulkan_backend_add(self): # This test is the simplest test by manually lowering some submodules, we can use paritioner @@ -942,6 +1008,7 @@ def forward(self, x): sample_inputs, ) + @unittest.skip("layer norm compute shader not working with swiftshader") def test_vulkan_backend_native_layer_norm(self): class NativeLayerNormModule(torch.nn.Module): def __init__(self): diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py new file mode 100644 index 00000000000..7572ebd5a5a --- /dev/null +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -0,0 +1,151 @@ +import unittest +from typing import Optional, Tuple + +import torch + +from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform +from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform + +from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_linear_weight_only_qcs_xnn_qconfig, + VulkanQuantizer, +) + +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge + +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer + +################### +## Common Models ## +################### + + +class SingleLinearModule(torch.nn.Module): + def __init__(self, K=256, N=128): + super().__init__() + self.K = K + self.N = N + self.linear = torch.nn.Linear(K, N, bias=False) + + def forward(self, x): + return self.linear(x) + + def get_sample_inputs(self): + sample_inputs = (torch.rand(size=(32, self.K), dtype=torch.float32),) + return sample_inputs + + +########### +## Tests ## +########### + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge( + program, + compile_config=edge_compile_config, + ) + + return edge_program + + +def get_target_canonical_name(node: torch.fx.Node) -> Optional[str]: + if node.op != "call_function": + return None + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name + + +def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> int: + count = 0 + for node in graph_module.graph.nodes: + canonical_name = get_target_canonical_name(node) + if canonical_name is not None and canonical_name == canonical_op_name: + count += 1 + return count + + +class TestVulkanPasses(unittest.TestCase): + + def test_fuse_int8pack_mm(self): + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + quantizer = VulkanQuantizer() + quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8)) + + edge_manager = quantize_and_lower_module( + model, + sample_inputs, + quantizer, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + + def test_fuse_linear_qcs4w(self): + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + quantizer = VulkanQuantizer() + quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4)) + + edge_manager = quantize_and_lower_module( + model, + sample_inputs, + quantizer, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa032cd7b4f..eb949a6ace8 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -14,6 +14,10 @@ VkStorageType, ) +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -22,11 +26,44 @@ from torch.export import ExportedProgram +from torch.export.exported_program import InputKind +from torch.export.graph_signature import TensorArgument + +_DQ_OPS = { + "dequantize_per_tensor.tensor", + "dequantize_per_tensor.default", + "dequantize_per_channel.default", + "dequantize_per_channel_group.default", + "dequantize_per_token.default", + "dequantize_affine.default", +} + ## ## Node type determination ## +def is_dequant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _DQ_OPS + + +def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "dequantize_per_channel.default" + + +def is_linear_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "linear.default" + + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -258,3 +295,35 @@ def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: return get_node_spec_attr(node, "vk_memory_layout") + + +## +## Misc +## + + +def update_program_state_dict( + program: ExportedProgram, + buffer_name: str, + updated_tensor: torch.Tensor, +) -> None: + target_name = None + # Iterate over all the tensors in the graph signature, and find + # the one corresponding to the parameter/buffer name + for input_ in program.graph_signature.input_specs: + if ( + input_.kind in (InputKind.BUFFER, InputKind.PARAMETER) + and isinstance(input_.arg, TensorArgument) + and input_.arg.name == buffer_name + ): + target_name = input_.target + break + + # Assert that we found the parameter/buffer + assert ( + target_name is not None + ), f"could not find {buffer_name} in source program signature" + assert target_name in program.state_dict, f"could not find {target_name}" + + # Finally, overwrite the current tensor with updated tensor + program.state_dict[target_name] = updated_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 188311e5f2c..4200df3e131 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,12 +17,12 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass -from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, ) from executorch.backends.vulkan._passes import ( + FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, @@ -152,7 +152,7 @@ def preprocess( # noqa: C901 [ RemoveRedundantOpsTransform(), AddmmToLinearTransform(), - FuseDequantLinearPass(), + FuseQuantizedOpsTransform(program), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 24c3be2e802..d7b8b3a92b1 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -266,16 +266,12 @@ def get_coreml_quantizer(pt2e_quantize: str): def get_vulkan_quantizer(pt2e_quantize: str): from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_weight_quantization_config, + get_linear_weight_only_qcs_xnn_qconfig, VulkanQuantizer, ) if pt2e_quantize == "vulkan_8w": - config = get_weight_quantization_config( - is_per_channel=True, - weight_qmin=-128, - weight_qmax=127, - ) + config = get_linear_weight_only_qcs_xnn_qconfig(8) else: raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")