diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 93717e3b02ef0..36208109e30c3 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" @@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern matchAndRewrite(arith::ConstantOp arithConst, arith::ConstantOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - arithConst, arithConst.getType(), adaptor.getValue()); + Type newTy = this->getTypeConverter()->convertType(arithConst.getType()); + if (!newTy) + return rewriter.notifyMatchFailure(arithConst, "type conversion failed"); + rewriter.replaceOpWithNewOp(arithConst, newTy, + adaptor.getValue()); return success(); } }; @@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) { return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(), signedness); } + } else if (emitc::isPointerWideType(ty)) { + if (isa(ty) != needsUnsigned) { + if (needsUnsigned) + return emitc::SizeTType::get(ty.getContext()); + return emitc::PtrDiffTType::get(ty.getContext()); + } } return ty; } @@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type type = adaptor.getLhs().getType(); - if (!isa_and_nonnull(type)) { - return rewriter.notifyMatchFailure(op, "expected integer or index type"); + if (!type || !(isa(type) || emitc::isPointerWideType(type))) { + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t/ptrdiff_t type"); } bool needsUnsigned = needsUnsignedCmp(op.getPredicate()); @@ -318,8 +329,10 @@ class CastConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type opReturnType = this->getTypeConverter()->convertType(op.getType()); - if (!isa_and_nonnull(opReturnType)) - return rewriter.notifyMatchFailure(op, "expected integer result type"); + if (!opReturnType || !(isa(opReturnType) || + emitc::isPointerWideType(opReturnType))) + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t/ptrdiff_t result type"); if (adaptor.getOperands().size() != 1) { return rewriter.notifyMatchFailure( @@ -327,8 +340,10 @@ class CastConversion : public OpConversionPattern { } Type operandType = adaptor.getIn().getType(); - if (!isa_and_nonnull(operandType)) - return rewriter.notifyMatchFailure(op, "expected integer operand type"); + if (!operandType || !(isa(operandType) || + emitc::isPointerWideType(operandType))) + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t/ptrdiff_t operand type"); // Signed (sign-extending) casts from i1 are not supported. if (operandType.isInteger(1) && !castToUnsigned) @@ -339,8 +354,11 @@ class CastConversion : public OpConversionPattern { // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives // truncation. if (opReturnType.isInteger(1)) { + Type attrType = (emitc::isPointerWideType(operandType)) + ? rewriter.getIndexType() + : operandType; auto constOne = rewriter.create( - op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1)); + op.getLoc(), operandType, rewriter.getOneAttr(attrType)); auto oneAndOperand = rewriter.create( op.getLoc(), operandType, adaptor.getIn(), constOne); rewriter.replaceOpWithNewOp(op, opReturnType, @@ -393,7 +411,11 @@ class ArithOpConversion final : public OpConversionPattern { matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(arithOp, arithOp.getType(), + Type newTy = this->getTypeConverter()->convertType(arithOp.getType()); + if (!newTy) + return rewriter.notifyMatchFailure(arithOp, + "converting result type failed"); + rewriter.template replaceOpWithNewOp(arithOp, newTy, adaptor.getOperands()); return success(); @@ -410,8 +432,9 @@ class IntegerOpConversion final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type type = this->getTypeConverter()->convertType(op.getType()); - if (!isa_and_nonnull(type)) { - return rewriter.notifyMatchFailure(op, "expected integer type"); + if (!type || !(isa(type) || emitc::isPointerWideType(type))) { + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t/ptrdiff_t type"); } if (type.isInteger(1)) { @@ -482,6 +505,89 @@ class BitwiseOpConversion : public OpConversionPattern { } }; +template +class ShiftOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type type = this->getTypeConverter()->convertType(op.getType()); + if (!type || !(isa(type) || emitc::isPointerWideType(type))) { + return rewriter.notifyMatchFailure( + op, "expected integer or size_t/ssize_t/ptrdiff_t type"); + } + + if (type.isInteger(1)) { + return rewriter.notifyMatchFailure(op, "i1 type is not implemented"); + } + + Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp); + + Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); + // Shift amount interpreted as unsigned per Arith dialect spec. + Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(), + /*needsUnsigned=*/true); + Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType); + + // Add a runtime check for overflow + Value width; + if (emitc::isPointerWideType(type)) { + Value eight = rewriter.create( + op.getLoc(), rhsType, rewriter.getIndexAttr(8)); + emitc::CallOpaqueOp sizeOfCall = rewriter.create( + op.getLoc(), rhsType, "sizeof", ArrayRef{eight}); + width = rewriter.create(op.getLoc(), rhsType, eight, + sizeOfCall.getResult(0)); + } else { + width = rewriter.create( + op.getLoc(), rhsType, + rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); + } + + Value excessCheck = rewriter.create( + op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); + + // Any concrete value is a valid refinement of poison. + Value poison = rewriter.create( + op.getLoc(), arithmeticType, + (isa(arithmeticType) + ? rewriter.getIntegerAttr(arithmeticType, 0) + : rewriter.getIndexAttr(0))); + + emitc::ExpressionOp ternary = rewriter.create( + op.getLoc(), arithmeticType, /*do_not_inline=*/false); + Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + rewriter.setInsertionPointToStart(&bodyBlock); + Value arithmeticResult = + rewriter.create(op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = rewriter.create( + op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); + rewriter.create(op.getLoc(), resultOrPoison); + rewriter.setInsertionPoint(op->getBlock(), currentPoint); + + Value result = adaptValueType(ternary, rewriter, type); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +class SignedShiftOpConversion final + : public ShiftOpConversion { + using ShiftOpConversion::ShiftOpConversion; +}; + +template +class UnsignedShiftOpConversion final + : public ShiftOpConversion { + using ShiftOpConversion::ShiftOpConversion; +}; + class SelectOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -606,6 +712,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); + mlir::populateEmitCSizeTTypeConversions(typeConverter); + // clang-format off patterns.add< ArithConstantOpConversionPattern, @@ -621,6 +729,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, BitwiseOpConversion, BitwiseOpConversion, BitwiseOpConversion, + UnsignedShiftOpConversion, + SignedShiftOpConversion, + UnsignedShiftOpConversion, CmpFOpConversion, CmpIOpConversion, NegFOpConversion, @@ -629,6 +740,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, UnsignedCastConversion, SignedCastConversion, UnsignedCastConversion, + SignedCastConversion, + UnsignedCastConversion, ItoFCastOpConversion, ItoFCastOpConversion, FtoICastOpConversion, diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt index a3784f47c3bc2..730a4b341673d 100644 --- a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRArithToEmitC LINK_LIBS PUBLIC MLIRArithDialect MLIREmitCDialect + MLIREmitCTransforms MLIRPass MLIRTransformUtils ) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index caef04052aa8c..766ad4039335e 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) { %idx = arith.extsi %arg0 : i1 to i32 return } + +// ----- + +func.func @arith_shli_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shli'}} + %shli = arith.shli %arg0, %arg1 : i1 + return +} + +// ----- + +func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shrsi'}} + %shrsi = arith.shrsi %arg0, %arg1 : i1 + return +} + +// ----- + +func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.shrui'}} + %shrui = arith.shrui %arg0, %arg1 : i1 + return +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 0289b7dc0728f..858ccd1171445 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -3,7 +3,8 @@ // CHECK-LABEL: arith_constants func.func @arith_constants() { // CHECK: emitc.constant - // CHECK-SAME: value = 0 : index + // CHECK-SAME: value = 0 + // CHECK-SAME: () -> !emitc.size_t %c_index = arith.constant 0 : index // CHECK: emitc.constant // CHECK-SAME: value = 0 : i32 @@ -75,13 +76,18 @@ func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) { // ----- // CHECK-LABEL: arith_index -func.func @arith_index(%arg0: index, %arg1: index) { - // CHECK: emitc.add %arg0, %arg1 : (index, index) -> index - %0 = arith.addi %arg0, %arg1 : index - // CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index - %1 = arith.subi %arg0, %arg1 : index - // CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index - %2 = arith.muli %arg0, %arg1 : index +func.func @arith_index(%arg0: i32, %arg1: i32) { + // CHECK: %[[CST0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t + %cst0 = arith.index_cast %arg0 : i32 to index + // CHECK: %[[CST1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t + %cst1 = arith.index_cast %arg1 : i32 to index + + // CHECK: emitc.add %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + %0 = arith.addi %cst0, %cst1 : index + // CHECK: emitc.sub %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + %1 = arith.subi %cst0, %cst1 : index + // CHECK: emitc.mul %[[CST0]], %[[CST1]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + %2 = arith.muli %cst0, %cst1 : index return } @@ -138,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) { // ----- +// CHECK-LABEL: arith_shift_left +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func.func @arith_shift_left(%arg0: i32, %arg1: i32) { + // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32 + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 + // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32 + // CHECK: emitc.yield %[[Ternary]] : ui32 + // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 + %1 = arith.shli %arg0, %arg1 : i32 + return +} + +// ----- + +// CHECK-LABEL: arith_shift_right +// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 +func.func @arith_shift_right(%arg0: i32, %arg1: i32) { + // CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32 + // CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32 + // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32 + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32 + // CHECK: emitc.yield %[[Ternary]] : ui32 + // CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32 + %2 = arith.shrui %arg0, %arg1 : i32 + + // CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32 + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32 + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1 + // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32 + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32 + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32 + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32 + // CHECK: emitc.yield %[[STernary]] : i32 + %3 = arith.shrsi %arg0, %arg1 : i32 + + return +} + +// ----- + +// CHECK-LABEL: arith_shift_left_index +// CHECK-SAME: %[[AMOUNT:.*]]: i32 +func.func @arith_shift_left_index(%amount: i32) { + %cst0 = "arith.constant"() {value = 42 : index} : () -> (index) + %cast1 = arith.index_cast %amount : i32 to index + // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t + // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t + // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t + // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index + // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0 + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t + // CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t + // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t + %1 = arith.shli %cst0, %cast1 : index + return +} + +// ----- + +// CHECK-LABEL: arith_shift_right_index +// CHECK-SAME: %[[AMOUNT:.*]]: i32 +func.func @arith_shift_right_index(%amount: i32) { + // CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t + // CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ptrdiff_t + // CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ptrdiff_t to !emitc.size_t + %arg0 = "arith.constant"() {value = 42 : index} : () -> (index) + %arg1 = arith.index_cast %amount : i32 to index + + // CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index + // CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 + // CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t + // CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t + // CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t + // CHECK: emitc.yield %[[Ternary]] : !emitc.size_t + %2 = arith.shrui %arg0, %arg1 : index + + // CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ptrdiff_t + // CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t + // CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1 + // CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ptrdiff_t + // CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ptrdiff_t + // CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ptrdiff_t, !emitc.size_t) -> !emitc.ptrdiff_t + // CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ptrdiff_t + // CHECK: emitc.yield %[[STernary]] : !emitc.ptrdiff_t + // CHECK: emitc.cast %[[SShiftRes]] : !emitc.ptrdiff_t to !emitc.size_t + %3 = arith.shrsi %arg0, %arg1 : index + + return +} + +// ----- + func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () { // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32> %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32> @@ -420,6 +536,27 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) { return } +func.func @arith_cmpi_index(%arg0: i32, %arg1: i32) -> i1 { + // CHECK-LABEL: arith_cmpi_index + + // CHECK: %[[Cst0:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t + %idx0 = arith.index_cast %arg0 : i32 to index + // CHECK: %[[Cst1:.*]] = emitc.cast %{{.*}} : {{.*}} to !emitc.size_t + %idx1 = arith.index_cast %arg0 : i32 to index + + // CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, %[[Cst0]], %[[Cst1]] : (!emitc.size_t, !emitc.size_t) -> i1 + %ult = arith.cmpi ult, %idx0, %idx1 : index + + // CHECK-DAG: %[[CastArg0:[^ ]*]] = emitc.cast %[[Cst0]] : !emitc.size_t to !emitc.ptrdiff_t + // CHECK-DAG: %[[CastArg1:[^ ]*]] = emitc.cast %[[Cst1]] : !emitc.size_t to !emitc.ptrdiff_t + // CHECK-DAG: %[[SLT:[^ ]*]] = emitc.cmp lt, %[[CastArg0]], %[[CastArg1]] : (!emitc.ptrdiff_t, !emitc.ptrdiff_t) -> i1 + %slt = arith.cmpi slt, %idx0, %idx1 : index + + // CHECK: return %[[SLT]] + return %slt: i1 +} + + // ----- func.func @arith_negf(%arg0: f32) -> f32 { @@ -536,3 +673,47 @@ func.func @arith_extui_i1_to_i32(%arg0: i1) { %idx = arith.extui %arg0 : i1 to i32 return } + +// ----- + +func.func @arith_index_cast(%arg0: i32) -> i32 { + // CHECK-LABEL: arith_index_cast + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to !emitc.ptrdiff_t + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : !emitc.ptrdiff_t to !emitc.size_t + %idx = arith.index_cast %arg0 : i32 to index + // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to !emitc.ptrdiff_t + // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : !emitc.ptrdiff_t to i32 + %int = arith.index_cast %idx : index to i32 + + // CHECK: %[[Const:.*]] = "emitc.constant" + // CHECK-SAME: value = 1 + // CHECK-SAME: () -> !emitc.size_t + // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1 + %bool = arith.index_cast %idx : index to i1 + + return %int : i32 +} + +// ----- + +func.func @arith_index_castui(%arg0: i32) -> i32 { + // CHECK-LABEL: arith_index_castui + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to !emitc.size_t + %idx = arith.index_castui %arg0 : i32 to index + // CHECK: %[[Conv2:.*]] = emitc.cast %[[Conv1]] : !emitc.size_t to ui32 + // CHECK: %[[Conv3:.*]] = emitc.cast %[[Conv2]] : ui32 to i32 + %int = arith.index_castui %idx : index to i32 + + // CHECK: %[[Const:.*]] = "emitc.constant" + // CHECK-SAME: value = 1 + // CHECK-SAME: () -> !emitc.size_t + // CHECK: %[[AndOne:.*]] = emitc.bitwise_and %[[Conv1]], %[[Const]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[Conv4:.*]] = emitc.cast %[[AndOne]] : !emitc.size_t to i1 + %bool = arith.index_castui %idx : index to i1 + + return %int : i32 +}