Skip to content

[mlir][SCF] ValueBoundsConstraintSet: Support scf.if (branches) #85895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,26 @@ class ValueBoundsConstraintSet
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);

/// Traverse the IR starting from the given value/dim and populate constraints
/// as long as the stop condition holds. Also process all values/dims that are
/// already on the worklist.
void populateConstraints(Value value, std::optional<int64_t> dim);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of this may clash with #83876. I'm going to rebase this PR when #83876 has been merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll land it soon :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebased all PRs. It would be great if you could also review #86097, then I can start merging PRs. You probably know the codebase best out of all reviewers that I added to that PR.


/// Comparison operator for `ValueBoundsConstraintSet::compare`.
enum ComparisonOperator { LT, LE, EQ, GT, GE };

/// Try to prove that, based on the current state of this constraint set
/// (i.e., without analyzing additional IR or adding new constraints), the
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
///
/// Return "true" if the specified relation between the two values/dims was
/// proven to hold. Return "false" if the specified relation could not be
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
Value rhs, std::optional<int64_t> rhsDim);

/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
Expand Down Expand Up @@ -274,13 +294,13 @@ class ValueBoundsConstraintSet

ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);

/// Populates the constraint set for a value/map without actually computing
/// the bound. Returns the position for the value/map (via the return value
/// and `posOut` output parameter).
int64_t populateConstraintsSet(Value value,
std::optional<int64_t> dim = std::nullopt);
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
int64_t *posOut = nullptr);
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
/// as the stop condition is not satisfied). Also process all values/dims that
/// are already on the worklist. Return the position of the newly added
/// column.
int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);

/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
Expand All @@ -295,14 +315,19 @@ class ValueBoundsConstraintSet
/// value/dimension exists in the constraint set.
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;

/// Return an affine expression that represents column `pos` in the constraint
/// set.
AffineExpr getPosExpr(int64_t pos);

/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
/// "false", a dimension is added. The value/dimension is added to the
/// worklist.
/// worklist if `addToWorklist` is set.
///
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
/// cannot be multiplied. Furthermore, bounds can only be queried for
/// dimensions but not for symbols.
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
bool addToWorklist = true);

/// Insert an anonymous column into the constraint set. The column is not
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
Expand Down
61 changes: 61 additions & 0 deletions mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,66 @@ struct ForOpInterface
}
};

struct IfOpInterface
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {

static void populateBounds(scf::IfOp ifOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];

// Populate constraints for the yielded value (and all values on the
// backward slice, as long as the current stop condition is not satisfied).
cstr.populateConstraints(thenValue, dim);
cstr.populateConstraints(elseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];

// Compare yielded values.
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
if (cstr.compare(thenValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
elseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
} else {
cstr.bound(value) >= thenValue;
cstr.bound(value) <= elseValue;
}
}
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
if (cstr.compare(elseValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
thenValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
} else {
cstr.bound(value) >= elseValue;
cstr.bound(value) <= thenValue;
}
}
}

void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
}

void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, dim, cstr);
}
};

} // namespace
} // namespace scf
} // namespace mlir
Expand All @@ -119,5 +179,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,24 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
ScalableValueBoundsConstraintSet scalableCstr(
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
vscaleMin, vscaleMax);
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
scalableCstr.processWorklist();

// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
// Project out all columns apart from vscale and the starting point
// (value/dim). This should result in constraints in terms of vscale only.
auto projectOutFn = [&](ValueDim p) {
return p.first != scalableCstr.getVscaleValue();
bool isStartingPoint =
p.first == value &&
p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
};
scalableCstr.projectOut(projectOutFn);

assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
"inconsistent mapping state");

// Check that the only symbols left are vscale.
// Check that the only columns left are vscale and the starting point.
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
if (i == pos)
continue;
Expand Down
143 changes: 113 additions & 30 deletions mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,47 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
assertValidValueDim(value, dim);
#endif // NDEBUG

// Check if the value/dim is statically known. In that case, an affine
// constant expression should be returned. This allows us to support
// multiplications with constants. (Multiplications of two columns in the
// constraint set is not supported.)
std::optional<int64_t> constSize = std::nullopt;
auto shapedType = dyn_cast<ShapedType>(value.getType());
if (shapedType) {
// Static dimension: return constant directly.
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
} else {
// Constant index value: return directly.
if (auto constInt = ::getConstantIntValue(value))
return builder.getAffineConstantExpr(*constInt);
constSize = shapedType.getDimSize(*dim);
} else if (auto constInt = ::getConstantIntValue(value)) {
constSize = *constInt;
}

