Skip to content

[ValueTracking] Use assume to compute overflowResult. #121665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7177,10 +7177,33 @@ llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
return CR1.intersectWith(CR2, RangeType);
}

static bool isKnownToNotOverflowFromAssume(const SimplifyQuery &Q) {
// Use of assumptions is context-sensitive. If we don't have a context, we
// cannot use them!
if (!Q.AC || !Q.CxtI)
return false;

for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(Q.CxtI)) {
if (!Elem.Assume)
continue;

AssumeInst *I = cast<AssumeInst>(Elem.Assume);
if (match(I->getArgOperand(0),
m_Not(m_ExtractValue<1>(m_Specific(Q.CxtI)))) &&
isValidAssumeForContext(I, Q.CxtI, /*DT=*/nullptr,
/*AllowEphemerals=*/true))
return true;
}
return false;
}

OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ,
bool IsNSW) {
if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);

Expand All @@ -7196,6 +7219,9 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

// Multiplying n * m significant bits yields a result of n + m significant
// bits. If the total number of significant bits does not exceed the
// result bit width (minus 1), there is no overflow.
Expand Down Expand Up @@ -7236,6 +7262,9 @@ OverflowResult
llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
const WithCache<const Value *> &RHS,
const SimplifyQuery &SQ) {
if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
ConstantRange RHSRange =
Expand All @@ -7251,6 +7280,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
return OverflowResult::NeverOverflows;
}

if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

// If LHS and RHS each have at least two sign bits, the addition will look
// like
//
Expand Down Expand Up @@ -7305,6 +7337,9 @@ computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

// X - (X % ?)
// The remainder of a value can't have greater magnitude than itself,
// so the subtraction can't overflow.
Expand Down Expand Up @@ -7338,6 +7373,9 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
if (isKnownToNotOverflowFromAssume(SQ))
return OverflowResult::NeverOverflows;

// X - (X % ?)
// The remainder of a value can't have greater magnitude than itself,
// so the subtraction can't overflow.
Expand Down Expand Up @@ -10180,6 +10218,9 @@ void llvm::findValuesAffectedByCondition(
m_Value()))) {
// Handle patterns that computeKnownFPClass() support.
AddAffected(A);
} else if (IsAssume && match(V, m_Not(m_ExtractValue<1>(m_Value(A)))) &&
isa<WithOverflowInst>(A)) {
AddAffected(A);
}
}
}
129 changes: 129 additions & 0 deletions llvm/test/Transforms/InstCombine/with_overflow.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1064,3 +1064,132 @@ define i8 @smul_7(i8 %x, ptr %p) {
store i1 %ov, ptr %p
ret i8 %r
}

define i8 @uadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @uadd_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = add nuw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.uadd.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i8 @sadd_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @sadd_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = add nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.sadd.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i8 @usub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @usub_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = sub nuw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.usub.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i8 @ssub_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @ssub_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = sub nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.ssub.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i8 @umul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @umul_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = mul nuw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i8 @smul_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @smul_assume_no_overflow(
; CHECK-NEXT: [[CALL:%.*]] = mul nsw i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret i8 [[CALL]]
;
%call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i8 %ret
}

define i1 @ephemeral_call_assume_no_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @ephemeral_call_assume_no_overflow(
; CHECK-NEXT: ret i1 true
;
%call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
ret i1 %not
}

define i8 @neg_assume_overflow(i8 noundef %a, i8 noundef %b) {
; CHECK-LABEL: @neg_assume_overflow(
; CHECK-NEXT: [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
; CHECK-NEXT: [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
; CHECK-NEXT: call void @llvm.assume(i1 [[OVERFLOW]])
; CHECK-NEXT: ret i8 [[RET]]
;
%call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
call void @llvm.assume(i1 %overflow)
ret i8 %ret
}

define i8 @neg_assume_not_guaranteed_to_execute(i8 noundef %a, i8 noundef %b, i1 %cond) {
; CHECK-LABEL: @neg_assume_not_guaranteed_to_execute(
; CHECK-NEXT: [[CALL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[A:%.*]], i8 [[B:%.*]])
; CHECK-NEXT: br i1 [[COND:%.*]], label [[BB1:%.*]], label [[BB2:%.*]]
; CHECK: bb1:
; CHECK-NEXT: [[OVERFLOW:%.*]] = extractvalue { i8, i1 } [[CALL]], 1
; CHECK-NEXT: [[NOT:%.*]] = xor i1 [[OVERFLOW]], true
; CHECK-NEXT: call void @llvm.assume(i1 [[NOT]])
; CHECK-NEXT: br label [[BB2]]
; CHECK: bb2:
; CHECK-NEXT: [[RET:%.*]] = extractvalue { i8, i1 } [[CALL]], 0
; CHECK-NEXT: ret i8 [[RET]]
;
%call = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 %a, i8 %b)
%overflow = extractvalue { i8, i1 } %call, 1
%ret = extractvalue { i8, i1 } %call, 0
br i1 %cond, label %bb1, label %bb2
bb1:
%not = xor i1 %overflow, true
call void @llvm.assume(i1 %not)
br label %bb2
bb2:
ret i8 %ret
}
Loading