diff --git a/kernels/portable/cpu/op_bitwise_and.cpp b/kernels/portable/cpu/op_bitwise_and.cpp index b1078f780a4..fa840f7e5cc 100644 --- a/kernels/portable/cpu/op_bitwise_and.cpp +++ b/kernels/portable/cpu/op_bitwise_and.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_and(CTYPE a, CTYPE b) { - return a & b; -} - -template <> -bool bitwise_and(bool a, bool b) { - return a && b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_and_Tensor_out( @@ -38,60 +26,7 @@ Tensor& bitwise_and_Tensor_out( const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_INT_TYPES_AND( - Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, - common_type, - ctx, - "bitwise_and.Tensor_out", - CTYPE_IN, - [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_and.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = - bitwise_and(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - }); - - return out; + return bitwise_op_out(ctx, a, b, out, "bitwise_and.Tensor_out"); } Tensor& bitwise_and_Scalar_out( @@ -142,8 +77,8 @@ Tensor& bitwise_and_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = - bitwise_and(a_casted, b_casted); + CTYPE_IN value = std::bit_and()( + a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_bitwise_or.cpp b/kernels/portable/cpu/op_bitwise_or.cpp index c13c68d3db4..e42632b71f2 100644 --- a/kernels/portable/cpu/op_bitwise_or.cpp +++ b/kernels/portable/cpu/op_bitwise_or.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_or(CTYPE a, CTYPE b) { - return a | b; -} - -template <> -bool bitwise_or(bool a, bool b) { - return a || b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_or_Tensor_out( @@ -38,59 +26,7 @@ Tensor& bitwise_or_Tensor_out( const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_INT_TYPES_AND( - Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, - common_type, - ctx, - "bitwise_or.Tensor_out", - CTYPE_IN, - [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_or.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = bitwise_or(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - }); - - return out; + return bitwise_op_out(ctx, a, b, out, "bitwise_or.Tensor_out"); } Tensor& bitwise_or_Scalar_out( @@ -141,7 +77,8 @@ Tensor& bitwise_or_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = bitwise_or(a_casted, b_casted); + CTYPE_IN value = + std::bit_or()(a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_bitwise_xor.cpp b/kernels/portable/cpu/op_bitwise_xor.cpp index d2ea8a81cfb..c59cdcc692e 100644 --- a/kernels/portable/cpu/op_bitwise_xor.cpp +++ b/kernels/portable/cpu/op_bitwise_xor.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_xor(CTYPE a, CTYPE b) { - return a ^ b; -} - -template <> -bool bitwise_xor(bool a, bool b) { - return a != b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_xor_Tensor_out( @@ -38,61 +26,7 @@ Tensor& bitwise_xor_Tensor_out( const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_INT_TYPES_AND( - Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( - Bool, - common_type, - ctx, - "bitwise_xor.Tensor_out", - CTYPE_IN, - [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_xor.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = - bitwise_xor(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - }); - - return out; + return bitwise_op_out(ctx, a, b, out, "bitwise_xor.Tensor_out"); } Tensor& bitwise_xor_Scalar_out( @@ -143,8 +77,8 @@ Tensor& bitwise_xor_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = - bitwise_xor(a_casted, b_casted); + CTYPE_IN value = std::bit_xor()( + a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 06c87d03f2d..50d7e8c374d 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -53,7 +53,7 @@ __ET_NODISCARD bool check_bounds( } }); } else if (isFloatingType(out_type)) { - ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { if (std::isfinite(val) && is_out_of_bounds(val)) { ET_LOG(Error, "%s value out of bounds", val_name); @@ -119,7 +119,7 @@ Tensor& clamp_out( ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { // Extract optional min value CTYPE_OUT min = 0; if (has_min) { @@ -140,7 +140,7 @@ Tensor& clamp_out( }); } - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() { apply_unary_map_fn( [has_min, min, has_max, max](const CTYPE_IN val_in) { CTYPE_OUT val_out = static_cast(val_in); @@ -195,20 +195,20 @@ Tensor& clamp_tensor_out( ScalarType out_type = out.scalar_type(); if (has_min) { - common_type = promoteTypes(common_type, min_type); + common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true); } if (has_max) { - common_type = promoteTypes(common_type, max_type); + common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true); } ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); constexpr auto name = "clamp.Tensor_out"; - ET_SWITCH_REALB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { - ET_SWITCH_REALB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { - ET_SWITCH_REALB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { - ET_SWITCH_REALB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { + ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { + ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { + ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { apply_ternary_elementwise_fn< CTYPE_IN, CTYPE_MIN, diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 261f77ce617..0514df0ca25 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -20,6 +20,60 @@ namespace native { using Tensor = exec_aten::Tensor; using ScalarType = exec_aten::ScalarType; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner { + static void + run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::floor_divide(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&, bool&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner + : public ReportCanCastBug {}; + +} // namespace + Tensor& floor_divide_out( RuntimeContext& ctx, const Tensor& a, @@ -46,36 +100,17 @@ Tensor& floor_divide_out( Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [common_type, &div_by_zero_error]( - const CTYPE_A val_a, const CTYPE_B val_b) { - if (isIntegralType( - common_type, /*includeBool=*/true)) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::floor_divide( - a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { + FloorDivideInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out, div_by_zero_error); }); }); }); diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 0083c1379d5..42f83731199 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -19,6 +19,60 @@ namespace native { using Tensor = exec_aten::Tensor; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner { + static void + run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = std::fmod(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&, bool&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner + : public ReportCanCastBug {}; + +} // namespace + Tensor& fmod_Tensor_out( RuntimeContext& ctx, const Tensor& a, @@ -44,35 +98,18 @@ Tensor& fmod_Tensor_out( Bool, a_type, ctx, "fmod.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "fmod.Tensor_out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "fmod.Tensor_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [common_type, &div_by_zero_error]( - const CTYPE_A val_a, const CTYPE_B val_b) { - if (isIntegralType( - common_type, /*includeBool=*/true)) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = std::fmod(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() { + FmodInner< + !std::is_same::value && + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out, div_by_zero_error); }); }); }); diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 3e34035d5f6..1353479b294 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace torch { @@ -15,10 +16,49 @@ namespace executor { namespace native { namespace { -template -const T& max(const T& a, const T& b) { - return (b > a) ? b : a; -} +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::max_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner + : public ReportCanCastBug {}; } // namespace @@ -44,20 +84,16 @@ Tensor& maximum_out( ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() { ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = max(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { + MaximumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index 767a2c4ca59..f18d1a6d368 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace torch { @@ -15,10 +16,49 @@ namespace executor { namespace native { namespace { -template -const T& min(const T& a, const T& b) { - return (b < a) ? b : a; -} +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::min_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner + : public ReportCanCastBug {}; } // namespace @@ -37,30 +77,24 @@ Tensor& minimum_out( ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = min(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); - }); + ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() { + ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { + MinimumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); + }); }); }); diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index 9e48374a81a..7c858c1c08a 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -20,6 +20,52 @@ namespace native { using Tensor = exec_aten::Tensor; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::remainder_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner + : public ReportCanCastBug {}; + +} // namespace Tensor& remainder_Tensor_out( RuntimeContext& ctx, const Tensor& a, @@ -45,32 +91,17 @@ Tensor& remainder_Tensor_out( Bool, a_type, ctx, "remainder.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "remainder.Tensor_out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "remainder.Tensor_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, - ctx, - "remainder.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::remainder_override( - a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "remainder.Tensor_out", CTYPE_OUT, [&]() { + RemainderInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); diff --git a/kernels/portable/cpu/pattern/bitwise_op.h b/kernels/portable/cpu/pattern/bitwise_op.h new file mode 100644 index 00000000000..4a338bbf020 --- /dev/null +++ b/kernels/portable/cpu/pattern/bitwise_op.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torch { +namespace executor { +namespace native { +template