diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 2f6e869ae7b73..51b3627056fbb 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7177,10 +7177,33 @@ llvm::computeConstantRangeIncludingKnownBits(const WithCache &V, return CR1.intersectWith(CR2, RangeType); } +static bool isKnownToNotOverflowFromAssume(const SimplifyQuery &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return false; + + for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(Q.CxtI)) { + if (!Elem.Assume) + continue; + + AssumeInst *I = cast(Elem.Assume); + if (match(I->getArgOperand(0), + m_Not(m_ExtractValue<1>(m_Specific(Q.CxtI)))) && + isValidAssumeForContext(I, Q.CxtI, /*DT=*/nullptr, + /*AllowEphemerals=*/true)) + return true; + } + return false; +} + OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ, bool IsNSW) { + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ); KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ); @@ -7196,6 +7219,9 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ) { + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -7236,6 +7262,9 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const WithCache &LHS, const WithCache &RHS, const SimplifyQuery &SQ) { + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ); ConstantRange RHSRange = @@ -7251,6 +7280,9 @@ computeOverflowForSignedAdd(const WithCache &LHS, return OverflowResult::NeverOverflows; } + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + // If LHS and RHS each have at least two sign bits, the addition will look // like // @@ -7305,6 +7337,9 @@ computeOverflowForSignedAdd(const WithCache &LHS, OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ) { + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + // X - (X % ?) // The remainder of a value can't have greater magnitude than itself, // so the subtraction can't overflow. @@ -7338,6 +7373,9 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ) { + if (isKnownToNotOverflowFromAssume(SQ)) + return OverflowResult::NeverOverflows; + // X - (X % ?) // The remainder of a value can't have greater magnitude than itself, // so the subtraction can't overflow. @@ -10180,6 +10218,9 @@ void llvm::findValuesAffectedByCondition( m_Value()))) { // Handle patterns that computeKnownFPClass() support. AddAffected(A); + } else if (IsAssume && match(V, m_Not(m_ExtractValue<1>(m_Value(A)))) && + isa(A)) { + AddAffected(A); } } } diff --git a/llvm/test/Transforms/InstCombine/with_overflow.ll b/llvm/test/Transforms/InstCombine/with_overflow.ll index fa810408730e1..dd64e993bb620 100644 --- a/llvm/test/Transforms/InstCombine/with_overflow.ll +++ b/llvm/test/Transforms/InstCombine/with_overflow.ll @@ -1064,3 +1064,132 @@ define i8 @smul_7(i8 %x, ptr %p) { store i1 %ov, ptr %p ret i8 %r } + +define i8 @uadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @uadd_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = add nuw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i8 @sadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @sadd_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = add nsw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.sadd.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i8 @usub_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @usub_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = sub nuw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.usub.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i8 @ssub_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @ssub_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.ssub.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i8 @umul_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @umul_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = mul nuw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i8 @smul_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @smul_assume_no_overflow( +; CHECK-NEXT: [[CALL:%.*]] = mul nsw i8 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: ret i8 [[CALL]] +; + %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i8 %ret +} + +define i1 @ephemeral_call_assume_no_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @ephemeral_call_assume_no_overflow( +; CHECK-NEXT: ret i1 true +; + %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + ret i1 %not +} + +define i8 @neg_assume_overflow(i8 noundef %a, i8 noundef %b) { +; CHECK-LABEL: @neg_assume_overflow( +; CHECK-NEXT: [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1 +; CHECK-NEXT: [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0 +; CHECK-NEXT: call void @llvm.assume(i1 [[OVERFLOW]]) +; CHECK-NEXT: ret i8 [[RET]] +; + %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + call void @llvm.assume(i1 %overflow) + ret i8 %ret +} + +define i8 @neg_assume_not_guaranteed_to_execute(i8 noundef %a, i8 noundef %b, i1 %cond) { +; CHECK-LABEL: @neg_assume_not_guaranteed_to_execute( +; CHECK-NEXT: [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]]) +; CHECK-NEXT: br i1 [[COND:%.*]], label [[BB1:%.*]], label [[BB2:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1 +; CHECK-NEXT: [[NOT:%.*]] = xor i1 [[OVERFLOW]], true +; CHECK-NEXT: call void @llvm.assume(i1 [[NOT]]) +; CHECK-NEXT: br label [[BB2]] +; CHECK: bb2: +; CHECK-NEXT: [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0 +; CHECK-NEXT: ret i8 [[RET]] +; + %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b) + %overflow = extractvalue { i8, i1 } %call, 1 + %ret = extractvalue { i8, i1 } %call, 0 + br i1 %cond, label %bb1, label %bb2 +bb1: + %not = xor i1 %overflow, true + call void @llvm.assume(i1 %not) + br label %bb2 +bb2: + ret i8 %ret +}