diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index c1f2770d3d6..876099598dc 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -288,16 +288,16 @@ Tensor& dequantize_per_tensor_out( static_cast(scale)); \ } \ } break; -#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { @@ -459,7 +459,8 @@ Tensor& dequantize_per_channel_out( } \ out_data_ptr[current_ix] = \ static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ + input_data_ptr[current_ix] - \ + static_cast(zero_point)) * \ _scale; \ } \ }, \ @@ -478,23 +479,24 @@ Tensor& dequantize_per_channel_out( apply_over_dim_list( \ [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ + (input_data_ptr[in_ix] - static_cast(_zero_point)) * \ + _scale); \ }, \ input, \ optional_dim_list, \ channel_ix); \ } \ break; -#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ +#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOATH_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ break; switch (input.scalar_type()) { diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index bbda1590a10..4a0c195e3ab 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -67,6 +67,96 @@ TEST(OpDequantizeOutTest, AllDtypesSupported) { test_dtype(); } +/// Test all supported output dtypes for dequantization +template +void test_output_dtype() { + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 100); + double scale = 0.5; + int64_t zero_point = 30; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (100 - 30) * 0.5 = 35 + Tensor expected = tfo.full({3, 5}, 35); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(OUT_DTYPE), + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, AllOutputDtypesSupported) { + et_pal_init(); + test_output_dtype(); + test_output_dtype(); + test_output_dtype(); +} + +TEST(OpDequantizeOutTest, HalfOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + // (10 - 100000) * 0.5 = -49995 + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Half), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpDequantizeOutTest, DoubleOutput) { + et_pal_init(); + TensorFactory tf; + + Tensor input = tf.full({3, 5}, 10); + double scale = 0.5; + int64_t zero_point = 100000; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 5}); + dequantize_per_tensor_out( + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(ScalarType::Double), + out); + + // The expected result should be (10 - 100000) * 0.5 = -49995 + Tensor expected = tfo.full({3, 5}, -49995); + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpDequantizeOutTest, NonWholeNumbers) { et_pal_init(); TensorFactory tf; diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 6f81146e925..d81b3ad4d0f 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -199,6 +199,11 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT, float, Float) \ _(ANOTHER_INPUT, double, Double) +#define ET_FORALL_FLOATH_TYPES_WITH(ANOTHER_INPUT, _) \ + _(ANOTHER_INPUT, float, Float) \ + _(ANOTHER_INPUT, double, Double) \ + _(ANOTHER_INPUT, ::executorch::aten::Half, Half) + #define ET_FORALL_FLOAT_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double)