Skip to content

Commit b0febfe

Browse files
authored
Expr simplifier: implement prove::isPositive, prove::isNonNegative, prove::isNonZero (#2273)
1 parent 2ca39e3 commit b0febfe

10 files changed

+349
-52
lines changed

third_party/nvfuser/csrc/expr_simplifier.cpp

Lines changed: 164 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <lower_magic_zero.h>
99
#include <utils.h>
1010

11+
#include <cmath>
1112
#include <functional>
1213
#include <list>
1314
#include <memory>
@@ -134,6 +135,46 @@ std::unique_ptr<debug_print::NoOpLogger> createLogger(Val* value) {
134135

135136
namespace {
136137

138+
// An ordered mapping of variable -> VarInfo
139+
class VarInfoMap {
140+
public:
141+
VarInfoMap(const std::list<VarInfo>& variables) {
142+
var_info_map_.reserve(variables.size());
143+
var_order_.reserve(variables.size());
144+
set_.reserve(variables.size());
145+
for (const auto& info : variables) {
146+
var_order_.emplace_back(info.variable);
147+
set_.emplace(info.variable);
148+
var_info_map_[info.variable] = info;
149+
}
150+
}
151+
152+
// Get the order of variables
153+
const std::vector<Val*>& order() const {
154+
return var_order_;
155+
}
156+
157+
// Get the info of a variable
158+
const VarInfo& info(Val* var) const {
159+
return var_info_map_.at(var);
160+
}
161+
162+
// Check if this mapping has information about the given variable
163+
bool has(Val* var) const {
164+
return set_.count(var);
165+
}
166+
167+
// Get the set of variables that has information
168+
const std::unordered_set<Val*>& set() const {
169+
return set_;
170+
}
171+
172+
private:
173+
std::unordered_map<Val*, VarInfo> var_info_map_;
174+
std::vector<Val*> var_order_;
175+
std::unordered_set<Val*> set_;
176+
};
177+
137178
bool hasSimilarType(DataType t1, DataType t2) {
138179
if (t1 == t2) {
139180
return true;
@@ -464,22 +505,17 @@ class FlattenedAssocCommOp : public Expr {
464505
// and b < c. So in this example, this function will return [v2, v1].
465506
// Tensors are always considered as variables and they are always considered
466507
// as the rightmost.
467-
std::vector<Val*> sortedInputs(const std::list<ValInfo>& variables) {
468-
std::unordered_set<Val*> variables_set;
469-
variables_set.reserve(variables.size());
470-
for (auto v : variables) {
471-
variables_set.emplace(v.variable);
472-
}
508+
std::vector<Val*> sortedInputs(const VarInfoMap& var_info) {
473509
std::vector<Val*> sorted_inputs(inputs().begin(), inputs().end());
474510
std::unordered_map<Val*, std::unordered_set<Val*>> dependency;
475511
dependency.reserve(sorted_inputs.size());
476512
for (auto v : sorted_inputs) {
477-
dependency[v] = getSubexprDependency(v, variables_set);
513+
dependency[v] = getSubexprDependency(v, var_info.set());
478514
}
479515
auto compare = [&](Val* v1, Val* v2) {
480-
// Find all variables in variables_set that v1 and v2 depends on. The
481-
// input (v1 or v2) that exclusively has the right most variable in
482-
// variables_set will be to the right of the other input.
516+
// Find all variables in var_info that v1 and v2 depends on. The input (v1
517+
// or v2) that exclusively has the right most variable in var_info.order()
518+
// will be to the right of the other input.
483519
bool v1_is_left_of_v2 = false;
484520
auto deps1 = dependency.at(v1);
485521
auto deps2 = dependency.at(v2);
@@ -489,11 +525,10 @@ class FlattenedAssocCommOp : public Expr {
489525
if (hasTensor(deps1)) {
490526
return false;
491527
}
492-
for (auto v : variables) {
493-
if (deps1.count(v.variable) > 0 && deps2.count(v.variable) == 0) {
528+
for (auto v : var_info.order()) {
529+
if (deps1.count(v) > 0 && deps2.count(v) == 0) {
494530
v1_is_left_of_v2 = false;
495-
} else if (
496-
deps2.count(v.variable) > 0 && deps1.count(v.variable) == 0) {
531+
} else if (deps2.count(v) > 0 && deps1.count(v) == 0) {
497532
v1_is_left_of_v2 = true;
498533
}
499534
}
@@ -642,9 +677,9 @@ Val* flatten(Val* value) {
642677

643678
// Recursively convert expressions like FlattenedAdd(a, b, c, d) into
644679
// AddOp(AddOp(AddOp(a, b), c), d))
645-
Val* unflatten(Val* value, const std::list<ValInfo>& variables);
680+
Val* unflatten(Val* value, const VarInfoMap& var_info);
646681

647-
Val* unflattenRule(Val* value, const std::list<ValInfo>& variables) {
682+
Val* unflattenRule(Val* value, const VarInfoMap& var_info) {
648683
auto def = value->definition();
649684
if (def == nullptr) {
650685
return value;
@@ -663,14 +698,14 @@ Val* unflattenRule(Val* value, const std::list<ValInfo>& variables) {
663698
// Handle flattened op:
664699
// Convert flattened op into original binary ops
665700
TORCH_INTERNAL_ASSERT(fop->inputs().size() >= 2);
666-
auto sorted_inputs = fop->sortedInputs(variables);
701+
auto sorted_inputs = fop->sortedInputs(var_info);
667702
// We need to recursively unflatten all inputs, because we might have
668703
// nested flattened expressions like
669704
// FlattenedAdd(a, b, FlattenedMul(c, d, e))
670-
Val* lhs = unflatten(sorted_inputs.at(0), variables);
705+
Val* lhs = unflatten(sorted_inputs.at(0), var_info);
671706
int64_t next = 1;
672-
while (next < (int)sorted_inputs.size()) {
673-
auto rhs = unflatten(sorted_inputs.at(next), variables);
707+
while (next < (int64_t)sorted_inputs.size()) {
708+
auto rhs = unflatten(sorted_inputs.at(next), var_info);
674709
auto output = IrBuilder::newScalar(*value->getDataType());
675710
IrBuilder::create<BinaryOp>(fop->getOpType(), output, lhs, rhs);
676711
lhs = output;
@@ -681,10 +716,9 @@ Val* unflattenRule(Val* value, const std::list<ValInfo>& variables) {
681716
return value;
682717
}
683718

684-
Val* unflatten(Val* value, const std::list<ValInfo>& variables) {
685-
using namespace std::placeholders;
719+
Val* unflatten(Val* value, const VarInfoMap& var_info) {
686720
return recurseDown(
687-
value, [&variables](Val* val) { return unflattenRule(val, variables); });
721+
value, [&var_info](Val* val) { return unflattenRule(val, var_info); });
688722
}
689723

690724
} // namespace assoc_comm
@@ -954,35 +988,53 @@ namespace prove {
954988
// - x can be either zero or non-zero, it is just a symbolic number that depends
955989
// - x is zero
956990

957-
bool isNonNegative(Val* value) {
991+
bool isPositive(Val* value, const VarInfoMap& var_info);
992+
993+
bool isNonNegative(Val* value, const VarInfoMap& var_info) {
958994
value = foldConstants(value);
959995
if (value->getInt().has_value() && *value->getInt() >= 0) {
960996
return true;
961997
}
962998
if (value->getDouble().has_value() && *value->getDouble() >= 0.0) {
963999
return true;
9641000
}
1001+
if (isPositive(value, var_info)) {
1002+
return true;
1003+
}
9651004
if (auto ns = dynamic_cast<NamedScalar*>(value)) {
9661005
if (ns->getParallelDim().has_value() ||
967-
ns->getParallelIndex().has_value()) {
1006+
ns->getParallelIndex().has_value() || ns->isTensorSize() ||
1007+
ns->isTensorStride()) {
9681008
return true;
9691009
}
9701010
}
1011+
value = maybeUnwrapMagicZero(value);
1012+
if (var_info.has(value)) {
1013+
return isNonNegative(var_info.info(value).start, var_info) &&
1014+
isNonNegative(var_info.info(value).step, var_info);
1015+
}
9711016
if (auto fop = dynamic_cast<FOp*>(value->definition())) {
9721017
auto op = fop->getOpType();
9731018
if (op == BinaryOpType::Add || op == BinaryOpType::Mul) {
9741019
for (auto inp : fop->inputs()) {
975-
if (!isNonNegative(inp)) {
1020+
if (!isNonNegative(inp, var_info)) {
9761021
return false;
9771022
}
9781023
}
9791024
return true;
9801025
}
1026+
} else if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
1027+
auto op = bop->getBinaryOpType();
1028+
if (op == BinaryOpType::Mod || op == BinaryOpType::Div ||
1029+
op == BinaryOpType::CeilDiv) {
1030+
return isNonNegative(bop->lhs(), var_info) &&
1031+
isPositive(bop->rhs(), var_info);
1032+
}
9811033
}
9821034
return false;
9831035
}
9841036

985-
bool isPositive(Val* value) {
1037+
bool isPositive(Val* value, const VarInfoMap& var_info) {
9861038
value = foldConstants(value);
9871039
if (value->getInt().has_value() && *value->getInt() > 0) {
9881040
return true;
@@ -1000,16 +1052,16 @@ bool isPositive(Val* value) {
10001052
if (op == BinaryOpType::Add) {
10011053
bool has_positive = false;
10021054
for (auto inp : fop->inputs()) {
1003-
if (isPositive(inp)) {
1055+
if (isPositive(inp, var_info)) {
10041056
has_positive = true;
1005-
} else if (!isNonNegative(inp)) {
1057+
} else if (!isNonNegative(inp, var_info)) {
10061058
return false;
10071059
}
10081060
}
10091061
return has_positive;
10101062
} else if (op == BinaryOpType::Mul) {
10111063
for (auto inp : fop->inputs()) {
1012-
if (!isPositive(inp)) {
1064+
if (!isPositive(inp, var_info)) {
10131065
return false;
10141066
}
10151067
}
@@ -1019,20 +1071,20 @@ bool isPositive(Val* value) {
10191071
return false;
10201072
}
10211073

1022-
bool isNonZero(Val* value) {
1074+
bool isNonZero(Val* value, const VarInfoMap& var_info) {
10231075
value = foldConstants(value);
10241076
if (value->getInt().has_value() && *value->getInt() != 0) {
10251077
return true;
10261078
}
10271079
if (value->getDouble().has_value() && *value->getDouble() != 0.0) {
10281080
return true;
10291081
}
1030-
if (isPositive(value)) {
1082+
if (isPositive(value, var_info)) {
10311083
return true;
10321084
}
10331085
if (auto fop = toFlattenedMul(value->definition())) {
10341086
for (auto inp : fop->inputs()) {
1035-
if (!isNonZero(inp)) {
1087+
if (!isNonZero(inp, var_info)) {
10361088
return false;
10371089
}
10381090
}
@@ -1064,7 +1116,7 @@ namespace rules {
10641116
// a / 1 -> a
10651117
// x - x -> 0
10661118
// ...
1067-
Val* eliminateTrivialComputation(Val* value) {
1119+
Val* eliminateTrivialComputation(Val* value, const VarInfoMap& var_info) {
10681120
auto folded = foldConstants(value);
10691121
if (folded != value) {
10701122
return folded;
@@ -1133,7 +1185,8 @@ Val* eliminateTrivialComputation(Val* value) {
11331185
auto lhs = foldConstants(bop->lhs());
11341186
auto rhs = foldConstants(bop->rhs());
11351187
bool divisor_is_1 = (rhs->getInt() == 1 || rhs->getDouble() == 1.0);
1136-
if (divisor_is_1 || (prove::isNonZero(rhs) && lhs->getInt() == 0)) {
1188+
if (divisor_is_1 ||
1189+
(prove::isNonZero(rhs, var_info) && lhs->getInt() == 0)) {
11371190
return lhs;
11381191
}
11391192
} else if (bop->getBinaryOpType() == BinaryOpType::Sub) {
@@ -1147,12 +1200,79 @@ Val* eliminateTrivialComputation(Val* value) {
11471200
return value;
11481201
}
11491202

1203+
// If x can be proved to be non-negative, then replace x >= 0 as true, replace
1204+
// x < 0 as false
1205+
// If x can be proved to be positive, then replace x >= 0 and x > 0 as true,
1206+
// replace x <= 0 and x < 0 as false
1207+
// If x can be proved to be nonzero, then replace x != 0 as true, replace x == 0
1208+
// as false
1209+
// if x->sameAs(y), then replace x == y as true, replace x != y as false
1210+
Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
1211+
if (!value->isABool()) {
1212+
return value;
1213+
}
1214+
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
1215+
auto op = bop->getBinaryOpType();
1216+
if (bop->lhs()->sameAs(bop->rhs())) {
1217+
if (op == BinaryOpType::Eq) {
1218+
return value->fusion()->trueVal();
1219+
} else if (op == BinaryOpType::NE) {
1220+
return value->fusion()->falseVal();
1221+
}
1222+
}
1223+
if (bop->rhs()->isZeroInt()) {
1224+
if (op == BinaryOpType::GE &&
1225+
prove::isNonNegative(bop->lhs(), var_info)) {
1226+
return value->fusion()->trueVal();
1227+
} else if (
1228+
op == BinaryOpType::GT && prove::isPositive(bop->lhs(), var_info)) {
1229+
return value->fusion()->trueVal();
1230+
} else if (
1231+
op == BinaryOpType::NE && prove::isNonZero(bop->lhs(), var_info)) {
1232+
return value->fusion()->trueVal();
1233+
} else if (
1234+
op == BinaryOpType::LT &&
1235+
prove::isNonNegative(bop->lhs(), var_info)) {
1236+
return value->fusion()->falseVal();
1237+
} else if (
1238+
op == BinaryOpType::LE && prove::isPositive(bop->lhs(), var_info)) {
1239+
return value->fusion()->falseVal();
1240+
} else if (
1241+
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) {
1242+
return value->fusion()->falseVal();
1243+
}
1244+
} else if (bop->lhs()->isZeroInt()) {
1245+
if (op == BinaryOpType::LE &&
1246+
prove::isNonNegative(bop->rhs(), var_info)) {
1247+
return value->fusion()->trueVal();
1248+
} else if (
1249+
op == BinaryOpType::LT && prove::isPositive(bop->rhs(), var_info)) {
1250+
return value->fusion()->trueVal();
1251+
} else if (
1252+
op == BinaryOpType::NE && prove::isNonZero(bop->rhs(), var_info)) {
1253+
return value->fusion()->trueVal();
1254+
} else if (
1255+
op == BinaryOpType::GT &&
1256+
prove::isNonNegative(bop->rhs(), var_info)) {
1257+
return value->fusion()->falseVal();
1258+
} else if (
1259+
op == BinaryOpType::GE && prove::isPositive(bop->rhs(), var_info)) {
1260+
return value->fusion()->falseVal();
1261+
} else if (
1262+
op == BinaryOpType::Eq && prove::isNonZero(bop->rhs(), var_info)) {
1263+
return value->fusion()->falseVal();
1264+
}
1265+
}
1266+
}
1267+
return value;
1268+
}
1269+
11501270
// Apply rule L to replace x % y with 0 if x can be proved to be a multiple of y
11511271
// Also, according to rule M, if x can be factorized as x = k * y, then x / y
11521272
// can be simplified as x / y = (k * y) / y = k * (y / y) = k
1153-
Val* simplifyDivisibleDivMod(Val* value) {
1273+
Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
11541274
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
1155-
if (prove::isNonZero(bop->rhs())) {
1275+
if (prove::isNonZero(bop->rhs(), var_info)) {
11561276
if (bop->getBinaryOpType() == BinaryOpType::Mod) {
11571277
if (prove::isMultipleOf(bop->lhs(), bop->rhs())) {
11581278
return IrBuilder::newConstant(0, *value->getDataType());
@@ -1172,12 +1292,15 @@ Val* simplifyDivisibleDivMod(Val* value) {
11721292

11731293
} // namespace rules
11741294

1175-
#define RUN_PASS(pass_name) \
1176-
simplified = recurseDown(simplified, rules::pass_name); \
1295+
#define RUN_PASS(pass_name) \
1296+
simplified = recurseDown(simplified, [&var_info](Val* val) { \
1297+
return rules::pass_name(val, var_info); \
1298+
}); \
11771299
logger->record(#pass_name, simplified)
11781300

1179-
Val* simplifyExpr(Val* value, const std::list<ValInfo>& variables) {
1301+
Val* simplifyExpr(Val* value, const std::list<VarInfo>& variables) {
11801302
FusionGuard fg(value->fusion());
1303+
const VarInfoMap var_info(variables);
11811304
auto logger = debug_print::createLogger(value);
11821305

11831306
auto simplified = assoc_comm::flatten(value);
@@ -1187,6 +1310,7 @@ Val* simplifyExpr(Val* value, const std::list<ValInfo>& variables) {
11871310
while (old_simplified != simplified) {
11881311
old_simplified = simplified;
11891312
RUN_PASS(eliminateTrivialComputation);
1313+
RUN_PASS(eliminateTrivialPredicate);
11901314
RUN_PASS(simplifyDivisibleDivMod);
11911315
}
11921316

0 commit comments

Comments
 (0)