@@ -728,10 +728,12 @@ namespace {
728
728
using FOp = assoc_comm::FlattenedAssocCommOp;
729
729
730
730
FOp* toFlattenedAdd (Expr* expr) {
731
- if (auto fop = dynamic_cast <FOp*>(expr)) {
732
- if (fop->getOpType () == BinaryOpType::Add) {
733
- return fop;
734
- }
731
+ auto fop = dynamic_cast <FOp*>(expr);
732
+ if (!fop) {
733
+ return nullptr ;
734
+ }
735
+ if (fop->getOpType () == BinaryOpType::Add) {
736
+ return fop;
735
737
}
736
738
return nullptr ;
737
739
}
@@ -741,10 +743,12 @@ bool isFlattenedAdd(Val* x) {
741
743
}
742
744
743
745
FOp* toFlattenedMul (Expr* expr) {
744
- if (auto fop = dynamic_cast <FOp*>(expr)) {
745
- if (fop->getOpType () == BinaryOpType::Mul) {
746
- return fop;
747
- }
746
+ auto fop = dynamic_cast <FOp*>(expr);
747
+ if (!fop) {
748
+ return nullptr ;
749
+ }
750
+ if (fop->getOpType () == BinaryOpType::Mul) {
751
+ return fop;
748
752
}
749
753
return nullptr ;
750
754
}
@@ -753,6 +757,19 @@ bool isFlattenedMul(Val* x) {
753
757
return toFlattenedMul (x->definition ()) != nullptr ;
754
758
}
755
759
760
+ BinaryOp* toDivModOp (Expr* expr) {
761
+ auto bop = dynamic_cast <BinaryOp*>(expr);
762
+ if (!bop) {
763
+ return nullptr ;
764
+ }
765
+ if (bop->getBinaryOpType () == BinaryOpType::Div ||
766
+ bop->getBinaryOpType () == BinaryOpType::Mod) {
767
+ // TODO: Add CeilDiv as well? Need mathematiclly prove its rules first
768
+ return bop;
769
+ }
770
+ return nullptr ;
771
+ }
772
+
756
773
// Classify terms of a FlattenedMul as (constant, symbolic), for example:
757
774
// a * 3 * b * 5 --> (15, {a, b})
758
775
// a * b --> (1, {a, b})
@@ -1101,6 +1118,10 @@ bool isMultipleOf(Val* x, Val* y) {
1101
1118
return sym_algebra::divideFactorized (lhs, rhs) != nullptr ;
1102
1119
}
1103
1120
1121
+ bool hasCompatibleSign (Val* x, Val* y, const VarInfoMap& var_info) {
1122
+ return isNonNegative (x, var_info) && isNonNegative (y, var_info);
1123
+ }
1124
+
1104
1125
} // namespace prove
1105
1126
1106
1127
namespace rules {
@@ -1211,57 +1232,55 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
1211
1232
if (!value->isABool ()) {
1212
1233
return value;
1213
1234
}
1214
- if (auto bop = dynamic_cast <BinaryOp*>(value->definition ())) {
1215
- auto op = bop->getBinaryOpType ();
1216
- if (bop->lhs ()->sameAs (bop->rhs ())) {
1217
- if (op == BinaryOpType::Eq) {
1218
- return value->fusion ()->trueVal ();
1219
- } else if (op == BinaryOpType::NE) {
1220
- return value->fusion ()->falseVal ();
1221
- }
1235
+ auto bop = dynamic_cast <BinaryOp*>(value->definition ());
1236
+ if (!bop) {
1237
+ return value;
1238
+ }
1239
+ auto op = bop->getBinaryOpType ();
1240
+ if (bop->lhs ()->sameAs (bop->rhs ())) {
1241
+ if (op == BinaryOpType::Eq) {
1242
+ return value->fusion ()->trueVal ();
1243
+ } else if (op == BinaryOpType::NE) {
1244
+ return value->fusion ()->falseVal ();
1222
1245
}
1223
- if (bop->rhs ()->isZeroInt ()) {
1224
- if (op == BinaryOpType::GE &&
1225
- prove::isNonNegative (bop->lhs (), var_info)) {
1226
- return value->fusion ()->trueVal ();
1227
- } else if (
1228
- op == BinaryOpType::GT && prove::isPositive (bop->lhs (), var_info)) {
1229
- return value->fusion ()->trueVal ();
1230
- } else if (
1231
- op == BinaryOpType::NE && prove::isNonZero (bop->lhs (), var_info)) {
1232
- return value->fusion ()->trueVal ();
1233
- } else if (
1234
- op == BinaryOpType::LT &&
1235
- prove::isNonNegative (bop->lhs (), var_info)) {
1236
- return value->fusion ()->falseVal ();
1237
- } else if (
1238
- op == BinaryOpType::LE && prove::isPositive (bop->lhs (), var_info)) {
1239
- return value->fusion ()->falseVal ();
1240
- } else if (
1241
- op == BinaryOpType::Eq && prove::isNonZero (bop->lhs (), var_info)) {
1242
- return value->fusion ()->falseVal ();
1243
- }
1244
- } else if (bop->lhs ()->isZeroInt ()) {
1245
- if (op == BinaryOpType::LE &&
1246
- prove::isNonNegative (bop->rhs (), var_info)) {
1247
- return value->fusion ()->trueVal ();
1248
- } else if (
1249
- op == BinaryOpType::LT && prove::isPositive (bop->rhs (), var_info)) {
1250
- return value->fusion ()->trueVal ();
1251
- } else if (
1252
- op == BinaryOpType::NE && prove::isNonZero (bop->rhs (), var_info)) {
1253
- return value->fusion ()->trueVal ();
1254
- } else if (
1255
- op == BinaryOpType::GT &&
1256
- prove::isNonNegative (bop->rhs (), var_info)) {
1257
- return value->fusion ()->falseVal ();
1258
- } else if (
1259
- op == BinaryOpType::GE && prove::isPositive (bop->rhs (), var_info)) {
1260
- return value->fusion ()->falseVal ();
1261
- } else if (
1262
- op == BinaryOpType::Eq && prove::isNonZero (bop->rhs (), var_info)) {
1263
- return value->fusion ()->falseVal ();
1264
- }
1246
+ }
1247
+ if (bop->rhs ()->isZeroInt ()) {
1248
+ if (op == BinaryOpType::GE && prove::isNonNegative (bop->lhs (), var_info)) {
1249
+ return value->fusion ()->trueVal ();
1250
+ } else if (
1251
+ op == BinaryOpType::GT && prove::isPositive (bop->lhs (), var_info)) {
1252
+ return value->fusion ()->trueVal ();
1253
+ } else if (
1254
+ op == BinaryOpType::NE && prove::isNonZero (bop->lhs (), var_info)) {
1255
+ return value->fusion ()->trueVal ();
1256
+ } else if (
1257
+ op == BinaryOpType::LT && prove::isNonNegative (bop->lhs (), var_info)) {
1258
+ return value->fusion ()->falseVal ();
1259
+ } else if (
1260
+ op == BinaryOpType::LE && prove::isPositive (bop->lhs (), var_info)) {
1261
+ return value->fusion ()->falseVal ();
1262
+ } else if (
1263
+ op == BinaryOpType::Eq && prove::isNonZero (bop->lhs (), var_info)) {
1264
+ return value->fusion ()->falseVal ();
1265
+ }
1266
+ } else if (bop->lhs ()->isZeroInt ()) {
1267
+ if (op == BinaryOpType::LE && prove::isNonNegative (bop->rhs (), var_info)) {
1268
+ return value->fusion ()->trueVal ();
1269
+ } else if (
1270
+ op == BinaryOpType::LT && prove::isPositive (bop->rhs (), var_info)) {
1271
+ return value->fusion ()->trueVal ();
1272
+ } else if (
1273
+ op == BinaryOpType::NE && prove::isNonZero (bop->rhs (), var_info)) {
1274
+ return value->fusion ()->trueVal ();
1275
+ } else if (
1276
+ op == BinaryOpType::GT && prove::isNonNegative (bop->rhs (), var_info)) {
1277
+ return value->fusion ()->falseVal ();
1278
+ } else if (
1279
+ op == BinaryOpType::GE && prove::isPositive (bop->rhs (), var_info)) {
1280
+ return value->fusion ()->falseVal ();
1281
+ } else if (
1282
+ op == BinaryOpType::Eq && prove::isNonZero (bop->rhs (), var_info)) {
1283
+ return value->fusion ()->falseVal ();
1265
1284
}
1266
1285
}
1267
1286
return value;
@@ -1271,25 +1290,160 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
1271
1290
// Also, according to rule M, if x can be factorized as x = k * y, then x / y
1272
1291
// can be simplified as x / y = (k * y) / y = k * (y / y) = k
1273
1292
Val* simplifyDivisibleDivMod (Val* value, const VarInfoMap& var_info) {
1274
- if (auto bop = dynamic_cast <BinaryOp*>(value->definition ())) {
1275
- if (prove::isNonZero (bop->rhs (), var_info)) {
1276
- if (bop->getBinaryOpType () == BinaryOpType::Mod) {
1277
- if (prove::isMultipleOf (bop->lhs (), bop->rhs ())) {
1278
- return IrBuilder::newConstant (0 , *value->getDataType ());
1279
- }
1280
- } else if (bop->getBinaryOpType () == BinaryOpType::Div) {
1281
- auto lhs = sym_algebra::factorize (bop->lhs ());
1282
- auto rhs = sym_algebra::factorize (bop->rhs ());
1283
- auto quotient = sym_algebra::divideFactorized (lhs, rhs);
1284
- if (quotient != nullptr ) {
1285
- return quotient;
1286
- }
1293
+ auto bop = dynamic_cast <BinaryOp*>(value->definition ());
1294
+ if (!bop) {
1295
+ return value;
1296
+ }
1297
+ if (!prove::isNonZero (bop->rhs (), var_info)) {
1298
+ return value;
1299
+ }
1300
+ if (bop->getBinaryOpType () == BinaryOpType::Mod) {
1301
+ if (prove::isMultipleOf (bop->lhs (), bop->rhs ())) {
1302
+ return IrBuilder::newConstant (0 , *value->getDataType ());
1303
+ }
1304
+ } else if (bop->getBinaryOpType () == BinaryOpType::Div) {
1305
+ auto lhs = sym_algebra::factorize (bop->lhs ());
1306
+ auto rhs = sym_algebra::factorize (bop->rhs ());
1307
+ auto quotient = sym_algebra::divideFactorized (lhs, rhs);
1308
+ if (quotient != nullptr ) {
1309
+ return quotient;
1310
+ }
1311
+ }
1312
+ return value;
1313
+ }
1314
+
1315
+ // Simplify div and mod by canceling common terms
1316
+ //
1317
+ // For div, use rule N): a / (b * c) = (a / b) / c to simplify division:
1318
+ // Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
1319
+ // then we can simplify x / y as:
1320
+ // x / y = x / (gcd(x, y) * y') = (x / gcd(x, y)) / y' = x' / y'
1321
+ //
1322
+ // For mod, use rule
1323
+ // O) If d divides a and b, then a % b = ((a / d) % (b / d)) * d
1324
+ // Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
1325
+ // If gcd is nonzero, then we can simplify x % y as:
1326
+ // x' % y' * gcd(x, y)
1327
+ Val* cancelDivMod (Val* value, const VarInfoMap& var_info) {
1328
+ auto divmod = toDivModOp (value->definition ());
1329
+ if (!divmod) {
1330
+ return value;
1331
+ }
1332
+ auto op = divmod->getBinaryOpType ();
1333
+ if (op != BinaryOpType::Div && op != BinaryOpType::Mod) {
1334
+ return value;
1335
+ }
1336
+ auto lhs = sym_algebra::factorize (divmod->lhs ());
1337
+ auto rhs = sym_algebra::factorize (divmod->rhs ());
1338
+ auto gcd = sym_algebra::greatestCommonDivisor ({lhs, rhs});
1339
+ if (gcd->isOneInt () || !prove::isNonZero (gcd, var_info)) {
1340
+ return value;
1341
+ }
1342
+ auto numerator = sym_algebra::divideFactorized (lhs, gcd);
1343
+ auto denominator = sym_algebra::divideFactorized (rhs, gcd);
1344
+ if (op == BinaryOpType::Div) {
1345
+ return IrBuilder::divExpr (numerator, denominator);
1346
+ } else {
1347
+ TORCH_INTERNAL_ASSERT (op == BinaryOpType::Mod);
1348
+ return assoc_comm::flatten (
1349
+ IrBuilder::mulExpr (IrBuilder::modExpr (numerator, denominator), gcd));
1350
+ }
1351
+ }
1352
+
1353
+ // Use the following rule to simplify div and mod:
1354
+ // J) Distributivity of % over +:
1355
+ // If compatible_sign(a, b), then (a + b) % c = (a % c + b % c) % c
1356
+ // Q) If compatible_sign(a, b) and -|c| < a % c + b % c < |c|, then
1357
+ // (a+b)/c = a/c + b/c
1358
+ // In this pass we distribute div and mod for a special case:
1359
+ // If compatible_sign(a, b), and a is a multiple of c, then:
1360
+ // (a+b)/c = a/c + b/c
1361
+ // (a + b) % c = b % c
1362
+ Val* distributeDivisibleDivMod (Val* value, const VarInfoMap& var_info) {
1363
+ auto divmod = toDivModOp (value->definition ());
1364
+ if (!divmod) {
1365
+ return value;
1366
+ }
1367
+ auto lhs = divmod->lhs ();
1368
+ auto rhs = divmod->rhs ();
1369
+ if (!prove::isNonZero (rhs, var_info)) {
1370
+ return value;
1371
+ }
1372
+ auto fop = toFlattenedAdd (lhs->definition ());
1373
+ if (!fop) {
1374
+ return value;
1375
+ }
1376
+ for (auto i : c10::irange (fop->inputs ().size ())) {
1377
+ Val* divisible_term = fop->input (i);
1378
+ if (!prove::isMultipleOf (divisible_term, rhs)) {
1379
+ continue ;
1380
+ }
1381
+ std::vector<Val*> other_terms;
1382
+ other_terms.reserve (fop->inputs ().size () - 1 );
1383
+ for (auto j : c10::irange (fop->inputs ().size ())) {
1384
+ if (j == i) {
1385
+ continue ;
1287
1386
}
1387
+ other_terms.emplace_back (fop->input (j));
1388
+ }
1389
+ Val* sum_of_other_terms = nullptr ;
1390
+ if (other_terms.size () == 1 ) {
1391
+ sum_of_other_terms = other_terms.at (0 );
1392
+ } else {
1393
+ sum_of_other_terms = IrBuilder::newScalar (*value->getDataType ());
1394
+ IrBuilder::create<FOp>(
1395
+ BinaryOpType::Add, sum_of_other_terms, std::move (other_terms));
1396
+ }
1397
+ if (prove::hasCompatibleSign (
1398
+ divisible_term, sum_of_other_terms, var_info)) {
1399
+ std::vector<Val*> new_inputs;
1400
+ auto term1 = IrBuilder::newScalar (*value->getDataType ());
1401
+ IrBuilder::create<BinaryOp>(
1402
+ divmod->getBinaryOpType (), term1, divisible_term, rhs);
1403
+ new_inputs.emplace_back (simplifyDivisibleDivMod (term1, var_info));
1404
+ new_inputs.emplace_back (IrBuilder::newScalar (*value->getDataType ()));
1405
+ IrBuilder::create<BinaryOp>(
1406
+ divmod->getBinaryOpType (), new_inputs[1 ], sum_of_other_terms, rhs);
1407
+ auto output = IrBuilder::newScalar (*value->getDataType ());
1408
+ IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move (new_inputs));
1409
+ return output;
1288
1410
}
1289
1411
}
1290
1412
return value;
1291
1413
}
1292
1414
1415
+ // a * (b + c) -> a * b + a * c
1416
+ Val* distributeMul (Val* value, const VarInfoMap& var_info) {
1417
+ auto fop = toFlattenedMul (value->definition ());
1418
+ if (!fop) {
1419
+ return value;
1420
+ }
1421
+ Val* flattened_add = nullptr ;
1422
+ std::vector<Val*> other_terms;
1423
+ for (auto inp : fop->inputs ()) {
1424
+ if (flattened_add == nullptr && isFlattenedAdd (inp)) {
1425
+ flattened_add = inp;
1426
+ } else {
1427
+ other_terms.emplace_back (inp);
1428
+ }
1429
+ }
1430
+ if (flattened_add == nullptr ) {
1431
+ return value;
1432
+ }
1433
+ auto fadd_op = toFlattenedAdd (flattened_add->definition ());
1434
+ std::vector<Val*> add_terms;
1435
+ for (auto inp : fadd_op->inputs ()) {
1436
+ std::vector<Val*> inputs = other_terms;
1437
+ inputs.emplace_back (inp);
1438
+ add_terms.emplace_back (IrBuilder::newScalar (*value->getDataType ()));
1439
+ IrBuilder::create<FOp>(
1440
+ BinaryOpType::Mul, add_terms.back (), std::move (inputs));
1441
+ }
1442
+ auto output = IrBuilder::newScalar (*value->getDataType ());
1443
+ IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move (add_terms));
1444
+ return output;
1445
+ }
1446
+
1293
1447
} // namespace rules
1294
1448
1295
1449
#define RUN_PASS (pass_name ) \
@@ -1298,20 +1452,44 @@ Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
1298
1452
}); \
1299
1453
logger->record (#pass_name, simplified)
1300
1454
1455
+ // Requires that all the passes before the barrier to be converged before
1456
+ // procceeding to the passes after the barrier.
1457
+ #define PASS_BARRIER \
1458
+ if (old_simplified != simplified) \
1459
+ continue
1460
+
1301
1461
Val* simplifyExpr (Val* value, const std::list<VarInfo>& variables) {
1302
1462
FusionGuard fg (value->fusion ());
1303
1463
const VarInfoMap var_info (variables);
1304
1464
auto logger = debug_print::createLogger (value);
1305
1465
1306
- auto simplified = assoc_comm::flatten (value);
1307
- logger->record (debug_print::kFlattenName , simplified);
1308
-
1466
+ Val* simplified = value;
1309
1467
Val* old_simplified = nullptr ;
1310
1468
while (old_simplified != simplified) {
1311
1469
old_simplified = simplified;
1470
+
1471
+ // Passes other than assoc_comm::flatten assumes that all
1472
+ // associative-and-commutative binary ops (such as + *) are flattened. So
1473
+ // that they don't need to worry about things like
1474
+ // (a + b) + c vs a + (b + c)
1475
+ // So the first step before running all other passes is always flattening
1476
+ //
1477
+ // Note that, some passes might create nested flattened ops, something like
1478
+ // FlattenedAdd(FlattenedAdd(...), ...), so we should rerun flatten at the
1479
+ // beginning of each round instead of flattening before the while loop.
1480
+ simplified = assoc_comm::flatten (simplified);
1481
+ logger->record (debug_print::kFlattenName , simplified);
1482
+
1312
1483
RUN_PASS (eliminateTrivialComputation);
1313
1484
RUN_PASS (eliminateTrivialPredicate);
1314
1485
RUN_PASS (simplifyDivisibleDivMod);
1486
+ RUN_PASS (eliminateTrivialPredicate);
1487
+ RUN_PASS (simplifyDivisibleDivMod);
1488
+ RUN_PASS (cancelDivMod);
1489
+ PASS_BARRIER;
1490
+ RUN_PASS (distributeDivisibleDivMod);
1491
+ PASS_BARRIER;
1492
+ RUN_PASS (distributeMul);
1315
1493
}
1316
1494
1317
1495
auto unflattened = assoc_comm::unflatten (simplified, variables);
0 commit comments