diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index 4e434935356..fb0d2ee61bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -11,6 +11,7 @@ dequantize_buffer: OUT_DTYPE: - VALUE: half - VALUE: float + - VALUE: double shader_variants: - NAME: dequantize_per_tensor_buffer MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index cfc61dd1816..801f4a2f6a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -139,7 +139,10 @@ void dequantize_per_tensor() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); OUT_T value = dequantize_val(qvalue, scale, zero_point); - outtex[i] = value; + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; } write_texel(t_out, pos, outtex); } @@ -177,7 +180,10 @@ void dequantize_per_token() { [[unroll]] for (int i = 0; i < 4; ++i) { IN_T qvalue = IN_T(intex[i]); OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - outtex[i] = value; + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; } write_texel(t_out, pos, outtex); diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index fc8c18468ed..7d19a543a03 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -11,6 +11,7 @@ dequantize_texture: OUT_DTYPE: - VALUE: half - VALUE: float + - VALUE: double shader_variants: - NAME: dequantize_per_tensor_texture3d MODE: per_tensor diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml index 90af2590936..4d95d610314 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml @@ -7,6 +7,7 @@ quantize_buffer: IN_DTYPE: - VALUE: half - VALUE: float + - VALUE: double OUT_DTYPE: - VALUE: uint8 - VALUE: int8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml index 042eb0f8196..65002ce26b6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml @@ -7,6 +7,7 @@ quantize_texture: IN_DTYPE: - VALUE: half - VALUE: float + - VALUE: double OUT_DTYPE: - VALUE: uint8 - VALUE: int8 diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 35712d59fb9..49277b4d718 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -188,6 +188,7 @@ void quantize_per_tensor_impl( // Verify input is a floating point type VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); @@ -214,6 +215,7 @@ void quantize_per_token_impl( // Verify input is a floating point type VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || graph.dtype_of(input) == vkapi::kFloat || graph.dtype_of(input) == vkapi::kHalf); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 1ec0602a4f2..6c604076c41 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_dequantize_per_tensor_impl( input_sizes, @@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // Telling the system to expect a float instead of a double + // since the shader can only return 32bit anyways + if (out_dtype == at::kDouble) { + out_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_dequantize_per_token_impl( input_sizes, @@ -793,6 +805,24 @@ TEST( at::kHalf); // output dtype } +TEST( + VulkanDequantizePerTensorTest, + test_vulkan_dequantize_per_tensor_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_dequantize_per_tensor( + {2, 3}, // input sizes + 0.05, // scale + 10, // zero_point + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} + void test_reference_dequantize_per_token( const std::vector& input_sizes, const std::vector& scales, @@ -1288,3 +1318,24 @@ TEST( at::kInt, // input dtype at::kHalf); // output dtype } + +TEST( + VulkanDequantizePerTokenTest, + test_vulkan_dequantize_per_token_int8_to_double) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.05, 0.001}; + std::vector zero_points = {10, -5}; + + test_vulkan_dequantize_per_token( + {2, 2}, // input sizes (2 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kDouble); // output dtype +} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 7ea98b14fb2..150bda6989e 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_quantize_per_tensor_impl( input_sizes, @@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token( vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); + // If the in_dtype is a double, convert to float for texture implementation + // since they don't support 64bit as inputs + if (in_dtype == at::kDouble) { + in_dtype = at::kFloat; + } + // Test with texture storage test_vulkan_quantize_per_token_impl( input_sizes, @@ -655,6 +667,24 @@ TEST( at::kChar); // output dtype } +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_tensor_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_vulkan_quantize_per_tensor( + {2, 3}, // input sizes + 0.01, // scale + 1, // zero_point + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +} + void test_reference_quantize_per_token( const std::vector& input_sizes, const std::vector& pre_scales, @@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { at::kHalf, // input dtype at::kChar); // output dtype } + +TEST( + VulkanQuantizePerTensorTest, + test_vulkan_quantize_per_token_double_to_int8) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_vulkan_quantize_per_token( + {2, 2}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kDouble, // input dtype + at::kChar); // output dtype +}