diff --git a/backends/transforms/fuse_dequant_linear.py b/backends/transforms/fuse_dequant_linear.py deleted file mode 100644 index 235715ac74f..00000000000 --- a/backends/transforms/fuse_dequant_linear.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class FuseDequantLinearPass(ExportPass): - """ - Fuses weight dequantize_per_channel nodes with linear nodes into - weight_int8pack_mm nodes, for 8-bit weight-only quantization. - - Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm - Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add - """ - - def fuse_dequant_with_linear( - self, - graph_module: torch.fx.GraphModule, - dequant_node: torch.fx.Node, - linear_node: torch.fx.Node, - ) -> None: - activations = linear_node.args[0] - bias = None - if len(linear_node.args) > 2: - bias = linear_node.args[2] - quant_weight = dequant_node.args[0] - scale = dequant_node.args[1] - - with graph_module.graph.inserting_before(linear_node): - weight_int8pack_mm_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten._weight_int8pack_mm.default, - (activations, quant_weight, scale), - ) - if bias: - add_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.aten.add.Tensor, - (weight_int8pack_mm_node, bias), - ) - linear_node.replace_all_uses_with(add_node) - else: - linear_node.replace_all_uses_with(weight_int8pack_mm_node) - graph_module.graph.erase_node(linear_node) - graph_module.graph.erase_node(dequant_node) - - def is_node_target( - self, node: torch.fx.Node, target: torch._ops.OperatorBase - ) -> bool: - return node.op == "call_function" and node.target == target - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - if self.is_node_target(node, exir_ops.edge.aten.linear.default): - weight_node = node.args[1] - if self.is_node_target( - weight_node, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - ): - # only fuse if weight tensor is int8 packed - quant_weight = weight_node.args[0] - if quant_weight.meta["val"].dtype != torch.int8: - continue - self.fuse_dequant_with_linear(graph_module, weight_node, node) - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 66ff9111f52..71980195962 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,21 +77,6 @@ def define_common_targets(): ], ) - runtime.python_library( - name = "fuse_dequant_linear", - srcs = ["fuse_dequant_linear.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 5478ad0eab6..cfe20892994 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -3,6 +3,23 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") oncall("executorch") +runtime.python_library( + name = "fuse_quantized_ops", + srcs = ["fuse_quantized_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/transforms:utils", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/backends/vulkan:utils_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "insert_prepack_nodes", srcs = ["insert_prepack_nodes.py"], @@ -13,6 +30,7 @@ runtime.python_library( "//caffe2:torch", "//executorch/exir:pass_base", "//executorch/backends/vulkan:utils_lib", + "//executorch/backends/vulkan:op_registry", ], ) @@ -110,6 +128,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":fuse_quantized_ops", ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_asserts", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 220afa6a35c..7ff93a6ee38 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -6,6 +6,9 @@ # pyre-strict +from executorch.backends.vulkan._passes.fuse_quantized_ops import ( + FuseQuantizedOpsTransform, +) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, @@ -26,6 +29,7 @@ from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ + "FuseQuantizedOpsTransform", "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "remove_asserts", diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py new file mode 100644 index 00000000000..d510e1d4342 --- /dev/null +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -0,0 +1,229 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Optional, Tuple + +import executorch.backends.vulkan.utils as utils +import torch + +import torch.nn.functional as F + +from executorch.backends.transforms.utils import get_param_tensor, is_param_node +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +################# +## linear_qcnw ## +################# + + +def matches_linear_qcnw_pattern( # noqa: C901 + program: ExportedProgram, node: torch.fx.Node +) -> Optional[Tuple[torch.qscheme, int]]: + """ + Checks if the nodes surrounding a linear node matches the pattern for weight only + quantized linear, where the weight is quantized channelswise to n bits. + + If the graph pattern matches, then return a tuple of (quantization_method, nbits) + describing the type of quantization used for the weights. Otherwise, return None. + """ + if not utils.is_linear_node(node): + return None + + input_node = node.args[0] + weight_node = node.args[1] + + # Type checking + if not isinstance(weight_node, torch.fx.Node): + return None + if not isinstance(input_node, torch.fx.Node): + return None + + # The input arg should not be a dequant node; if it is, then it is indicative that + # dynamically quantized linear should be used instead + if utils.is_dequant_node(input_node): + return None + + # The weight arg should be a dequant node dequantizing the quantized weight + # Furthermore, the op expects per channel quantization of the weight + if not utils.is_dequant_per_channel_node(weight_node): + return None + + orig_weight = weight_node.args[0] + zeros = weight_node.args[2] + + # Type checking + if not isinstance(orig_weight, torch.fx.Node): + return None + if not is_param_node(program, orig_weight): + return None + if not isinstance(zeros, torch.fx.Node): + return None + if not is_param_node(program, zeros): + return None + + zeros_tensor = get_param_tensor(program, zeros) + if not isinstance(zeros_tensor, torch.Tensor): + return None + + quant_method = torch.per_channel_affine + # Check for symmetric quantization, where the zeros used for dequantization will + # actually be all zeros. + if torch.all(zeros_tensor == 0): + quant_method = torch.per_channel_symmetric + + orig_weight_tensor = get_param_tensor(program, orig_weight) + if not isinstance(orig_weight_tensor, torch.Tensor): + return None + # Sanity check the dtype of the quantized weight + if orig_weight_tensor.dtype != torch.int8: + return None + + quant_min = orig_weight_tensor.min().item() + quant_max = orig_weight_tensor.max().item() + # Determine the number of bits the weight has been quantized to + if quant_min >= -8 and quant_max <= 7: + return quant_method, 4 + elif quant_min >= -128 and quant_max <= 127: + return quant_method, 8 + + return None + + +def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: + """ + Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed + weight tensor by packing 2 4-bit values in one unsigned 8-bit value. + + An input weight tensor of shape (M, K) will produce a packed weight tensor of shape + (M, K / 2). + """ + + # Assert we got a properly quantized tensor. + min, max = inp.min().item(), inp.max().item() + assert ( + max <= 7 and min >= -8 + ), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]" + + # Assuming we have a 2d tensor + if inp.ndim != 2: + inp = inp.squeeze() + assert ( + inp.ndim == 2 + ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}" + + # pad ic + if inp.shape[-1] % 2 != 0: + inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) + + # Shape after padding + oc, ic = inp.shape + assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" + + # Adjust inp tensor for zp + inp = inp.to(dtype=torch.uint8) + 8 + + # Prepare the Result tensor + inp = inp.contiguous().view(-1) + return (inp[::2] << 4 | inp[1::2]).view(oc, int(ic / 2)) + + +def fuse_into_linear_qcnw_node( + program: ExportedProgram, + graph_module: torch.fx.GraphModule, + linear_node: torch.fx.Node, + quant_method: torch.qscheme, + nbits: int, +) -> None: + """ + The weight_int8pack_mm operator represents a weight only quantized linear operator, + where the weight tensor has been quantized channelswise to nbits bits. + + After the PT2E quantization flow, the expected graph pattern is + + dq_weight = dequantize(weight, scales) + out = linear(activation, dq_weight, bias?) + + The goal of this function is to condense that sequence into + + out = quantized_linear(activation, dq_weight, scales) + out = out + bias + """ + activation = linear_node.args[0] + dq_weight_node = linear_node.args[1] + assert isinstance(activation, torch.fx.Node) + assert isinstance(dq_weight_node, torch.fx.Node) + + bias = None + if len(linear_node.args) > 2: + bias = linear_node.args[2] + assert isinstance(bias, torch.fx.Node) + + orig_weight = dq_weight_node.args[0] + scale = dq_weight_node.args[1] + + # For 4 bit quantization, pack the weight tensor + if nbits == 4: + assert isinstance(orig_weight, torch.fx.Node) + orig_weight_tensor = get_param_tensor(program, orig_weight) + assert isinstance(orig_weight_tensor, torch.Tensor) + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) + utils.update_program_state_dict( + program, + orig_weight.name, + packed_weight_tensor, + ) + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) + + if nbits == 8 and quant_method == torch.per_channel_symmetric: + op_target = exir_ops.edge.aten._weight_int8pack_mm.default + elif nbits == 4 and quant_method == torch.per_channel_symmetric: + op_target = exir_ops.edge.et_vk.linear_qcs4w.default + else: + raise NotImplementedError( + "only 4 and 8 bits per channel symmetric quant supported for linear_qcnw" + ) + + with graph_module.graph.inserting_before(linear_node): + weight_int8pack_mm_node = graph_module.graph.create_node( + "call_function", + op_target, + (activation, orig_weight, scale), + ) + if bias: + add_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.add.Tensor, + (weight_int8pack_mm_node, bias), + ) + linear_node.replace_all_uses_with(add_node) + else: + linear_node.replace_all_uses_with(weight_int8pack_mm_node) + graph_module.graph.erase_node(linear_node) + graph_module.graph.erase_node(dq_weight_node) + + +class FuseQuantizedOpsTransform(ExportPass): + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.program = exported_program + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + qcnw_details = matches_linear_qcnw_pattern(self.program, node) + if qcnw_details is not None: + qcnw_method, qcnw_nbits = qcnw_details + fuse_into_linear_qcnw_node( + self.program, graph_module, node, qcnw_method, qcnw_nbits + ) + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 0275239a86a..af6fcbfbb14 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -184,6 +184,53 @@ def linear_weight_int4_impl( lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) +################# +## linear_qcs4w ## +################# + + +def linear_qcs4w( + x: torch.Tensor, + weights_4x2: torch.Tensor, + scales: torch.Tensor, +): + original_x_shape = x.shape + x = x.reshape(-1, original_x_shape[-1]) + + unpacked_weights_shape = weights_4x2.shape + out_features = unpacked_weights_shape[0] + in_features = unpacked_weights_shape[1] + + weights_unpacked = torch.empty( + (out_features, in_features * 2), dtype=torch.int8, device=weights_4x2.device + ) + + weights_unpacked[:, ::2] = weights_4x2 >> 4 + weights_unpacked[:, 1::2] = weights_4x2 & 0x0F + + n_bit = 8 + quant_min = -(2 ** (n_bit - 1)) + quant_max = 2 ** (n_bit - 1) - 1 + dq_weights = torch.ops.quantized_decomposed.dequantize_per_channel( + weights_unpacked, + scales, + None, + 0, + quant_min, + quant_max, + torch.int8, + ) + + out = torch.nn.functional.linear(x, dq_weights) + out_shape = original_x_shape[:-1] + (out_features,) + return out.reshape(out_shape) + + +name = "linear_qcs4w" +lib.define(f"{name}(Tensor self, Tensor weight, Tensor scales) -> Tensor") +lib.impl(name, linear_qcs4w, "CompositeExplicitAutograd") +linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) + ###################### ## apply_rotary_emb ## ###################### diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index 2ea3e321dc3..b2f1a658040 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -9,7 +9,7 @@ from __future__ import annotations import functools -from typing import Any, Callable, Dict, Optional +from typing import Callable, Optional import torch from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import ( @@ -18,53 +18,60 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor +from torch.ao.quantization.observer import PerChannelMinMaxObserver from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node __all__ = [ "VulkanQuantizer", - "get_weight_quantization_config", + "get_linear_weight_qcs_qspec", + "get_linear_weight_only_qcs_xnn_qconfig", ] -@functools.lru_cache -def get_weight_quantization_config( - is_per_channel: bool = True, - weight_qmin: int = -128, - weight_qmax: int = 127, -) -> QuantizationConfig: - - weight_qscheme = ( - torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric - ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( - PerChannelMinMaxObserver if is_per_channel else MinMaxObserver - ) - extra_args: Dict[str, Any] = {"eps": 2**-12} +def get_linear_weight_qcs_qspec(quant_bits: int) -> QuantizationSpec: + """ + Return a QuantizationSpec to perform per-channel symmetric (i.e. "qcs") quantization + of weight tensors of linear layers to the number of bits specified by quant_bits. + """ + weight_observer = PerChannelMinMaxObserver + assert quant_bits in { + 8, + 4, + }, f"Unsupported weight quantization bits: {quant_bits}" - weight_quantization_spec = QuantizationSpec( + quant_min = -(2 ** (quant_bits - 1)) + quant_max = 2 ** (quant_bits - 1) - 1 + qscheme = torch.per_channel_symmetric + + return QuantizationSpec( dtype=torch.int8, - quant_min=weight_qmin, - quant_max=weight_qmax, - qscheme=weight_qscheme, + quant_min=quant_min, + quant_max=quant_max, + qscheme=qscheme, ch_axis=0, is_dynamic=False, - observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( - **extra_args - ), + observer_or_fake_quant_ctr=weight_observer, ) - quantization_config = QuantizationConfig( + +@functools.lru_cache +def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfig: + """ + Return a XNNPACKQuantizer QuantizationConfig class instance that specifies + quantizing the weight tensors of linear layers using per-channel symmetric (qcs) + quantization to the number of bits specified by quant_bits. + """ + weight_qspec = get_linear_weight_qcs_qspec(quant_bits) + + return QuantizationConfig( input_activation=None, output_activation=None, - weight=weight_quantization_spec, + weight=weight_qspec, bias=None, is_qat=False, ) - return quantization_config _SUPPORTED_OPS = [ diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index aafc87ad2c3..665fde103fc 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -280,6 +280,7 @@ def define_common_targets(is_fbcode = False): deps = [ "//caffe2:torch", "//executorch/exir:tensor", + "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", "//executorch/backends/vulkan/serialization:lib", ] ) @@ -332,7 +333,6 @@ def define_common_targets(is_fbcode = False): "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", - "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", "//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze", diff --git a/backends/vulkan/test/TARGETS b/backends/vulkan/test/TARGETS index 5ac87892762..8f07040d586 100644 --- a/backends/vulkan/test/TARGETS +++ b/backends/vulkan/test/TARGETS @@ -24,6 +24,19 @@ python_unittest( ], ) +python_unittest( + name = "test_vulkan_passes", + srcs = [ + "test_vulkan_passes.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/backends/vulkan/quantizer:vulkan_quantizer", + "//executorch/backends/vulkan:vulkan_preprocess", + ] +) + python_unittest( name = "test_vulkan_delegate_header", srcs = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 5fba5ed54cf..b57710974e8 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -15,10 +15,19 @@ from executorch.backends.transforms.convert_dtype_pass import I64toI32 from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner + from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend -from executorch.exir import EdgeCompileConfig -from torch.export import Dim, export, ExportedProgram +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchProgramManager, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer import Quantizer +from torch.export import Dim, export, export_for_training, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -30,11 +39,66 @@ from executorch.extension.pytree import tree_flatten -class TestBackends(unittest.TestCase): - _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( +def lower_module( + model: torch.nn.Module, sample_inputs: Tuple[torch.Tensor], dynamic_shapes=None +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + ) + + program: ExportedProgram = export( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + compile_options = {} + edge_compile_config = EdgeCompileConfig( _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. ) + program = export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge_transform_and_lower( + program, + compile_config=edge_compile_config, + transform_passes=[ + I64toI32(edge_compile_config._skip_dim_order), + ], + partitioner=[VulkanPartitioner(compile_options)], + ) + + return edge_program + + +class TestVulkanBackend(unittest.TestCase): def assert_outputs_equal( self, model_output, @@ -88,6 +152,59 @@ def assert_outputs_equal( ) ) + def check_no_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + len(et_program.executorch_program.execution_plan[0].delegates), + 0, + ) + return + + def check_vk_delegation(self, et_program: ExecutorchProgramManager): + self.assertEqual( + et_program.executorch_program.execution_plan[0].delegates[0].id, + VulkanBackend.__name__, + ) + + def run_delegated_model_and_check_output( + self, + et_program: ExecutorchProgramManager, + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + atol=1e-03, + rtol=1e-01, + test_inputs=None, + first_output_only=False, + ): + executorch_module = _load_for_executorch_from_buffer(et_program.buffer) + inputs_flattened, _ = tree_flatten(sample_inputs) + + model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) + ref_output = model(*sample_inputs) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + + if test_inputs is not None: + for test_input in test_inputs: + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + self.assert_outputs_equal( + model_output, + ref_output, + atol=atol, + rtol=rtol, + first_output_only=first_output_only, + ) + def lower_module_and_test_output( self, model: torch.nn.Module, @@ -105,80 +222,29 @@ def lower_module_and_test_output( outputs with the outputs of the eager module. """ - def run_test(): - compile_options = {} + # Validate that the model can execute in eager mode + model.eval() + model(*sample_inputs) - # At least model should run in eager mode. - model.eval() - model(*sample_inputs) + edge_program = lower_module(model, sample_inputs, dynamic_shapes=dynamic_shapes) - program: ExportedProgram = export( - model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True - ) + et_program = edge_program.to_executorch() - edge_program = to_edge_transform_and_lower( - program, - compile_config=self._edge_compile_config, - transform_passes=[ - I64toI32(self._edge_compile_config._skip_dim_order), - ], - partitioner=[VulkanPartitioner(compile_options)], - ) - executorch_program = edge_program.to_executorch() - - if expect_no_delegates: - self.assertEqual( - len( - executorch_program.executorch_program.execution_plan[ - 0 - ].delegates - ), - 0, - ) - return - else: - self.assertEqual( - executorch_program.executorch_program.execution_plan[0] - .delegates[0] - .id, - VulkanBackend.__name__, - ) - - executorch_module = _load_for_executorch_from_buffer( - executorch_program.buffer - ) - inputs_flattened, _ = tree_flatten(sample_inputs) + if expect_no_delegates: + self.check_no_delegation(et_program) + return - model_output = executorch_module.run_method( - "forward", tuple(inputs_flattened) - ) - ref_output = model(*sample_inputs) - - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - if test_inputs is not None: - for test_input in test_inputs: - test_inputs_flattened, _ = tree_flatten(test_input) - model_output = executorch_module.run_method( - "forward", tuple(test_inputs_flattened) - ) - ref_output = model(*test_input) + self.check_vk_delegation(et_program) - self.assert_outputs_equal( - model_output, - ref_output, - atol=atol, - rtol=rtol, - first_output_only=first_output_only, - ) - - run_test() + self.run_delegated_model_and_check_output( + et_program, + model, + sample_inputs, + atol, + rtol, + test_inputs=test_inputs, + first_output_only=first_output_only, + ) def test_vulkan_backend_add(self): # This test is the simplest test by manually lowering some submodules, we can use paritioner @@ -942,6 +1008,7 @@ def forward(self, x): sample_inputs, ) + @unittest.skip("layer norm compute shader not working with swiftshader") def test_vulkan_backend_native_layer_norm(self): class NativeLayerNormModule(torch.nn.Module): def __init__(self): diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py new file mode 100644 index 00000000000..7572ebd5a5a --- /dev/null +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -0,0 +1,151 @@ +import unittest +from typing import Optional, Tuple + +import torch + +from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform +from executorch.backends.vulkan._passes import FuseQuantizedOpsTransform + +from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_linear_weight_only_qcs_xnn_qconfig, + VulkanQuantizer, +) + +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge + +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer import Quantizer + +################### +## Common Models ## +################### + + +class SingleLinearModule(torch.nn.Module): + def __init__(self, K=256, N=128): + super().__init__() + self.K = K + self.N = N + self.linear = torch.nn.Linear(K, N, bias=False) + + def forward(self, x): + return self.linear(x) + + def get_sample_inputs(self): + sample_inputs = (torch.rand(size=(32, self.K), dtype=torch.float32),) + return sample_inputs + + +########### +## Tests ## +########### + + +def quantize_and_lower_module( + model: torch.nn.Module, + sample_inputs: Tuple[torch.Tensor], + quantizer: Quantizer, + dynamic_shapes=None, +) -> EdgeProgramManager: + edge_compile_config = EdgeCompileConfig( + _skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend. + _check_ir_validity=False, + ) + + program = torch.export.export_for_training( + model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + ).module() + + program = prepare_pt2e(program, quantizer) # pyre-ignore + # Calibrate + program(*sample_inputs) + + program = convert_pt2e(program) + + program = torch.export.export(program, sample_inputs, dynamic_shapes=dynamic_shapes) + + edge_program = to_edge( + program, + compile_config=edge_compile_config, + ) + + return edge_program + + +def get_target_canonical_name(node: torch.fx.Node) -> Optional[str]: + if node.op != "call_function": + return None + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name + + +def op_node_count(graph_module: torch.fx.GraphModule, canonical_op_name: str) -> int: + count = 0 + for node in graph_module.graph.nodes: + canonical_name = get_target_canonical_name(node) + if canonical_name is not None and canonical_name == canonical_op_name: + count += 1 + return count + + +class TestVulkanPasses(unittest.TestCase): + + def test_fuse_int8pack_mm(self): + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + quantizer = VulkanQuantizer() + quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(8)) + + edge_manager = quantize_and_lower_module( + model, + sample_inputs, + quantizer, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + self.assertEqual(op_node_count(gm, "_weight_int8pack_mm.default"), 1) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) + + def test_fuse_linear_qcs4w(self): + K = 256 + N = 256 + model = SingleLinearModule(K, N) + sample_inputs = model.get_sample_inputs() + + quantizer = VulkanQuantizer() + quantizer.set_global(get_linear_weight_only_qcs_xnn_qconfig(4)) + + edge_manager = quantize_and_lower_module( + model, + sample_inputs, + quantizer, + ) + + ep = edge_manager._edge_programs["forward"] + edge_manager.transform( + [ + AddmmToLinearTransform(), + FuseQuantizedOpsTransform(ep), + ] + ) + + gm = ep.graph_module + + self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) + self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index fa032cd7b4f..eb949a6ace8 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -14,6 +14,10 @@ VkStorageType, ) +from executorch.exir.backend.canonical_partitioners.config_partitioner import ( + format_target_name, +) + from executorch.exir.tensor import TensorSpec from torch._export.utils import is_buffer, is_param @@ -22,11 +26,44 @@ from torch.export import ExportedProgram +from torch.export.exported_program import InputKind +from torch.export.graph_signature import TensorArgument + +_DQ_OPS = { + "dequantize_per_tensor.tensor", + "dequantize_per_tensor.default", + "dequantize_per_channel.default", + "dequantize_per_channel_group.default", + "dequantize_per_token.default", + "dequantize_affine.default", +} + ## ## Node type determination ## +def is_dequant_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name in _DQ_OPS + + +def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "dequantize_per_channel.default" + + +def is_linear_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return node_name == "linear.default" + + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -258,3 +295,35 @@ def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: return get_node_spec_attr(node, "vk_memory_layout") + + +## +## Misc +## + + +def update_program_state_dict( + program: ExportedProgram, + buffer_name: str, + updated_tensor: torch.Tensor, +) -> None: + target_name = None + # Iterate over all the tensors in the graph signature, and find + # the one corresponding to the parameter/buffer name + for input_ in program.graph_signature.input_specs: + if ( + input_.kind in (InputKind.BUFFER, InputKind.PARAMETER) + and isinstance(input_.arg, TensorArgument) + and input_.arg.name == buffer_name + ): + target_name = input_.target + break + + # Assert that we found the parameter/buffer + assert ( + target_name is not None + ), f"could not find {buffer_name} in source program signature" + assert target_name in program.state_dict, f"could not find {target_name}" + + # Finally, overwrite the current tensor with updated tensor + program.state_dict[target_name] = updated_tensor diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 188311e5f2c..4200df3e131 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,12 +17,12 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass -from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, ) from executorch.backends.vulkan._passes import ( + FuseQuantizedOpsTransform, insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, RemoveRedundantOpsTransform, @@ -152,7 +152,7 @@ def preprocess( # noqa: C901 [ RemoveRedundantOpsTransform(), AddmmToLinearTransform(), - FuseDequantLinearPass(), + FuseQuantizedOpsTransform(program), SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 24c3be2e802..d7b8b3a92b1 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -266,16 +266,12 @@ def get_coreml_quantizer(pt2e_quantize: str): def get_vulkan_quantizer(pt2e_quantize: str): from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( - get_weight_quantization_config, + get_linear_weight_only_qcs_xnn_qconfig, VulkanQuantizer, ) if pt2e_quantize == "vulkan_8w": - config = get_weight_quantization_config( - is_per_channel=True, - weight_qmin=-128, - weight_qmax=127, - ) + config = get_linear_weight_only_qcs_xnn_qconfig(8) else: raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")