From 308d9570a929469e1d155afd317848cfea90a354 Mon Sep 17 00:00:00 2001 From: morelos Date: Mon, 9 Jun 2025 08:04:59 -0700 Subject: [PATCH] [ET-VK][Ops] dequantize_per_token.default test setup Creating dequantize_per_token testing framework along with a reference implementation for testing Differential Revision: [D76267037](https://our.internmc.facebook.com/intern/diff/D76267037/) [ghstack-poisoned] --- .../vulkan/test/op_tests/dequantize_test.cpp | 423 ++++++++++++++++++ 1 file changed, 423 insertions(+) diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 594c53500f5..eb0d430ccc3 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -332,6 +332,71 @@ at::Tensor dequantize_per_tensor_reference_impl( return out.reshape(input.sizes()); } +/* + * Reference implementation of dequantize_per_token + */ +at::Tensor dequantize_per_token_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Create output tensor with the target dtype + at::Tensor out = at::empty_like(input, out_dtype); + + // Calculate number of tokens + int num_tokens = 1; + for (int i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + + // Verify that the number of tokens matches the size of scale and zero_point tensors + assert(num_tokens == scale.numel()); + assert(num_tokens == zero_point.numel()); + + // Reshape input to [num_tokens, last_dim] + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); + + // Dequantize each token separately + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + // Use float for scale since Vulkan doesn't support double + float token_scale = scale[token_idx].item(); + // Use int for zero_point since Vulkan doesn't support int64_t + int token_zero_point = zero_point[token_idx].item(); + + // Dequantize the token + for (int i = 0; i < input.size(-1); i++) { + int qvalue; + if (dtype == at::kByte) { + qvalue = static_cast(reshaped_input[token_idx][i].item()); + } else if (dtype == at::kChar) { + qvalue = static_cast(reshaped_input[token_idx][i].item()); + } else if (dtype == at::kShort) { + qvalue = static_cast(reshaped_input[token_idx][i].item()); + } else if (dtype == at::kInt) { + qvalue = reshaped_input[token_idx][i].item(); + } else if (dtype == at::kLong) { + qvalue = static_cast(reshaped_input[token_idx][i].item()); + } else { + throw std::runtime_error("Unsupported input dtype"); + } + + float value = (qvalue - token_zero_point) * token_scale; + + if (out_dtype == at::kFloat) { + reshaped_out[token_idx][i] = value; + } else if (out_dtype == at::kDouble) { + reshaped_out[token_idx][i] = static_cast(value); + } + } + } + + return out; +} + // Forward declaration of implementation functions void test_vulkan_dequantize_per_tensor_impl( const std::vector& input_sizes, @@ -344,6 +409,17 @@ void test_vulkan_dequantize_per_tensor_impl( const vkcompute::utils::StorageType in_storage, const vkcompute::utils::StorageType out_storage); +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage); + // Wrapper function to test both buffer and texture storage types void test_vulkan_dequantize_per_tensor( const std::vector& input_sizes, @@ -378,6 +454,40 @@ void test_vulkan_dequantize_per_tensor( vkcompute::utils::kTexture3D); } +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + // Test with buffer storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_per_token_impl( + input_sizes, + scales, + zero_points, + quant_min, + quant_max, + dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + void test_reference_dequantize_per_tensor( const std::vector& input_sizes, float scale, @@ -619,3 +729,316 @@ TEST(VulkanDequantizePerTensorTest, test_reference_dequantize_per_tensor_int16_t at::kShort, // input dtype at::kFloat); // output dtype } + +void test_reference_dequantize_per_token( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = dequantize_per_token_reference_impl( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + + // Get implementation output + at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, impl_out, 1e-5, 1e-5); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "implementation:" << std::endl; + std::cout << impl_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +void test_vulkan_dequantize_per_token_impl( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype, + at::ScalarType out_dtype, + const vkcompute::utils::StorageType in_storage, + const vkcompute::utils::StorageType out_storage) { + check_dequantize_args(quant_min, quant_max, dtype, out_dtype); + int num_tokens = 1; + for (int i = 0; i < input_sizes.size() - 1; i++) { + num_tokens *= input_sizes[i]; + } + + ASSERT_EQ(num_tokens, scales.size()); + ASSERT_EQ(num_tokens, zero_points.size()); + + // Create input tensor with quantized values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input; + if (dtype == at::kByte) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); + } else if (dtype == at::kChar) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); + } else if (dtype == at::kShort) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); + } else if (dtype == at::kInt) { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); + } else { + input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); + } + + // Fill with a simple pattern: values from quant_min to quant_max in steps + at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); + for (int token_idx = 0; token_idx < num_tokens; token_idx++) { + float step = 1.0f; + if (input.size(-1) > 1) { + step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); + } + + for (int i = 0; i < input.size(-1); i++) { + int64_t qvalue = quant_min + i * step; + if (dtype == at::kByte) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kChar) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kShort) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kInt) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } else if (dtype == at::kLong) { + reshaped_input[token_idx][i] = static_cast(qvalue); + } + } + } + + // Reshape back to original dimensions + input = reshaped_input.reshape(input_sizes_int64); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( + input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype, out_dtype); + + // Build Vulkan dequantize_per_token graph + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(dtype), in_storage); + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), vkapi::kFloat, in_storage); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), vkapi::kInt, in_storage); + + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + VK_GET_OP_FN("dequantize_per_token.default") + (graph, + { + r_input.value, + r_scale.value, + r_zero_point.value, + r_quant_min, + r_quant_max, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Convert scale tensor to float and copy to GPU + at::Tensor scale_float = scale_tensor.to(at::kFloat); + graph.copy_into_staging( + r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); + + // Convert zero_point tensor to int and copy to GPU + at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); + graph.copy_into_staging( + r_zero_point.staging, + zero_point_int.const_data_ptr(), + zero_point_int.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = at::allclose(reference_out, vk_out, 1e-5, 1e-5); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "reference:" << std::endl; + std::cout << reference_out << std::endl; + std::cout << "vulkan:" << std::endl; + std::cout << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Test cases for dequantize_per_token +TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_uint8_to_float) { + std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; + std::vector zero_points = {5, 10, 15, 20, 25, 30}; + + test_reference_dequantize_per_token( + {2, 3, 4}, // input sizes (2*3=6 tokens) + scales, + zero_points, + 0, // quant_min + 255, // quant_max + at::kByte, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_int8_to_float) { + std::vector scales = {0.05, 0.1, 0.15, 0.2}; + std::vector zero_points = {0, -5, 5, 10}; + + test_reference_dequantize_per_token( + {2, 2, 5}, // input sizes (2*2=4 tokens) + scales, + zero_points, + -128, // quant_min + 127, // quant_max + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizePerTokenTest, test_reference_dequantize_per_token_int16_to_float) { + std::vector scales = {0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007, 0.008}; + std::vector zero_points = {-10, 0, 10, 20, -20, -15, 15, 25}; + + test_reference_dequantize_per_token( + {2, 4, 6}, // input sizes (2*4=8 tokens) + scales, + zero_points, + -32768, // quant_min + 32767, // quant_max + at::kShort, // input dtype + at::kFloat); // output dtype +}