diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 5c59f13fc24..a137a7d538f 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -56,52 +56,97 @@ TYPE_MAPPINGS: Dict[str, Any] = { "IMAGE_T": { 3: { + "double": "image3D", "float": "image3D", "half": "image3D", - "int": "iimage3D", - "uint": "uimage3D", + # integer dtypes "int8": "iimage3D", "uint8": "uimage3D", + "int16": "iimage3D", + "uint16": "uimage3D", + "int32": "iimage3D", + "uint32": "uimage3D", + "int64": "iimage3D", + "uint64": "uimage3D", + # common dtype aliases "bool": "uimage3D", + "int": "iimage3D", + "uint": "uimage3D", }, 2: { + "double": "image2D", "float": "image2D", "half": "image2D", - "int": "iimage2D", - "uint": "uimage2D", + # integer dtypes "int8": "iimage2D", "uint8": "uimage2D", + "int16": "iimage2D", + "uint16": "uimage2D", + "int32": "iimage2D", + "uint32": "uimage2D", + "int64": "iimage2D", + "uint64": "uimage2D", + # common dtype aliases "bool": "uimage2D", + "int": "iimage2D", + "uint": "uimage2D", }, }, "SAMPLER_T": { 3: { + "double": "sampler3D", "float": "sampler3D", "half": "sampler3D", - "int": "isampler3D", - "uint": "usampler3D", + # integer dtypes "int8": "isampler3D", "uint8": "usampler3D", + "int16": "isampler3D", + "uint16": "usampler3D", + "int32": "isampler3D", + "uint32": "usampler3D", + "int64": "isampler3D", + "uint64": "usampler3D", + # common dtype aliases "bool": "usampler3D", + "int": "isampler3D", + "uint": "usampler3D", }, 2: { + "double": "sampler2D", "float": "sampler2D", "half": "sampler2D", - "int": "isampler2D", - "uint": "usampler2D", + # integer dtypes "int8": "isampler2D", "uint8": "usampler2D", + "int16": "isampler2D", + "uint16": "usampler2D", + "int32": "isampler2D", + "uint32": "usampler2D", + "int64": "isampler2D", + "uint64": "usampler2D", + # common dtype aliases "bool": "usampler2D", + "int": "isampler2D", + "uint": "usampler2D", }, }, "IMAGE_FORMAT": { + "double": "rgba32f", "float": "rgba32f", "half": "rgba16f", - "int": "rgba32i", - "uint": "rgba32ui", + # integer dtypes "int8": "rgba8i", "uint8": "rgba8ui", + "int16": "rgba16i", + "uint16": "rgba16ui", + "int32": "rgba32i", + "uint32": "rgba32ui", + "int64": "rgba32i", + "uint64": "rgba32ui", + # common dtype aliases "bool": "rgba8ui", + "int": "rgba32i", + "uint": "rgba32ui", }, } @@ -118,10 +163,18 @@ def define_variable(name: str) -> str: def buffer_scalar_type(dtype: str) -> str: if dtype == "half": return "float16_t" - elif dtype[-1] == "8": - return dtype + "_t" + elif dtype == "float": + return "float" + elif dtype == "double": + return "float64_t" + # integer dtype alias conversion elif dtype == "bool": return "uint8_t" + # we don't want to append _t for int32 or uint32 as int is already 32bit + elif dtype == "int32" or dtype == "uint32": + return "int" if dtype == "int32" else "uint" + elif dtype[-1].isdigit(): + return dtype + "_t" return dtype @@ -129,22 +182,28 @@ def buffer_gvec_type(dtype: str, n: int) -> str: if n == 1: return buffer_scalar_type(dtype) - if dtype == "float": - return f"vec{n}" - if dtype == "uint": - return f"uvec{n}" - elif dtype == "half": - return f"f16vec{n}" - elif dtype == "int": - return f"ivec{n}" - elif dtype == "int8": - return f"i8vec{n}" - elif dtype == "uint8": - return f"u8vec{n}" - elif dtype == "bool": - return f"u8vec{n}" - - raise AssertionError(f"Invalid dtype: {dtype}") + dtype_map = { + "half": f"f16vec{n}", + "float": f"vec{n}", + "double": f"vec{n}", # No 64bit image format support in GLSL + "int8": f"i8vec{n}", + "uint8": f"u8vec{n}", + "int16": f"i16vec{n}", + "uint16": f"u16vec{n}", + "int32": f"ivec{n}", + "int": f"ivec{n}", + "uint32": f"uvec{n}", + "uint": f"uvec{n}", + "int64": f"ivec{n}", # No 64bit image format support in GLSL + "uint64": f"uvec{n}", # No 64bit image format support in GLSL + "bool": f"u8vec{n}", + } + + vector_type = dtype_map.get(dtype) + if vector_type is None: + raise AssertionError(f"Invalid dtype: {dtype}") + + return vector_type def texel_type(dtype: str) -> str: @@ -365,15 +424,22 @@ def define_required_extensions(dtypes: Union[str, List[str]]): if dtype == "half": nbit = "16bit" glsl_type = "float16" - elif dtype == "int16" or dtype == "uint16": - nbit = "16bit" - glsl_type = "int16" - elif dtype == "int8" or dtype == "uint8" or dtype == "bool": + elif dtype == "double": + # We only need to allow float64_t type usage + glsl_type = "float64" + elif dtype in ["int8", "uint8", "bool"]: nbit = "8bit" glsl_type = "int8" + elif dtype in ["int16", "uint16"]: + nbit = "16bit" + glsl_type = "int16" + elif dtype in ["int64", "uint64"]: + # We only need to allow int64_t and uint64_t type usage + glsl_type = "int64" - if nbit is not None and glsl_type is not None: + if nbit is not None: out_str += f"#extension GL_EXT_shader_{nbit}_storage : require\n" + if glsl_type is not None: out_str += f"#extension GL_EXT_shader_explicit_arithmetic_types_{glsl_type} : require\n" return out_str @@ -629,6 +695,10 @@ def generateVariantCombinations( elif "VALUE" in value: suffix = value.get("SUFFIX", value["VALUE"]) + if value["VALUE"] in ["int", "uint"]: + raise ValueError( + f"Use int32 or uint32 instead of {value['VALUE']}" + ) param_values.append((param_name, suffix, value["VALUE"])) else: diff --git a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml index e3df8bf73a1..37b2027db85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/arange.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/arange.yaml @@ -7,13 +7,13 @@ arange: parameter_names_with_default_values: NDIM: 3 - DTYPE: int + DTYPE: int32 STORAGE: texture3d PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: arange diff --git a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml index eddddec0d8d..b1e16dec8d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/avg_pool2d.yaml @@ -13,6 +13,6 @@ avg_pool2d: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: avg_pool2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index c0efdd81eb9..accfcf53599 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -17,7 +17,7 @@ binary_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: binary_add - NAME: binary_sub diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml index 9abd9c1deac..e8bb86dbf6a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_buffer.yaml @@ -12,8 +12,9 @@ buffer_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index e48eab63a64..679e686dc2f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -13,9 +13,10 @@ buffer_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: buffer_to_nchw - NAME: buffer_to_nchw_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml index 414bf8191b9..984d9a09d43 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_channel_offset.yaml @@ -7,6 +7,6 @@ copy_channel_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_channel_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml index 87df7bf9dc1..09f5ca36ea4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -7,7 +7,7 @@ copy_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml index e872d64e3c3..6e55876cb28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_packed_dim_offset.yaml @@ -7,6 +7,6 @@ copy_packed_dim_offset: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: copy_packed_dim_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml index 5ffe37265b1..0e7b491c433 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding.yaml @@ -7,6 +7,6 @@ embedding: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: embedding diff --git a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml index 646fd05e420..f5e7c874773 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/flip.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/flip.yaml @@ -6,8 +6,9 @@ flip: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: flip diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 804ce19bdb8..646d8f1be81 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -14,9 +14,10 @@ image_to_nchw: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index 5a6c525993e..abef2225cd9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -7,6 +7,6 @@ index_select: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index 66cb7ec3f89..a306e3ce47d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -7,6 +7,6 @@ index_select_channel: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 486d710cf55..99e41a0ab6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -13,9 +13,10 @@ nchw_to_buffer: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_buffer - NAME: nchw_to_buffer_no_pc diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 4674822ce6a..f3f604e10cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -87,5 +87,9 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); + $if DTYPE == "double" and DTYPE == "int64": + VEC4_T texel = read_texel(tidx); + write_texel(t_out, lpos_to_pos(lpos, axis_map), texel); + $else: + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 7e52ec10376..85119c8d508 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -14,9 +14,10 @@ nchw_to_image: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: double - VALUE: int8 - VALUE: uint8 + - VALUE: int32 shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml index e64e1bd260a..bfeaba2496b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/no_op.yaml @@ -12,7 +12,7 @@ no_op: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 STORAGE: diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml index f678aeedf6e..a90ddcb41ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.yaml @@ -7,6 +7,6 @@ permute: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: permute diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml index 526980a0f41..f40d94142e1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat.yaml @@ -7,7 +7,7 @@ repeat: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 - VALUE: uint8 shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index f13393ce6c7..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -15,9 +15,9 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_int + - NAME: clamp_int32 OPERATOR: clamp(X, A, B) - DTYPE: int + DTYPE: int32 - NAME: cos OPERATOR: cos(X) - NAME: exp diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index ba11a2496a0..33364a25225 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -7,6 +7,6 @@ view: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index e1ac4e9d40a..6388a8ad091 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -34,24 +34,42 @@ void add_storage_type_suffix( void add_dtype_suffix(std::string& kernel_name, const vkapi::ScalarType dtype) { switch (dtype) { + case vkapi::kDouble: + kernel_name += "_double"; + break; case vkapi::kFloat: kernel_name += "_float"; break; case vkapi::kHalf: kernel_name += "_half"; break; - case vkapi::kInt: - kernel_name += "_int"; - break; case vkapi::kChar: case vkapi::kQInt8: kernel_name += "_int8"; break; case vkapi::kByte: - case vkapi::kQUInt8: case vkapi::kBool: + case vkapi::kQUInt8: kernel_name += "_uint8"; break; + case vkapi::kShort: + kernel_name += "_int16"; + break; + case vkapi::kUInt16: + kernel_name += "_uint16"; + break; + case vkapi::kInt: + kernel_name += "_int32"; + break; + case vkapi::kUInt: + kernel_name += "_uint32"; + break; + case vkapi::kLong: + kernel_name += "_int64"; + break; + case vkapi::kUInt64: + kernel_name += "_uint64"; + break; default: break; } diff --git a/backends/vulkan/runtime/vk_api/Types.h b/backends/vulkan/runtime/vk_api/Types.h index f25fe95d72b..b3309aa6c69 100644 --- a/backends/vulkan/runtime/vk_api/Types.h +++ b/backends/vulkan/runtime/vk_api/Types.h @@ -30,11 +30,17 @@ #define VK_FORALL_SCALAR_TYPES(_) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Bool) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_UINT, UInt16) \ + _(int16_t, VK_FORMAT_R16G16B16A16_SINT, Short) \ + _(uint32_t, VK_FORMAT_R32G32B32A32_UINT, UInt) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(uint64_t, VK_FORMAT_R64G64B64A64_UINT, UInt64) \ + _(int64_t, VK_FORMAT_R64G64B64A64_SINT, Long) \ _(float, VK_FORMAT_FLOAT4, Float) \ + _(double, VK_FORMAT_R64G64B64A64_SFLOAT, Double) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) @@ -86,17 +92,29 @@ inline VkFormat to_vkformat(const ScalarType t) { */ inline ScalarType element_scalartype(const VkFormat vkformat) { switch (vkformat) { + case VK_FORMAT_R64G64B64A64_SFLOAT: + return kDouble; + case VK_FORMAT_R32G32B32A32_SFLOAT: + return kFloat; + case VK_FORMAT_R16G16B16A16_SFLOAT: + return kHalf; case VK_FORMAT_R8G8B8A8_SINT: return kChar; case VK_FORMAT_R8G8B8A8_UINT: case VK_FORMAT_R8G8B8A8_UNORM: return kByte; + case VK_FORMAT_R16G16B16A16_SINT: + return kShort; + case VK_FORMAT_R16G16B16A16_UINT: + return kUInt16; case VK_FORMAT_R32G32B32A32_SINT: return kInt; - case VK_FORMAT_R32G32B32A32_SFLOAT: - return kFloat; - case VK_FORMAT_R16G16B16A16_SFLOAT: - return kHalf; + case VK_FORMAT_R32G32B32A32_UINT: + return kUInt; + case VK_FORMAT_R64G64B64A64_SINT: + return kLong; + case VK_FORMAT_R64G64B64A64_UINT: + return kUInt64; default: VK_THROW("No corresponding scalar type for unknown VkFormat: ", vkformat); } diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 37403c97ac8..4ef934eb105 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -51,7 +51,7 @@ idx_fill_texture: DTYPE: - VALUE: half - VALUE: float - - VALUE: int + - VALUE: int32 - VALUE: int8 shader_variants: - NAME: idx_fill_texture diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp new file mode 100644 index 00000000000..24c856e9d46 --- /dev/null +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -0,0 +1,675 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <gtest/gtest.h> + +#include <ATen/ATen.h> + +#include <executorch/backends/vulkan/runtime/api/api.h> +#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> +#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> + +#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h> +#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> + +#include "test_utils.h" + +#include <cassert> +#include <iostream> + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out); + +// Wrapper function for choose_qparams_tensor_out without context +Tensor& choose_qparams_tensor_out_no_context( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ET_UNUSED double eps, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_tensor_out( + input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); + return scale_out; +} + +// Wrapper function for choose_qparams_per_token_asymmetric_out without context +Tensor& choose_qparams_per_token_asymmetric_out_no_context( + const Tensor& input, + ScalarType dtype, + Tensor& scale_out, + Tensor& zero_point_out) { + torch::executor::native::choose_qparams_per_token_asymmetric_out( + input, dtype, scale_out, zero_point_out); + return scale_out; +} + +// ATen wrapper for choose_qparams_tensor +std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_aten( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + double eps = 1e-7; + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) + (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +// ATen wrapper for choose_qparams_per_token_asymmetric +std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten( + const at::Tensor& input, + at::ScalarType dtype) { + // Calculate output sizes for scale and zero_point tensors + std::vector<int64_t> output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + auto scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + auto zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + // Use WRAP_TO_ATEN with the wrapper function + WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) + (input, et_dtype, scale_out, zero_point_out); + + return {scale_out, zero_point_out}; +} + +} // namespace native +} // namespace executor +} // namespace torch + +// +// Reference Implementation +// + +/* + * Reference implementation of choose_qparams_tensor + */ +std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_reference_impl( + const at::Tensor& input, + int64_t quant_min, + int64_t quant_max) { + // Create output tensors + at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty({}, at::device(at::kCPU).dtype(at::kLong)); + + // Find min and max values in the input tensor + float min_val = input.min().item<float>(); + float max_val = input.max().item<float>(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast<double>(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = quant_min - min_val / static_cast<double>(scale); + double zero_point_from_max = quant_max - max_val / static_cast<double>(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::nearbyint(static_cast<float>(initial_zero_point)); + } + + // Set output values - use item_mutable() for scalar tensors + scale_out.fill_(scale); + zero_point_out.fill_(nudged_zero_point); + + return std::make_tuple(scale_out, zero_point_out); +} + +/* + * Reference implementation of choose_qparams_per_token_asymmetric + */ +std::tuple<at::Tensor, at::Tensor> +choose_qparams_per_token_asymmetric_reference_impl( + const at::Tensor& input, + at::ScalarType dtype) { + // For per-token quantization, we need to compute scale and zero_point for + // each token + int64_t quant_min = -128; + int64_t quant_max = 127; + + // Calculate output sizes + std::vector<int64_t> output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Create output tensors + at::Tensor scale_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_out = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); + + // Calculate number of tokens + int64_t num_tokens = 1; + for (int64_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + + // Process each token + for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { + at::Tensor token = reshaped_input[token_idx]; + + // Find min and max values for this token + float min_val = token.min().item<float>(); + float max_val = token.max().item<float>(); + + // Extend the [min, max] interval to ensure it contains 0 + min_val = std::min(min_val, 0.f); + max_val = std::max(max_val, 0.f); + + // Calculate scale + double scale = + (static_cast<double>(max_val) - min_val) / (quant_max - quant_min); + + // Handle small scale + constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust min and max based on new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Calculate zero point + double zero_point_from_min = + quant_min - min_val / static_cast<double>(scale); + double zero_point_from_max = + quant_max - max_val / static_cast<double>(scale); + double zero_point_from_min_error = + std::abs(quant_min) - std::abs(min_val / static_cast<double>(scale)); + double zero_point_from_max_error = + std::abs(quant_max) - std::abs(max_val / static_cast<double>(scale)); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int64_t nudged_zero_point = 0; + if (initial_zero_point < quant_min) { + nudged_zero_point = quant_min; + } else if (initial_zero_point > quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = + std::nearbyint(static_cast<float>(initial_zero_point)); + } + + // Set output values for this token - use index_put_ for safety + scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); + zero_point_out.view({num_tokens, 1}) + .index_put_({token_idx, 0}, nudged_zero_point); + } + + return std::make_tuple(scale_out, zero_point_out); +} + +// Forward declaration of implementation functions +void test_vulkan_choose_qparams_tensor_impl( + const std::vector<int>& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector<int>& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_tensor( + const std::vector<int>& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_tensor_impl( + input_sizes, + quant_min, + quant_max, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_per_token_asymmetric( + const std::vector<int>& input_sizes, + at::ScalarType dtype) { + // Test with buffer storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_choose_qparams_per_token_asymmetric_impl( + input_sizes, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_choose_qparams_tensor( + const std::vector<int>& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_tensor_reference_impl(input, quant_min, quant_max); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_tensor_impl( + const std::vector<int>& input_sizes, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_tensor_aten( + input, quant_min, quant_max, dtype); + + // Build Vulkan choose_qparams_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min); + const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max); + + // Output tensors + const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams.tensor") + (graph, + { + r_input.value, + r_quant_min, + r_quant_max, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); + at::Tensor vk_zero_point = + at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + // make sure that there arent a ton of elements in the input tensor + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { + test_reference_choose_qparams_tensor( + {2, 3, 4}, // input sizes + -128, // quant_min + 127, // quant_max + at::kChar); +} + +void test_reference_choose_qparams_per_token_asymmetric( + const std::vector<int>& input_sizes, + at::ScalarType dtype) { + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Get reference output + auto [reference_scale, reference_zero_point] = + choose_qparams_per_token_asymmetric_reference_impl(input, dtype); + + // Get implementation output + auto [impl_scale, impl_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale, impl_scale); + const bool zero_point_correct = + at::equal(reference_zero_point, impl_zero_point); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "implementation scale:" << std::endl; + std::cout << impl_scale << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "implementation zero_point:" << std::endl; + std::cout << impl_zero_point << std::endl; + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +void test_vulkan_choose_qparams_per_token_asymmetric_impl( + const std::vector<int>& input_sizes, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); + + // Calculate output sizes + std::vector<int64_t> output_sizes; + for (int64_t i = 0; i < input.dim() - 1; i++) { + output_sizes.push_back(input.size(i)); + } + output_sizes.push_back(1); + + // Get reference output + auto [reference_scale, reference_zero_point] = + torch::executor::native::choose_qparams_per_token_asymmetric_aten( + input, dtype); + + // Build Vulkan choose_qparams_per_token_asymmetric graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Output tensors + const ValueRef r_scale = + graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); + const ValueRef r_zero_point = + graph.add_tensor(output_sizes, vkapi::kInt, out_storage); + + VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan choose_qparams_per_token_asymmetric + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + // Create output tensors to hold the results - use types that match GPU output + at::Tensor vk_scale = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) + .contiguous(); + at::Tensor vk_zero_point = + at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) + .contiguous(); + + // Copy results from GPU to CPU + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Convert reference values to match Vulkan output types for comparison + at::Tensor reference_scale_float = reference_scale.to(at::kFloat); + at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); + + // Compare outputs + const bool scale_correct = at::allclose(reference_scale_float, vk_scale); + const bool zero_point_correct = + at::equal(reference_zero_point_int, vk_zero_point); + if (!scale_correct || !zero_point_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (input.numel() < 100) { + std::cout << "input:" << std::endl; + std::cout << input << "\n" << std::endl; + std::cout << "reference scale:" << std::endl; + std::cout << reference_scale << std::endl; + std::cout << "vulkan scale:" << std::endl; + std::cout << vk_scale << "\n" << std::endl; + std::cout << "reference zero_point:" << std::endl; + std::cout << reference_zero_point << std::endl; + std::cout << "vulkan zero_point:" << std::endl; + std::cout << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct && zero_point_correct); +} + +TEST( + VulkanChooseQparamsTest, + test_reference_choose_qparams_per_token_asymmetric_int8) { + test_reference_choose_qparams_per_token_asymmetric( + {2, 3, 4}, // input sizes (2*3=6 tokens) + at::kChar); +} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp new file mode 100644 index 00000000000..7b155c8f98b --- /dev/null +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -0,0 +1,1061 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <gtest/gtest.h> + +#include <ATen/ATen.h> + +#include <executorch/backends/vulkan/runtime/api/api.h> +#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> +#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> + +#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h> +#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> + +#include "test_utils.h" + +#include <cassert> +#include <iostream> +#include <limits> + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional<ScalarType> out_dtype, + Tensor& out); + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out); + +// Wrapper function for dequantize_per_tensor_out without context +Tensor& dequantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + executorch::aten::optional<ScalarType> out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +// Wrapper function for dequantize_per_token_out without context +Tensor& dequantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) { + return torch::executor::native::dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +// ATen wrapper for dequantize_per_tensor +at::Tensor dequantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + executorch::aten::optional<ScalarType> opt_et_out_dtype(et_out_dtype); + + WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) + (input, + scale, + zero_point, + quant_min, + quant_max, + et_dtype, + opt_et_out_dtype, + out); + return out; +} + +// ATen wrapper for dequantize_per_token +at::Tensor dequantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + auto out = at::empty_like(input, out_dtype); + // Convert at::ScalarType to executorch::ScalarType + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); + + WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) + (input, + scale, + zero_points, + quant_min, + quant_max, + et_dtype, + et_out_dtype, + out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_dequantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType in_dtype, + c10::ScalarType out_dtype) { + using namespace vkcompute; + + // Check that quant_min <= quant_max + VK_CHECK_COND( + quant_min <= quant_max, + "quant_min must be <= quant_max, got quant_min: ", + quant_min, + " quant_max: ", + quant_max); + + // Check that input dtype is a quantized type + switch (in_dtype) { + case c10::kByte: + case c10::kChar: + case c10::kShort: + case c10::kInt: + case c10::kLong: + break; + default: + VK_THROW( + "Unsupported input dtype: ", + scalar_type_name(in_dtype), + " (", + static_cast<int>(in_dtype), + ")"); + } + + // Check that output dtype is a floating point type + switch (out_dtype) { + case c10::kHalf: + case c10::kFloat: + case c10::kDouble: + break; + default: + VK_THROW( + "Unsupported output dtype: ", + scalar_type_name(out_dtype), + " (", + static_cast<int>(out_dtype), + ")"); + } +} + +// +// Reference Implementation +// + +/* + * Reference implementation of dequantize_per_tensor + */ +at::Tensor dequantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Dequantize the input tensor + at::Tensor flat_input = input.flatten(); + at::Tensor flat_out = out.flatten(); + + // Store casted values to avoid repeated casting + const int32_t zero_point_int32 = static_cast<int32_t>(zero_point); + const float scale_float = static_cast<float>(scale); + + for (int i = 0; i < flat_input.numel(); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = flat_input[i].item<uint8_t>(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kChar) { + int8_t qvalue = flat_input[i].item<int8_t>(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kShort) { + int16_t qvalue = flat_input[i].item<int16_t>(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kInt) { + int32_t qvalue = flat_input[i].item<int32_t>(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } else if (dtype == at::kLong) { + int64_t qvalue = flat_input[i].item<int64_t>(); + dequantized_value = (qvalue - zero_point_int32) * scale_float; + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + flat_out[i] = static_cast<float>(dequantized_value); + } else if (out_dtype == at::kDouble) { + flat_out[i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + flat_out[i] = static_cast<c10::Half>(dequantized_value); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of dequantize_per_token + */ +at::Tensor dequantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Dequantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Get scale and zero_point for this token + float token_scale = scale[token_idx].item<float>(); + int64_t token_zero_point = zero_point[token_idx].item<int64_t>(); + + // Store casted values to avoid repeated casting + const int32_t token_zero_point_int32 = + static_cast<int32_t>(token_zero_point); + + // Dequantize the token + for (int i = 0; i < input.size(-1); i++) { + double dequantized_value = 0.0; + + // Extract quantized value and dequantize based on input dtype + // Following the CPU implementation pattern: (input - zero_point) * scale + if (dtype == at::kByte) { + uint8_t qvalue = reshaped_input[token_idx][i].item<uint8_t>(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kChar) { + int8_t qvalue = reshaped_input[token_idx][i].item<int8_t>(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kShort) { + int16_t qvalue = reshaped_input[token_idx][i].item<int16_t>(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kInt) { + int32_t qvalue = reshaped_input[token_idx][i].item<int32_t>(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else if (dtype == at::kLong) { + int64_t qvalue = reshaped_input[token_idx][i].item<int64_t>(); + dequantized_value = (qvalue - token_zero_point_int32) * token_scale; + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + // Store result based on output dtype + if (out_dtype == at::kFloat) { + reshaped_out[token_idx][i] = static_cast<float>(dequantized_value); + } else if (out_dtype == at::kDouble) { + reshaped_out[token_idx][i] = dequantized_value; + } else if (out_dtype == at::kHalf) { + reshaped_out[token_idx][i] = static_cast<c10::Half>(dequantized_value); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_dequantize_per_tensor_impl( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_dequantize_per_token_impl( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_tensor( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_token( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_dequantize_per_tensor( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast<int64_t>(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = dequantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_tensor_impl( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + + // Create a quantized input tensor with values from quant_min to quant_max + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + float step = 1.0f; + if (input.numel() > 1) { + step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1); + } + + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + flat_input[i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + flat_input[i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + flat_input[i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + flat_input[i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + flat_input[i] = static_cast<int64_t>(qvalue); + } + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = + torch::executor::native::dequantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + + const ValueRef r_scale = graph.add_scalar<double>(scale); + const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point); + const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min); + const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan dequantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_tensor +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_float) { + test_reference_dequantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 5, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int8_to_float) { + test_reference_dequantize_per_tensor( + {3, 4, 5}, // input sizes + 0.05, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_float) { + test_reference_dequantize_per_tensor( + {4, 6, 2}, // input sizes + 0.2, // scale + 2, // zero_point + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_uint8_to_half) { + test_reference_dequantize_per_tensor( + {7, 4}, // input sizes + 0.1, // scale + 10, // zero_point + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype (uint8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTensorTest, + test_reference_dequantize_per_tensor_int32_to_half) { + test_reference_dequantize_per_tensor( + {2, 6, 5}, // input sizes + 0.3, // scale + -10, // zero_point + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} + +void test_reference_dequantize_per_token( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast<float>(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast<int64_t>(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = dequantize_per_token_reference_impl( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_token_impl( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast<float>(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast<int64_t>(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( + input, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + dtype, + out_dtype); + + // Build Vulkan dequantize_per_token graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + + const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min); + const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_token +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_uint8_to_float) { + std::vector<float> scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector<int> zero_points = {5, 10, 15, 20, 25, 30}; + + test_reference_dequantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_float) { + std::vector<float> scales = {0.05, 0.1, 0.15, 0.2}; + std::vector<int> zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 5}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_float) { + std::vector<float> scales = {0.05, 0.1, 0.15, 0.2}; + std::vector<int> zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 10}, // input sizes (2*2=4 tokens) + scales, + zero_points, + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kInt, // input dtype + at::kFloat); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int8_to_half) { + std::vector<float> scales = {0.05, 0.1, 0.15, 0.2}; + std::vector<int> zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {4, 1, 5}, // input sizes (4*1=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype (int8) + at::kHalf); // output dtype +} + +TEST( + VulkanDequantizePerTokenTest, + test_reference_dequantize_per_token_int32_to_half) { + std::vector<float> scales = {0.05, 0.1}; + std::vector<int> zero_points = {0, -5}; + + test_reference_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kInt, // input dtype + at::kHalf); // output dtype +} diff --git a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp index b95b7b3aa6d..e48042c4620 100644 --- a/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp +++ b/backends/vulkan/test/op_tests/linear_weight_int4_test.cpp @@ -14,6 +14,8 @@ #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> +#include "test_utils.h" + #include <cassert> // @@ -201,26 +203,6 @@ void test_reference_linear_qcs4w( ASSERT_TRUE(at::allclose(out, out_ref)); } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_linear_qga4w_impl( const int B, const int M, diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp new file mode 100644 index 00000000000..8b79dc1ce6b --- /dev/null +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -0,0 +1,843 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include <gtest/gtest.h> + +#include <ATen/ATen.h> + +#include <executorch/backends/vulkan/runtime/api/api.h> +#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> +#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> + +#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h> +#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> + +#include "test_utils.h" + +#include <cassert> +#include <iostream> + +namespace torch { +namespace executor { +namespace native { + +// Forward declarations of the functions we're testing +Tensor& quantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + +// Wrapper function for quantize_per_tensor_out without context +Tensor& quantize_per_tensor_out_no_context( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// Wrapper function for quantize_per_token_out without context +Tensor& quantize_per_token_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +// ATen wrapper for quantize_per_tensor +at::Tensor quantize_per_tensor_aten( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +// ATen wrapper for quantize_per_token +at::Tensor quantize_per_token_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) + (input, scale, zero_point, quant_min, quant_max, et_dtype, out); + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch + +void check_quantize_args( + int64_t quant_min, + int64_t quant_max, + c10::ScalarType out_dtype) { + using namespace vkcompute; + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + switch (out_dtype) { + case c10::kByte: + quant_min_lower_bound = + static_cast<int32_t>(std::numeric_limits<uint8_t>::min()); + quant_max_upper_bound = + static_cast<int32_t>(std::numeric_limits<uint8_t>::max()); + break; + case c10::kChar: + quant_min_lower_bound = + static_cast<int32_t>(std::numeric_limits<int8_t>::min()); + quant_max_upper_bound = + static_cast<int32_t>(std::numeric_limits<int8_t>::max()); + break; + case c10::kBits16: + case c10::kUInt16: + quant_min_lower_bound = std::numeric_limits<uint16_t>::min(); + quant_max_upper_bound = std::numeric_limits<uint16_t>::max(); + break; + case c10::kShort: + quant_min_lower_bound = std::numeric_limits<int16_t>::min(); + quant_max_upper_bound = std::numeric_limits<int16_t>::max(); + break; + case c10::kInt: + quant_min_lower_bound = std::numeric_limits<int32_t>::min(); + quant_max_upper_bound = std::numeric_limits<int32_t>::max(); + break; + default: + VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); + } + VK_CHECK_COND( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: ", + quant_min_lower_bound, + " actual quant_min: ", + quant_min); + + VK_CHECK_COND( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: ", + quant_max_upper_bound, + " actual quant_max: ", + quant_max); +} + +// +// Reference Implementation +// + +/* + * Reference implementation of quantize_per_tensor + */ +at::Tensor quantize_per_tensor_reference_impl( + const at::Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Quantize the input tensor + float inv_scale = 1.0 / scale; + + // Iterate through the tensor and quantize each element + at::Tensor float_input = input.to(at::kFloat); + at::Tensor float_values = float_input.flatten(); + + auto out_flat = out.flatten(); + + for (int i = 0; i < float_values.numel(); i++) { + float value = float_values[i].item<float>(); + int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max<int64_t>(qvalue, quant_min); + qvalue = std::min<int64_t>(qvalue, quant_max); + + if (dtype == at::kByte) { + out_flat[i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + out_flat[i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + out_flat[i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + out_flat[i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + out_flat[i] = static_cast<int64_t>(qvalue); + } + } + + return out.reshape(input.sizes()); +} + +/* + * Reference implementation of quantize_per_token + */ +at::Tensor quantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point + // tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Quantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Use float for scale since Vulkan doesn't support double + float token_scale = scale[token_idx].item<float>(); + // Use int for zero_point since Vulkan doesn't support int64_t + int token_zero_point = zero_point[token_idx].item<int>(); + + float inv_scale = 1.0 / token_scale; + + // Quantize the token + for (int i = 0; i < input.size(-1); i++) { + float value = reshaped_input[token_idx][i].item<float>(); + int qvalue = token_zero_point + std::nearbyint(inv_scale * value); + + qvalue = std::max<int64_t>(qvalue, quant_min); + qvalue = std::min<int64_t>(qvalue, quant_max); + + if (dtype == at::kByte) { + reshaped_out[token_idx][i] = static_cast<uint8_t>(qvalue); + } else if (dtype == at::kChar) { + reshaped_out[token_idx][i] = static_cast<int8_t>(qvalue); + } else if (dtype == at::kShort) { + reshaped_out[token_idx][i] = static_cast<int16_t>(qvalue); + } else if (dtype == at::kInt) { + reshaped_out[token_idx][i] = static_cast<int32_t>(qvalue); + } else if (dtype == at::kLong) { + reshaped_out[token_idx][i] = static_cast<int64_t>(qvalue); + } + } + } + + return out; +} + +// Forward declaration of implementation functions +void test_vulkan_quantize_per_tensor_impl( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +void test_vulkan_quantize_per_token_impl( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype, + at::ScalarType dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_tensor( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_per_tensor_impl( + input_sizes, + scale, + zero_point, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_quantize_per_token( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + // Test with buffer storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_quantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +void test_reference_quantize_per_tensor( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Get reference output + at::Tensor reference_out = quantize_per_tensor_reference_impl( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - impl_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_tensor_impl( + const std::vector<int>& input_sizes, + float scale, + int zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::quantize_per_tensor_aten( + input, scale, zero_point, quant_min, quant_max, dtype); + + // Build Vulkan quantize_per_tensor graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + const ValueRef r_scale = graph.add_scalar<double>(scale); + const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point); + const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min); + const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_tensor.default") + (graph, + { + r_input.value, + r_scale, + r_zero_point, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Run Vulkan quantize_per_tensor + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + graph.execute(); + + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + // For quantized types, we need to compare the actual integer values + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale: " << scale << std::endl; + std::cout << " zero_point: " << zero_point << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.1, // scale + 0, // zero_point + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_float_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.04, // scale + 5, // zero_point + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_uint8) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.2, // scale + 2, // zero_point + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_tensor_half_to_int32) { + test_reference_quantize_per_tensor( + {2, 3, 4}, // input sizes + 0.01, // scale + 1, // zero_point + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kHalf, + at::kInt); +} + +void test_reference_quantize_per_token( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0 / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scales and zero_points + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = quantize_per_token_reference_impl( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor impl_int = impl_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, impl_out); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "my_reference:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_quantize_per_token_impl( + const std::vector<int>& input_sizes, + const std::vector<float>& scales, + const std::vector<int>& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + check_quantize_args(quant_min, quant_max, dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with random values + std::vector<int64_t> input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output to show what we would compare against + at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + + const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min); + const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(dtype), out_storage); + + VK_GET_OP_FN("quantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + at::Tensor reference_int = reference_out.to(at::kInt); + at::Tensor vk_int = vk_out.to(at::kInt); + + const bool output_correct = at::equal(reference_int, vk_int); + if (!output_correct) { + at::Tensor diffs = at::abs(reference_int - vk_int); + + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_int << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int8) { + std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector<int> zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_float_to_int32) { + std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector<int> zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kFloat, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_int32) { + std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector<int> zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + std::numeric_limits<int32_t>::min(), // quant_min + std::numeric_limits<int32_t>::max(), // quant_max + at::kHalf, + at::kInt); +} + +TEST( + VulkanQuantizePerTensorTest, + test_reference_quantize_per_token_half_to_uint8) { + std::vector<float> scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; + std::vector<int> zero_points = {1, 2, 3, 0, -1, -2}; + + test_reference_quantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kHalf, + at::kByte); +} diff --git a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp index 534bb577e7a..eebbb89ab40 100644 --- a/backends/vulkan/test/op_tests/rotary_embedding_test.cpp +++ b/backends/vulkan/test/op_tests/rotary_embedding_test.cpp @@ -14,6 +14,8 @@ #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h> +#include "test_utils.h" + #include <cassert> // @@ -55,26 +57,6 @@ std::pair<at::Tensor, at::Tensor> rotary_embedding_impl( // Test functions // -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - case c10::kByte: - return vkapi::kByte; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_reference( const int n_heads = 4, const int n_kv_heads = 2, diff --git a/backends/vulkan/test/op_tests/sdpa_test.cpp b/backends/vulkan/test/op_tests/sdpa_test.cpp index 772039eda6a..79b679674a5 100644 --- a/backends/vulkan/test/op_tests/sdpa_test.cpp +++ b/backends/vulkan/test/op_tests/sdpa_test.cpp @@ -18,6 +18,8 @@ #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h> #include <executorch/extension/llm/custom_ops/op_sdpa.h> +#include "test_utils.h" + #include <cassert> #include <iostream> @@ -261,24 +263,6 @@ void test_reference_sdpa( } } -vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { - using namespace vkcompute; - switch (at_scalartype) { - case c10::kFloat: - return vkapi::kFloat; - case c10::kHalf: - return vkapi::kHalf; - case c10::kInt: - return vkapi::kInt; - case c10::kLong: - return vkapi::kInt; - case c10::kChar: - return vkapi::kChar; - default: - VK_THROW("Unsupported at::ScalarType!"); - } -} - void test_vulkan_sdpa( const int start_input_pos, const int base_sequence_len, diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index 5c9afa40762..0d014c7ef29 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -142,6 +142,28 @@ def define_common_targets(is_fbcode = False): platforms = get_platforms(), ) + runtime.cxx_library( + name = "test_utils", + srcs = [ + "test_utils.cpp", + ], + headers = [ + "test_utils.h", + ], + exported_headers = [ + "test_utils.h", + ], + deps = [ + "//executorch/backends/vulkan:vulkan_graph_runtime", + "//executorch/runtime/core/exec_aten:lib", + runtime.external_dep_location("libtorch"), + ], + visibility = [ + "//executorch/backends/vulkan/test/op_tests/...", + "@EXECUTORCH_CLIENTS", + ], + ) + define_test_targets( "compute_graph_op_tests", src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" @@ -150,9 +172,47 @@ def define_common_targets(is_fbcode = False): define_test_targets( "sdpa_test", extra_deps = [ + ":test_utils", "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", ] ) - define_test_targets("linear_weight_int4_test") - define_test_targets("rotary_embedding_test") + define_test_targets( + "quantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_quantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "dequantize_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_dequantize", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "choose_qparams_test", + extra_deps = [ + ":test_utils", + "//executorch/kernels/quantized/cpu:op_choose_qparams", + "//executorch/extension/tensor:tensor", + "//executorch/extension/aten_util:aten_bridge", + ] + ) + define_test_targets( + "linear_weight_int4_test", + extra_deps = [ + ":test_utils", + ] + ) + define_test_targets( + "rotary_embedding_test", + extra_deps = [ + ":test_utils", + ] + ) diff --git a/backends/vulkan/test/op_tests/test_utils.cpp b/backends/vulkan/test/op_tests/test_utils.cpp new file mode 100644 index 00000000000..196f079be2c --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "test_utils.h" + +#include <stdexcept> + +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype) { + using ScalarType = executorch::aten::ScalarType; + switch (dtype) { + case at::kByte: + return ScalarType::Byte; + case at::kChar: + return ScalarType::Char; + case at::kShort: + return ScalarType::Short; + case at::kInt: + return ScalarType::Int; + case at::kLong: + return ScalarType::Long; + case at::kHalf: + return ScalarType::Half; + case at::kFloat: + return ScalarType::Float; + case at::kDouble: + return ScalarType::Double; + default: + throw std::runtime_error("Unsupported dtype"); + } +} + +std::string scalar_type_name(c10::ScalarType dtype) { + switch (dtype) { + case c10::kLong: + return "c10::kLong"; + case c10::kShort: + return "c10::kShort"; + case c10::kComplexHalf: + return "c10::kComplexHalf"; + case c10::kComplexFloat: + return "c10::kComplexFloat"; + case c10::kComplexDouble: + return "c10::kComplexDouble"; + case c10::kBool: + return "c10::kBool"; + case c10::kQInt8: + return "c10::kQInt8"; + case c10::kQUInt8: + return "c10::kQUInt8"; + case c10::kQInt32: + return "c10::kQInt32"; + case c10::kBFloat16: + return "c10::kBFloat16"; + case c10::kQUInt4x2: + return "c10::kQUInt4x2"; + case c10::kQUInt2x4: + return "c10::kQUInt2x4"; + case c10::kFloat: + return "c10::kFloat"; + case c10::kHalf: + return "c10::kHalf"; + case c10::kInt: + return "c10::kInt"; + case c10::kChar: + return "c10::kChar"; + case c10::kByte: + return "c10::kByte"; + case c10::kDouble: + return "c10::kDouble"; + case c10::kUInt16: + return "c10::kUInt16"; + case c10::kBits16: + return "c10::kBits16"; + default: + return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")"; + } +} + +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { + using namespace vkcompute; + switch (at_scalartype) { + case c10::kHalf: + return vkapi::kHalf; + case c10::kFloat: + return vkapi::kFloat; + case c10::kDouble: + return vkapi::kDouble; + case c10::kInt: + return vkapi::kInt; + case c10::kLong: + return vkapi::kLong; + case c10::kChar: + return vkapi::kChar; + case c10::kByte: + return vkapi::kByte; + case c10::kShort: + return vkapi::kShort; + case c10::kUInt16: + return vkapi::kUInt16; + default: + VK_THROW( + "Unsupported at::ScalarType: ", + scalar_type_name(at_scalartype), + " (", + static_cast<int>(at_scalartype), + ")"); + } +} diff --git a/backends/vulkan/test/op_tests/test_utils.h b/backends/vulkan/test/op_tests/test_utils.h new file mode 100644 index 00000000000..369767007e0 --- /dev/null +++ b/backends/vulkan/test/op_tests/test_utils.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include <string> + +#include <ATen/ATen.h> +#include <c10/core/ScalarType.h> +#include <executorch/backends/vulkan/runtime/api/api.h> +#include <executorch/runtime/core/exec_aten/exec_aten.h> + +/** + * Convert at::ScalarType to executorch::ScalarType + */ +executorch::aten::ScalarType at_scalartype_to_et_scalartype( + at::ScalarType dtype); + +/** + * Get the string name of a c10::ScalarType for better error messages + */ +std::string scalar_type_name(c10::ScalarType dtype); + +/** + * Convert c10::ScalarType to vkcompute::vkapi::ScalarType + */ +vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype); diff --git a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py index 65bb959f6d1..a054fdf1a19 100644 --- a/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py @@ -177,6 +177,8 @@ def generate_benchmark_fixture(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {{ switch (at_scalartype) {{ + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: @@ -187,6 +189,8 @@ def generate_benchmark_fixture(self) -> str: return vkapi::kInt; case c10::kChar: return vkapi::kChar; + case c10::kBool: + return vkapi::kBool; default: VK_THROW("Unsupported at::ScalarType!"); }} diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py index 4f0d2ff11ef..e7cf5ba92a5 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_vk.py @@ -110,6 +110,8 @@ def gen_parameterization(self) -> str: vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) { switch (at_scalartype) { + case c10::kDouble: + return vkapi::kDouble; case c10::kFloat: return vkapi::kFloat; case c10::kHalf: diff --git a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml index a00bba2bc5a..69587bd38d0 100644 --- a/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml +++ b/backends/vulkan/tools/gpuinfo/glsl/warp_size.yaml @@ -6,7 +6,7 @@ warp_size: parameter_names_with_default_values: - DTYPE: int + DTYPE: int32 STORAGE: buffer generate_variant_forall: METHOD: diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index c1f2770d3d6..876099598dc 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out( static_cast<float>(scale)); \ } \ } break; -#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast<int8_t>(out.scalar_type())); \ - } \ +#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast<int8_t>(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { @@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out( } \ out_data_ptr[current_ix] = \ static_cast<CTYPE_OUT>( \ - input_data_ptr[current_ix] - zero_point) * \ + input_data_ptr[current_ix] - \ + static_cast<int32_t>(zero_point)) * \ _scale; \ } \ }, \ @@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out( apply_over_dim_list( \ [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ out_data_ptr[in_ix] = static_cast<CTYPE_OUT>( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ + (input_data_ptr[in_ix] - static_cast<int32_t>(_zero_point)) * \ + _scale); \ }, \ input, \ optional_dim_list, \ channel_ix); \ } \ break; -#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast<int8_t>(out.scalar_type())); \ - } \ +#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast<int8_t>(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 4665c3d665b..d0b7c882f8e 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -150,7 +150,7 @@ Tensor& quantize_per_tensor_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, @@ -346,7 +346,7 @@ Tensor& quantize_per_channel_out( break; switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(CALCULATE_FLOAT_TYPE); + ET_FORALL_FLOATH_TYPES(CALCULATE_FLOAT_TYPE); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index bbda1590a10..4a0c195e3ab 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype<ScalarType::Int>(); } +/// Test all supported output dtypes for dequantization +template <ScalarType OUT_DTYPE> +void test_output_dtype() { + TensorFactory<ScalarType::Byte> tf; + + Tensor input = tf.full({3, 5}, 100); + double scale = 0.5; + int64_t zero_point = 30; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory<OUT_DTYPE> tfo; + Tensor out = tfo.zeros({3, 5}); + // (100 - 30) * 0.5 = 35 + Tensor expected = tfo.full({3, 5}, 35); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional<ScalarType>(OUT_DTYPE), + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, AllOutputDtypesSupported) { + et_pal_init(); + test_output_dtype<ScalarType::Float>(); + test_output_dtype<ScalarType::Double>(); + test_output_dtype<ScalarType::Half>(); +} + +TEST(OpDequantizeOutTest, HalfOutput) { + et_pal_init(); + TensorFactory<ScalarType::Byte> tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory<ScalarType::Half> tfo; + Tensor out = tfo.zeros({3, 5}); + // (10 - 100000) * 0.5 = -49995 + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional<ScalarType>(ScalarType::Half), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, DoubleOutput) { + et_pal_init(); + TensorFactory<ScalarType::Byte> tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory<ScalarType::Double> tfo; + Tensor out = tfo.zeros({3, 5}); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional<ScalarType>(ScalarType::Double), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpDequantizeOutTest, NonWholeNumbers) { et_pal_init(); TensorFactory<ScalarType::Byte> tf; diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 704d8d06c5c..5cd17223d80 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -49,6 +49,32 @@ void test_dtype() { EXPECT_TENSOR_EQ(out, expected); } +template <ScalarType INPUT_DTYPE> +void test_input_dtype() { + TensorFactory<INPUT_DTYPE> tf_input; + + Tensor input = tf_input.full({3, 5}, 4); + double scale = 0.5; + int64_t zero_point = 108; + int64_t quant_min = 0; + int64_t quant_max = 127; + + TensorFactory<ScalarType::Char> tfo; + Tensor out = tfo.zeros({3, 5}); + // 4 / 0.5 + 108 = 116 + Tensor expected = tfo.full({3, 5}, 116); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, AllInputDtypesSupported) { + test_input_dtype<ScalarType::Float>(); + test_input_dtype<ScalarType::Half>(); + test_input_dtype<ScalarType::Double>(); +} + TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype<ScalarType::Byte>(); test_dtype<ScalarType::Char>(); @@ -58,6 +84,45 @@ TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype<ScalarType::Int>(); } +TEST(OpQuantizeOutTest, DoubleInputTest) { + TensorFactory<ScalarType::Double> tf_double; + + // Test with a more complex value that might have precision differences + Tensor input = tf_double.full({2, 3}, 3.14159265359); + double scale = 0.01; + int64_t zero_point = -100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory<ScalarType::Byte> tfo; + Tensor out = tfo.zeros({2, 3}); + // 3.14159265359 / 0.01 - 100 = 214.159265359 + Tensor expected = tfo.full({2, 3}, 214); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, HalfInputTest) { + TensorFactory<ScalarType::Half> tf_half; + + Tensor input = tf_half.full({2, 3}, 2.5); + double scale = 0.5; + int64_t zero_point = 10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory<ScalarType::Char> tfo; + Tensor out = tfo.zeros({2, 3}); + // 2.5 / 0.5 + 10 = 15 + Tensor expected = tfo.full({2, 3}, 15); + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizeOutTest, TensorArgOverload) { TensorFactory<ScalarType::Float> tf_float; TensorFactory<ScalarType::Double> tf_double; diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 6f81146e925..d81b3ad4d0f 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT, float, Float) \ _(ANOTHER_INPUT, double, Double) +#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, ::executorch::aten::Half, Half) + #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)