@@ -35,45 +35,8 @@ Tensor& ge_scalar_out(
35
35
const Tensor& a,
36
36
const Scalar& b,
37
37
Tensor& out) {
38
- (void )ctx;
39
-
40
- // Resize for dynamic shape
41
- ET_KERNEL_CHECK_MSG (
42
- ctx,
43
- resize_tensor (out, a.sizes ()) == Error::Ok,
44
- InvalidArgument,
45
- out,
46
- " Failed to resize output tensor." );
47
-
48
- ScalarType a_type = a.scalar_type ();
49
- ScalarType b_type = utils::get_scalar_dtype (b);
50
- ScalarType common_type = utils::promote_type_with_scalar (a_type, b);
51
- ScalarType out_type = out.scalar_type ();
52
-
53
- ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " ge.Scalar_out" , CTYPE_A, [&]() {
54
- ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, " ge.Scalar_out" , CTYPE_B, [&]() {
55
- ET_SWITCH_REAL_TYPES_AND (
56
- Bool, common_type, ctx, " ge.Scalar_out" , CTYPE_IN, [&]() {
57
- ET_SWITCH_REAL_TYPES_AND (
58
- Bool, out_type, ctx, " ge.Scalar_out" , CTYPE_OUT, [&]() {
59
- CTYPE_B val_b = 0 ;
60
- utils::extract_scalar (b, &val_b);
61
- apply_unary_map_fn (
62
- [val_b](const CTYPE_A val_a) {
63
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
64
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
65
- bool value = a_casted >= b_casted;
66
- return static_cast <CTYPE_OUT>(value);
67
- },
68
- a.const_data_ptr <CTYPE_A>(),
69
- out.mutable_data_ptr <CTYPE_OUT>(),
70
- out.numel ());
71
- });
72
- });
73
- });
74
- });
75
-
76
- return out;
38
+ return scalar_comparison_op_with_scalar_promotion_out<std::greater_equal>(
39
+ ctx, a, b, out, " ge.Scalar_out" );
77
40
}
78
41
79
42
} // namespace native
0 commit comments