// Dynamic value: add to constraint set.
// If the value/dim is already mapped, return the corresponding expression
// directly.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
if (!valueDimToPosition.contains(valueDim))
(void)insert(value, dim);
int64_t pos = getPos(value, dim);
return pos < cstr.getNumDimVars()
? builder.getAffineDimExpr(pos)
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
if (valueDimToPosition.contains(valueDim)) {
// If it is a constant, return an affine constant expression. Otherwise,
// return an affine expression that represents the respective column in the
// constraint set.
if (constSize)
return builder.getAffineConstantExpr(*constSize);
return getPosExpr(getPos(value, dim));
}

if (constSize) {
// Constant index value/dim: add column to the constraint set, add EQ bound
// and return an affine constant expression without pushing the newly added
// column to the worklist.
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
if (shapedType)
bound(value)[*dim] == *constSize;
else
bound(value) == *constSize;
return builder.getAffineConstantExpr(*constSize);
}

// Dynamic value/dim: insert column to the constraint set and put it on the
// worklist. Return an affine expression that represents the newly inserted
// column in the constraint set.
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
}

AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
Expand All @@ -145,7 +167,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {

int64_t ValueBoundsConstraintSet::insert(Value value,
std::optional<int64_t> dim,
bool isSymbol) {
bool isSymbol, bool addToWorklist) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG
Expand All @@ -160,7 +182,12 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
if (positionToValueDim[i].has_value())
valueDimToPosition[*positionToValueDim[i]] = i;

worklist.push(pos);
if (addToWorklist) {
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
worklist.push(pos);
}

return pos;
}

Expand Down Expand Up @@ -190,6 +217,13 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
return it->second;
}

AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
return pos < cstr.getNumDimVars()
? builder.getAffineDimExpr(pos)
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
}

static Operation *getOwnerOfValue(Value value) {
if (auto bbArg = dyn_cast<BlockArgument>(value))
return bbArg.getOwner()->getParentOp();
Expand Down Expand Up @@ -492,15 +526,16 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(

// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos;
int64_t pos = 0;
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
};

ValueBoundsConstraintSet cstr(
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
cstr.populateConstraintsSet(map, operands, &pos);
pos = cstr.populateConstraints(map, operands);
assert(pos == 0 && "expected `map` is the first column");

// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
Expand All @@ -509,29 +544,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}

int64_t
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
std::optional<int64_t> dim) {
void ValueBoundsConstraintSet::populateConstraints(Value value,
std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG

AffineMap map =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
Builder(value.getContext()).getAffineDimExpr(0));
return populateConstraintsSet(map, {{value, dim}});
// `getExpr` pushes the value/dim onto the worklist (unless it was already
// analyzed).
(void)getExpr(value, dim);
// Process all values/dims on the worklist. This may traverse and analyze
// additional IR, depending the current stop function.
processWorklist();
}

int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
ValueDimList operands,
int64_t *posOut) {
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
ValueDimList operands) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
if (posOut)
*posOut = pos;

// Add map and operands to the constraint set. Dimensions are converted to
// symbols. All operands are added to the worklist.
// symbols. All operands are added to the worklist (unless they were already
// processed).
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
return getExpr(v.first, v.second);
};
Expand Down Expand Up @@ -566,6 +600,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
{{value1, dim1}, {value2, dim2}});
}

bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, Value rhs,
std::optional<int64_t> rhsDim) {
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
// prove that lhs <= rhs. Proof by contradiction: add the inverse
// relation (lhs > rhs) to the constraint set and check if the resulting
// constraint set is "empty" (i.e. has no solution). In that case,
// lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.

// We cannot prove anything if the constraint set is already empty.
if (cstr.isEmpty()) {
LLVM_DEBUG(
llvm::dbgs()
<< "cannot compare value/dims: constraint system is already empty");
return false;
}

// EQ can be expressed as LE and GE.
if (cmp == EQ)
return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);

// Construct inequality. For the above example: lhs > rhs.
// `IntegerRelation` inequalities are expressed in the "flattened" form and
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
if (cmp == LT || cmp == LE) {
++eq[getPos(lhs, lhsDim)];
--eq[getPos(rhs, rhsDim)];
} else if (cmp == GT || cmp == GE) {
--eq[getPos(lhs, lhsDim)];
++eq[getPos(rhs, rhsDim)];
} else {
llvm_unreachable("unsupported comparison operator");
}
if (cmp == LE || cmp == GE)
eq[cstr.getNumDimAndSymbolVars()] -= 1;

// Add inequality to the constraint set and check if it made the constraint
// set empty.
int64_t ineqPos = cstr.getNumInequalities();
cstr.addInequality(eq);
bool isEmpty = cstr.isEmpty();
cstr.removeInequality(ineqPos);
return isEmpty;
}

FailureOr<bool>
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
std::optional<int64_t> dim1,
Expand Down
Loading