Skip to content

Commit cead7ad

Browse files
authored
Expr simplifier: simplification passes for matmul (#2275)
1 parent 7b1ab9a commit cead7ad

File tree

3 files changed

+419
-162
lines changed

3 files changed

+419
-162
lines changed

third_party/nvfuser/csrc/expr_simplifier.cpp

Lines changed: 252 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -728,10 +728,12 @@ namespace {
728728
using FOp = assoc_comm::FlattenedAssocCommOp;
729729

730730
FOp* toFlattenedAdd(Expr* expr) {
731-
if (auto fop = dynamic_cast<FOp*>(expr)) {
732-
if (fop->getOpType() == BinaryOpType::Add) {
733-
return fop;
734-
}
731+
auto fop = dynamic_cast<FOp*>(expr);
732+
if (!fop) {
733+
return nullptr;
734+
}
735+
if (fop->getOpType() == BinaryOpType::Add) {
736+
return fop;
735737
}
736738
return nullptr;
737739
}
@@ -741,10 +743,12 @@ bool isFlattenedAdd(Val* x) {
741743
}
742744

743745
FOp* toFlattenedMul(Expr* expr) {
744-
if (auto fop = dynamic_cast<FOp*>(expr)) {
745-
if (fop->getOpType() == BinaryOpType::Mul) {
746-
return fop;
747-
}
746+
auto fop = dynamic_cast<FOp*>(expr);
747+
if (!fop) {
748+
return nullptr;
749+
}
750+
if (fop->getOpType() == BinaryOpType::Mul) {
751+
return fop;
748752
}
749753
return nullptr;
750754
}
@@ -753,6 +757,19 @@ bool isFlattenedMul(Val* x) {
753757
return toFlattenedMul(x->definition()) != nullptr;
754758
}
755759

760+
BinaryOp* toDivModOp(Expr* expr) {
761+
auto bop = dynamic_cast<BinaryOp*>(expr);
762+
if (!bop) {
763+
return nullptr;
764+
}
765+
if (bop->getBinaryOpType() == BinaryOpType::Div ||
766+
bop->getBinaryOpType() == BinaryOpType::Mod) {
767+
// TODO: Add CeilDiv as well? Need mathematiclly prove its rules first
768+
return bop;
769+
}
770+
return nullptr;
771+
}
772+
756773
// Classify terms of a FlattenedMul as (constant, symbolic), for example:
757774
// a * 3 * b * 5 --> (15, {a, b})
758775
// a * b --> (1, {a, b})
@@ -1101,6 +1118,10 @@ bool isMultipleOf(Val* x, Val* y) {
11011118
return sym_algebra::divideFactorized(lhs, rhs) != nullptr;
11021119
}
11031120

1121+
bool hasCompatibleSign(Val* x, Val* y, const VarInfoMap& var_info) {
1122+
return isNonNegative(x, var_info) && isNonNegative(y, var_info);
1123+
}
1124+
11041125
} // namespace prove
11051126

11061127
namespace rules {
@@ -1211,57 +1232,55 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
12111232
if (!value->isABool()) {
12121233
return value;
12131234
}
1214-
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
1215-
auto op = bop->getBinaryOpType();
1216-
if (bop->lhs()->sameAs(bop->rhs())) {
1217-
if (op == BinaryOpType::Eq) {
1218-
return value->fusion()->trueVal();
1219-
} else if (op == BinaryOpType::NE) {
1220-
return value->fusion()->falseVal();
1221-
}
1235+
auto bop = dynamic_cast<BinaryOp*>(value->definition());
1236+
if (!bop) {
1237+
return value;
1238+
}
1239+
auto op = bop->getBinaryOpType();
1240+
if (bop->lhs()->sameAs(bop->rhs())) {
1241+
if (op == BinaryOpType::Eq) {
1242+
return value->fusion()->trueVal();
1243+
} else if (op == BinaryOpType::NE) {
1244+
return value->fusion()->falseVal();
12221245
}
1223-
if (bop->rhs()->isZeroInt()) {
1224-
if (op == BinaryOpType::GE &&
1225-
prove::isNonNegative(bop->lhs(), var_info)) {
1226-
return value->fusion()->trueVal();
1227-
} else if (
1228-
op == BinaryOpType::GT && prove::isPositive(bop->lhs(), var_info)) {
1229-
return value->fusion()->trueVal();
1230-
} else if (
1231-
op == BinaryOpType::NE && prove::isNonZero(bop->lhs(), var_info)) {
1232-
return value->fusion()->trueVal();
1233-
} else if (
1234-
op == BinaryOpType::LT &&
1235-
prove::isNonNegative(bop->lhs(), var_info)) {
1236-
return value->fusion()->falseVal();
1237-
} else if (
1238-
op == BinaryOpType::LE && prove::isPositive(bop->lhs(), var_info)) {
1239-
return value->fusion()->falseVal();
1240-
} else if (
1241-
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) {
1242-
return value->fusion()->falseVal();
1243-
}
1244-
} else if (bop->lhs()->isZeroInt()) {
1245-
if (op == BinaryOpType::LE &&
1246-
prove::isNonNegative(bop->rhs(), var_info)) {
1247-
return value->fusion()->trueVal();
1248-
} else if (
1249-
op == BinaryOpType::LT && prove::isPositive(bop->rhs(), var_info)) {
1250-
return value->fusion()->trueVal();
1251-
} else if (
1252-
op == BinaryOpType::NE && prove::isNonZero(bop->rhs(), var_info)) {
1253-
return value->fusion()->trueVal();
1254-
} else if (
1255-
op == BinaryOpType::GT &&
1256-
prove::isNonNegative(bop->rhs(), var_info)) {
1257-
return value->fusion()->falseVal();
1258-
} else if (
1259-
op == BinaryOpType::GE && prove::isPositive(bop->rhs(), var_info)) {
1260-
return value->fusion()->falseVal();
1261-
} else if (
1262-
op == BinaryOpType::Eq && prove::isNonZero(bop->rhs(), var_info)) {
1263-
return value->fusion()->falseVal();
1264-
}
1246+
}
1247+
if (bop->rhs()->isZeroInt()) {
1248+
if (op == BinaryOpType::GE && prove::isNonNegative(bop->lhs(), var_info)) {
1249+
return value->fusion()->trueVal();
1250+
} else if (
1251+
op == BinaryOpType::GT && prove::isPositive(bop->lhs(), var_info)) {
1252+
return value->fusion()->trueVal();
1253+
} else if (
1254+
op == BinaryOpType::NE && prove::isNonZero(bop->lhs(), var_info)) {
1255+
return value->fusion()->trueVal();
1256+
} else if (
1257+
op == BinaryOpType::LT && prove::isNonNegative(bop->lhs(), var_info)) {
1258+
return value->fusion()->falseVal();
1259+
} else if (
1260+
op == BinaryOpType::LE && prove::isPositive(bop->lhs(), var_info)) {
1261+
return value->fusion()->falseVal();
1262+
} else if (
1263+
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) {
1264+
return value->fusion()->falseVal();
1265+
}
1266+
} else if (bop->lhs()->isZeroInt()) {
1267+
if (op == BinaryOpType::LE && prove::isNonNegative(bop->rhs(), var_info)) {
1268+
return value->fusion()->trueVal();
1269+
} else if (
1270+
op == BinaryOpType::LT && prove::isPositive(bop->rhs(), var_info)) {
1271+
return value->fusion()->trueVal();
1272+
} else if (
1273+
op == BinaryOpType::NE && prove::isNonZero(bop->rhs(), var_info)) {
1274+
return value->fusion()->trueVal();
1275+
} else if (
1276+
op == BinaryOpType::GT && prove::isNonNegative(bop->rhs(), var_info)) {
1277+
return value->fusion()->falseVal();
1278+
} else if (
1279+
op == BinaryOpType::GE && prove::isPositive(bop->rhs(), var_info)) {
1280+
return value->fusion()->falseVal();
1281+
} else if (
1282+
op == BinaryOpType::Eq && prove::isNonZero(bop->rhs(), var_info)) {
1283+
return value->fusion()->falseVal();
12651284
}
12661285
}
12671286
return value;
@@ -1271,25 +1290,160 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
12711290
// Also, according to rule M, if x can be factorized as x = k * y, then x / y
12721291
// can be simplified as x / y = (k * y) / y = k * (y / y) = k
12731292
Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
1274-
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
1275-
if (prove::isNonZero(bop->rhs(), var_info)) {
1276-
if (bop->getBinaryOpType() == BinaryOpType::Mod) {
1277-
if (prove::isMultipleOf(bop->lhs(), bop->rhs())) {
1278-
return IrBuilder::newConstant(0, *value->getDataType());
1279-
}
1280-
} else if (bop->getBinaryOpType() == BinaryOpType::Div) {
1281-
auto lhs = sym_algebra::factorize(bop->lhs());
1282-
auto rhs = sym_algebra::factorize(bop->rhs());
1283-
auto quotient = sym_algebra::divideFactorized(lhs, rhs);
1284-
if (quotient != nullptr) {
1285-
return quotient;
1286-
}
1293+
auto bop = dynamic_cast<BinaryOp*>(value->definition());
1294+
if (!bop) {
1295+
return value;
1296+
}
1297+
if (!prove::isNonZero(bop->rhs(), var_info)) {
1298+
return value;
1299+
}
1300+
if (bop->getBinaryOpType() == BinaryOpType::Mod) {
1301+
if (prove::isMultipleOf(bop->lhs(), bop->rhs())) {
1302+
return IrBuilder::newConstant(0, *value->getDataType());
1303+
}
1304+
} else if (bop->getBinaryOpType() == BinaryOpType::Div) {
1305+
auto lhs = sym_algebra::factorize(bop->lhs());
1306+
auto rhs = sym_algebra::factorize(bop->rhs());
1307+
auto quotient = sym_algebra::divideFactorized(lhs, rhs);
1308+
if (quotient != nullptr) {
1309+
return quotient;
1310+
}
1311+
}
1312+
return value;
1313+
}
1314+
1315+
// Simplify div and mod by canceling common terms
1316+
//
1317+
// For div, use rule N): a / (b * c) = (a / b) / c to simplify division:
1318+
// Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
1319+
// then we can simplify x / y as:
1320+
// x / y = x / (gcd(x, y) * y') = (x / gcd(x, y)) / y' = x' / y'
1321+
//
1322+
// For mod, use rule
1323+
// O) If d divides a and b, then a % b = ((a / d) % (b / d)) * d
1324+
// Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
1325+
// If gcd is nonzero, then we can simplify x % y as:
1326+
// x' % y' * gcd(x, y)
1327+
Val* cancelDivMod(Val* value, const VarInfoMap& var_info) {
1328+
auto divmod = toDivModOp(value->definition());
1329+
if (!divmod) {
1330+
return value;
1331+
}
1332+
auto op = divmod->getBinaryOpType();
1333+
if (op != BinaryOpType::Div && op != BinaryOpType::Mod) {
1334+
return value;
1335+
}
1336+
auto lhs = sym_algebra::factorize(divmod->lhs());
1337+
auto rhs = sym_algebra::factorize(divmod->rhs());
1338+
auto gcd = sym_algebra::greatestCommonDivisor({lhs, rhs});
1339+
if (gcd->isOneInt() || !prove::isNonZero(gcd, var_info)) {
1340+
return value;
1341+
}
1342+
auto numerator = sym_algebra::divideFactorized(lhs, gcd);
1343+
auto denominator = sym_algebra::divideFactorized(rhs, gcd);
1344+
if (op == BinaryOpType::Div) {
1345+
return IrBuilder::divExpr(numerator, denominator);
1346+
} else {
1347+
TORCH_INTERNAL_ASSERT(op == BinaryOpType::Mod);
1348+
return assoc_comm::flatten(
1349+
IrBuilder::mulExpr(IrBuilder::modExpr(numerator, denominator), gcd));
1350+
}
1351+
}
1352+
1353+
// Use the following rule to simplify div and mod:
1354+
// J) Distributivity of % over +:
1355+
// If compatible_sign(a, b), then (a + b) % c = (a % c + b % c) % c
1356+
// Q) If compatible_sign(a, b) and -|c| < a % c + b % c < |c|, then
1357+
// (a+b)/c = a/c + b/c
1358+
// In this pass we distribute div and mod for a special case:
1359+
// If compatible_sign(a, b), and a is a multiple of c, then:
1360+
// (a+b)/c = a/c + b/c
1361+
// (a + b) % c = b % c
1362+
Val* distributeDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
1363+
auto divmod = toDivModOp(value->definition());
1364+
if (!divmod) {
1365+
return value;
1366+
}
1367+
auto lhs = divmod->lhs();
1368+
auto rhs = divmod->rhs();
1369+
if (!prove::isNonZero(rhs, var_info)) {
1370+
return value;
1371+
}
1372+
auto fop = toFlattenedAdd(lhs->definition());
1373+
if (!fop) {
1374+
return value;
1375+
}
1376+
for (auto i : c10::irange(fop->inputs().size())) {
1377+
Val* divisible_term = fop->input(i);
1378+
if (!prove::isMultipleOf(divisible_term, rhs)) {
1379+
continue;
1380+
}
1381+
std::vector<Val*> other_terms;
1382+
other_terms.reserve(fop->inputs().size() - 1);
1383+
for (auto j : c10::irange(fop->inputs().size())) {
1384+
if (j == i) {
1385+
continue;
12871386
}
1387+
other_terms.emplace_back(fop->input(j));
1388+
}
1389+
Val* sum_of_other_terms = nullptr;
1390+
if (other_terms.size() == 1) {
1391+
sum_of_other_terms = other_terms.at(0);
1392+
} else {
1393+
sum_of_other_terms = IrBuilder::newScalar(*value->getDataType());
1394+
IrBuilder::create<FOp>(
1395+
BinaryOpType::Add, sum_of_other_terms, std::move(other_terms));
1396+
}
1397+
if (prove::hasCompatibleSign(
1398+
divisible_term, sum_of_other_terms, var_info)) {
1399+
std::vector<Val*> new_inputs;
1400+
auto term1 = IrBuilder::newScalar(*value->getDataType());
1401+
IrBuilder::create<BinaryOp>(
1402+
divmod->getBinaryOpType(), term1, divisible_term, rhs);
1403+
new_inputs.emplace_back(simplifyDivisibleDivMod(term1, var_info));
1404+
new_inputs.emplace_back(IrBuilder::newScalar(*value->getDataType()));
1405+
IrBuilder::create<BinaryOp>(
1406+
divmod->getBinaryOpType(), new_inputs[1], sum_of_other_terms, rhs);
1407+
auto output = IrBuilder::newScalar(*value->getDataType());
1408+
IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move(new_inputs));
1409+
return output;
12881410
}
12891411
}
12901412
return value;
12911413
}
12921414

1415+
// a * (b + c) -> a * b + a * c
1416+
Val* distributeMul(Val* value, const VarInfoMap& var_info) {
1417+
auto fop = toFlattenedMul(value->definition());
1418+
if (!fop) {
1419+
return value;
1420+
}
1421+
Val* flattened_add = nullptr;
1422+
std::vector<Val*> other_terms;
1423+
for (auto inp : fop->inputs()) {
1424+
if (flattened_add == nullptr && isFlattenedAdd(inp)) {
1425+
flattened_add = inp;
1426+
} else {
1427+
other_terms.emplace_back(inp);
1428+
}
1429+
}
1430+
if (flattened_add == nullptr) {
1431+
return value;
1432+
}
1433+
auto fadd_op = toFlattenedAdd(flattened_add->definition());
1434+
std::vector<Val*> add_terms;
1435+
for (auto inp : fadd_op->inputs()) {
1436+
std::vector<Val*> inputs = other_terms;
1437+
inputs.emplace_back(inp);
1438+
add_terms.emplace_back(IrBuilder::newScalar(*value->getDataType()));
1439+
IrBuilder::create<FOp>(
1440+
BinaryOpType::Mul, add_terms.back(), std::move(inputs));
1441+
}
1442+
auto output = IrBuilder::newScalar(*value->getDataType());
1443+
IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move(add_terms));
1444+
return output;
1445+
}
1446+
12931447
} // namespace rules
12941448

12951449
#define RUN_PASS(pass_name) \
@@ -1298,20 +1452,44 @@ Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
12981452
}); \
12991453
logger->record(#pass_name, simplified)
13001454

1455+
// Requires that all the passes before the barrier to be converged before
1456+
// procceeding to the passes after the barrier.
1457+
#define PASS_BARRIER \
1458+
if (old_simplified != simplified) \
1459+
continue
1460+
13011461
Val* simplifyExpr(Val* value, const std::list<VarInfo>& variables) {
13021462
FusionGuard fg(value->fusion());
13031463
const VarInfoMap var_info(variables);
13041464
auto logger = debug_print::createLogger(value);
13051465

1306-
auto simplified = assoc_comm::flatten(value);
1307-
logger->record(debug_print::kFlattenName, simplified);
1308-
1466+
Val* simplified = value;
13091467
Val* old_simplified = nullptr;
13101468
while (old_simplified != simplified) {
13111469
old_simplified = simplified;
1470+
1471+
// Passes other than assoc_comm::flatten assumes that all
1472+
// associative-and-commutative binary ops (such as + *) are flattened. So
1473+
// that they don't need to worry about things like
1474+
// (a + b) + c vs a + (b + c)
1475+
// So the first step before running all other passes is always flattening
1476+
//
1477+
// Note that, some passes might create nested flattened ops, something like
1478+
// FlattenedAdd(FlattenedAdd(...), ...), so we should rerun flatten at the
1479+
// beginning of each round instead of flattening before the while loop.
1480+
simplified = assoc_comm::flatten(simplified);
1481+
logger->record(debug_print::kFlattenName, simplified);
1482+
13121483
RUN_PASS(eliminateTrivialComputation);
13131484
RUN_PASS(eliminateTrivialPredicate);
13141485
RUN_PASS(simplifyDivisibleDivMod);
1486+
RUN_PASS(eliminateTrivialPredicate);
1487+
RUN_PASS(simplifyDivisibleDivMod);
1488+
RUN_PASS(cancelDivMod);
1489+
PASS_BARRIER;
1490+
RUN_PASS(distributeDivisibleDivMod);
1491+
PASS_BARRIER;
1492+
RUN_PASS(distributeMul);
13151493
}
13161494

13171495
auto unflattened = assoc_comm::unflatten(simplified, variables);

0 commit comments

Comments
 (0)