Skip to content

[ValueTracking] Use assume to compute overflowResult. #121665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

andjo403
Copy link
Contributor

@andjo403 andjo403 commented Jan 4, 2025

Rust code have a lot of assumes with the overflow flag from WithOverflow Instructions as condition, this instruction can due to the assume be folded to the operation without overflow check.

e.g. the pattern that this PR look for:

  %call = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
  %overflow = extractvalue { i8, i1 } %call, 1
  %not = xor i1 %overflow, true
  call void @llvm.assume(i1 %not)

that can be folded to:
%call = add nuw i8 %a, %b

It is the hashbrown crate that have e.g. this match with use of hint::unreachable_unchecked() together with the use of checked instructions in calculate_layout_for that after inlining and simplifyCfg result in this pattern.

Prof: https://alive2.llvm.org/ce/z/oRu3oi

@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Andreas Jonson (andjo403)

Changes

Rust code have a lot of assumes with the overflow flag from WithOverflow Instructions as condition, this instruction can due to the assume be folded to the operation without overflow check.

e.g. the pattern that this PR look for:

  %call = call { i8, i1 } @<!-- -->llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
  %overflow = extractvalue { i8, i1 } %call, 1
  %not = xor i1 %overflow, true
  call void @<!-- -->llvm.assume(i1 %not)

that can be folded to:
%call = add nuw i8 %a, %b

It is the hashbrown crate that have e.g. this match with use of hint::unreachable_unchecked() together with the use of checked instructions in calculate_layout_for that after inlining and simplifyCfg result in this pattern.

Prof: https://alive2.llvm.org/ce/z/oRu3oi


Full diff: https://github.com/llvm/llvm-project/pull/121665.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+41)
  • (modified) llvm/test/Transforms/InstCombine/with_overflow.ll (+129)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2f6e869ae7b735..51b3627056fbb6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7177,10 +7177,33 @@ llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
   return CR1.intersectWith(CR2, RangeType);
 }
 
+static bool isKnownToNotOverflowFromAssume(const SimplifyQuery &Q) {
+  // Use of assumptions is context-sensitive. If we don't have a context, we
+  // cannot use them!
+  if (!Q.AC || !Q.CxtI)
+    return false;
+
+  for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(Q.CxtI)) {
+    if (!Elem.Assume)
+      continue;
+
+    AssumeInst *I = cast<AssumeInst>(Elem.Assume);
+    if (match(I->getArgOperand(0),
+              m_Not(m_ExtractValue<1>(m_Specific(Q.CxtI)))) &&
+        isValidAssumeForContext(I, Q.CxtI, /*DT=*/nullptr,
+                                /*AllowEphemerals=*/true))
+      return true;
+  }
+  return false;
+}
+
 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
                                                    const Value *RHS,
                                                    const SimplifyQuery &SQ,
                                                    bool IsNSW) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
   KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
 
@@ -7196,6 +7219,9 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
 OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
                                                  const Value *RHS,
                                                  const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -7236,6 +7262,9 @@ OverflowResult
 llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
                                     const WithCache<const Value *> &RHS,
                                     const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   ConstantRange LHSRange =
       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
   ConstantRange RHSRange =
@@ -7251,6 +7280,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
     return OverflowResult::NeverOverflows;
   }
 
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // If LHS and RHS each have at least two sign bits, the addition will look
   // like
   //
@@ -7305,6 +7337,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
                                                    const Value *RHS,
                                                    const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -7338,6 +7373,9 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
                                                  const Value *RHS,
                                                  const SimplifyQuery &SQ) {
+  if (isKnownToNotOverflowFromAssume(SQ))
+    return OverflowResult::NeverOverflows;
+
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -10180,6 +10218,9 @@ void llvm::findValuesAffectedByCondition(
                                                            m_Value()))) {
       // Handle patterns that computeKnownFPClass() support.
       AddAffected(A);
+    } else if (IsAssume && match(V, m_Not(m_ExtractValue<1>(m_Value(A)))) &&
+               isa<WithOverflowInst>(A)) {
+      AddAffected(A);
     }
   }
 }
diff --git a/llvm/test/Transforms/InstCombine/with_overflow.ll b/llvm/test/Transforms/InstCombine/with_overflow.ll
index fa810408730e1b..dd64e993bb620c 100644
--- a/llvm/test/Transforms/InstCombine/with_overflow.ll
+++ b/llvm/test/Transforms/InstCombine/with_overflow.ll
@@ -1064,3 +1064,132 @@ define i8 @smul_7(i8 %x, ptr %p) {
   store i1 %ov, ptr %p
   ret i8 %r
 }
+
+define i8 @uadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @uadd_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = add nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @sadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @sadd_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = add nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.sadd.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @usub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @usub_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = sub nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.usub.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @ssub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @ssub_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.ssub.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @umul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @umul_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = mul nuw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i8 @smul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @smul_assume_no_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = mul nsw i8 [[A:%.*]], [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[CALL]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i8 %ret
+}
+
+define i1 @ephemeral_call_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @ephemeral_call_assume_no_overflow(
+; CHECK-NEXT:    ret i1 true
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  ret i1 %not
+}
+
+define i8 @neg_assume_overflow(i8 noundef %a, i8 noundef %b) {
+; CHECK-LABEL: @neg_assume_overflow(
+; CHECK-NEXT:    [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
+; CHECK-NEXT:    [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
+; CHECK-NEXT:    call void @llvm.assume(i1 [[OVERFLOW]])
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  call void @llvm.assume(i1 %overflow)
+  ret i8 %ret
+}
+
+define i8 @neg_assume_not_guaranteed_to_execute(i8 noundef %a, i8 noundef %b, i1 %cond) {
+; CHECK-LABEL: @neg_assume_not_guaranteed_to_execute(
+; CHECK-NEXT:    [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
+; CHECK-NEXT:    br i1 [[COND:%.*]], label [[BB1:%.*]], label [[BB2:%.*]]
+; CHECK:       bb1:
+; CHECK-NEXT:    [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
+; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[OVERFLOW]], true
+; CHECK-NEXT:    call void @llvm.assume(i1 [[NOT]])
+; CHECK-NEXT:    br label [[BB2]]
+; CHECK:       bb2:
+; CHECK-NEXT:    [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
+  %overflow = extractvalue { i8, i1 } %call, 1
+  %ret = extractvalue { i8, i1 } %call, 0
+  br i1 %cond, label %bb1, label %bb2
+bb1:
+  %not = xor i1 %overflow, true
+  call void @llvm.assume(i1 %not)
+  br label %bb2
+bb2:
+  ret i8 %ret
+}

@nikic
Copy link
Contributor

nikic commented Jan 4, 2025

I think this is trying to fix the same issue as #84016.

@andjo403
Copy link
Contributor Author

andjo403 commented Jan 4, 2025

yes it is the same did not find it before as I searched for assume not assumption :(

@andjo403 andjo403 closed this Jan 4, 2025
@andjo403 andjo403 deleted the overflowResultFromAssume branch January 4, 2025 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants