@@ -19,6 +19,59 @@ namespace native {
19
19
20
20
using Tensor = exec_aten::Tensor;
21
21
22
+ namespace {
23
+ template <
24
+ bool can_cast,
25
+ typename CTYPE_A,
26
+ typename CTYPE_B,
27
+ typename CTYPE_IN,
28
+ typename CTYPE_OUT>
29
+ struct FmodInner ;
30
+
31
+ template <
32
+ typename CTYPE_A,
33
+ typename CTYPE_B,
34
+ typename CTYPE_IN,
35
+ typename CTYPE_OUT>
36
+ struct FmodInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37
+ static void
38
+ run (const Tensor& a, const Tensor& b, Tensor& out, bool & div_by_zero_error) {
39
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40
+ [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
41
+ if (is_integral_type<CTYPE_IN, /* includeBool=*/ true >::value) {
42
+ if (val_b == 0 ) {
43
+ div_by_zero_error = true ;
44
+ return static_cast <CTYPE_OUT>(0 );
45
+ }
46
+ }
47
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
48
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
49
+ CTYPE_IN value = std::fmod (a_casted, b_casted);
50
+
51
+ return static_cast <CTYPE_OUT>(value);
52
+ },
53
+ a,
54
+ b,
55
+ out);
56
+ }
57
+ };
58
+
59
+ struct ReportCanCastBug {
60
+ static void run (const Tensor&, const Tensor&, Tensor&, bool &) {
61
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
62
+ }
63
+ };
64
+
65
+ template <
66
+ typename CTYPE_A,
67
+ typename CTYPE_B,
68
+ typename CTYPE_IN,
69
+ typename CTYPE_OUT>
70
+ struct FmodInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
71
+ : public ReportCanCastBug {};
72
+
73
+ } // namespace
74
+
22
75
Tensor& fmod_Tensor_out (
23
76
RuntimeContext& ctx,
24
77
const Tensor& a,
@@ -44,35 +97,18 @@ Tensor& fmod_Tensor_out(
44
97
Bool, a_type, ctx, " fmod.Tensor_out" , CTYPE_A, [&]() {
45
98
ET_SWITCH_REAL_TYPES_AND (
46
99
Bool, b_type, ctx, " fmod.Tensor_out" , CTYPE_B, [&]() {
100
+ using CTYPE_IN = typename torch::executor::
101
+ promote_types<CTYPE_A, CTYPE_B>::type;
102
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
47
103
ET_SWITCH_REAL_TYPES (
48
- common_type, ctx, " fmod.Tensor_out" , CTYPE_IN, [&]() {
49
- ET_SWITCH_REAL_TYPES (
50
- out_type, ctx, " fmod.Tensor_out" , CTYPE_OUT, [&]() {
51
- apply_binary_elementwise_fn<
52
- CTYPE_A,
53
- CTYPE_B,
54
- CTYPE_OUT>(
55
- [common_type, &div_by_zero_error](
56
- const CTYPE_A val_a, const CTYPE_B val_b) {
57
- if (isIntegralType (
58
- common_type, /* includeBool=*/ true )) {
59
- if (val_b == 0 ) {
60
- div_by_zero_error = true ;
61
- return static_cast <CTYPE_OUT>(0 );
62
- }
63
- }
64
- CTYPE_IN a_casted =
65
- static_cast <CTYPE_IN>(val_a);
66
- CTYPE_IN b_casted =
67
- static_cast <CTYPE_IN>(val_b);
68
- CTYPE_IN value = std::fmod (a_casted, b_casted);
69
-
70
- return static_cast <CTYPE_OUT>(value);
71
- },
72
- a,
73
- b,
74
- out);
75
- });
104
+ out_type, ctx, " fmod.Tensor_out" , CTYPE_OUT, [&]() {
105
+ FmodInner<
106
+ !std::is_same<CTYPE_IN, bool >::value &&
107
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
108
+ CTYPE_A,
109
+ CTYPE_B,
110
+ CTYPE_IN,
111
+ CTYPE_OUT>::run (a, b, out, div_by_zero_error);
76
112
});
77
113
});
78
114
});
0 commit comments