8
8
#include < lower_magic_zero.h>
9
9
#include < utils.h>
10
10
11
+ #include < cmath>
11
12
#include < functional>
12
13
#include < list>
13
14
#include < memory>
@@ -134,6 +135,46 @@ std::unique_ptr<debug_print::NoOpLogger> createLogger(Val* value) {
134
135
135
136
namespace {
136
137
138
+ // An ordered mapping of variable -> VarInfo
139
+ class VarInfoMap {
140
+ public:
141
+ VarInfoMap (const std::list<VarInfo>& variables) {
142
+ var_info_map_.reserve (variables.size ());
143
+ var_order_.reserve (variables.size ());
144
+ set_.reserve (variables.size ());
145
+ for (const auto & info : variables) {
146
+ var_order_.emplace_back (info.variable );
147
+ set_.emplace (info.variable );
148
+ var_info_map_[info.variable ] = info;
149
+ }
150
+ }
151
+
152
+ // Get the order of variables
153
+ const std::vector<Val*>& order () const {
154
+ return var_order_;
155
+ }
156
+
157
+ // Get the info of a variable
158
+ const VarInfo& info (Val* var) const {
159
+ return var_info_map_.at (var);
160
+ }
161
+
162
+ // Check if this mapping has information about the given variable
163
+ bool has (Val* var) const {
164
+ return set_.count (var);
165
+ }
166
+
167
+ // Get the set of variables that has information
168
+ const std::unordered_set<Val*>& set () const {
169
+ return set_;
170
+ }
171
+
172
+ private:
173
+ std::unordered_map<Val*, VarInfo> var_info_map_;
174
+ std::vector<Val*> var_order_;
175
+ std::unordered_set<Val*> set_;
176
+ };
177
+
137
178
bool hasSimilarType (DataType t1, DataType t2) {
138
179
if (t1 == t2) {
139
180
return true ;
@@ -464,22 +505,17 @@ class FlattenedAssocCommOp : public Expr {
464
505
// and b < c. So in this example, this function will return [v2, v1].
465
506
// Tensors are always considered as variables and they are always considered
466
507
// as the rightmost.
467
- std::vector<Val*> sortedInputs (const std::list<ValInfo>& variables) {
468
- std::unordered_set<Val*> variables_set;
469
- variables_set.reserve (variables.size ());
470
- for (auto v : variables) {
471
- variables_set.emplace (v.variable );
472
- }
508
+ std::vector<Val*> sortedInputs (const VarInfoMap& var_info) {
473
509
std::vector<Val*> sorted_inputs (inputs ().begin (), inputs ().end ());
474
510
std::unordered_map<Val*, std::unordered_set<Val*>> dependency;
475
511
dependency.reserve (sorted_inputs.size ());
476
512
for (auto v : sorted_inputs) {
477
- dependency[v] = getSubexprDependency (v, variables_set );
513
+ dependency[v] = getSubexprDependency (v, var_info. set () );
478
514
}
479
515
auto compare = [&](Val* v1, Val* v2) {
480
- // Find all variables in variables_set that v1 and v2 depends on. The
481
- // input (v1 or v2) that exclusively has the right most variable in
482
- // variables_set will be to the right of the other input.
516
+ // Find all variables in var_info that v1 and v2 depends on. The input (v1
517
+ // or v2) that exclusively has the right most variable in var_info.order()
518
+ // will be to the right of the other input.
483
519
bool v1_is_left_of_v2 = false ;
484
520
auto deps1 = dependency.at (v1);
485
521
auto deps2 = dependency.at (v2);
@@ -489,11 +525,10 @@ class FlattenedAssocCommOp : public Expr {
489
525
if (hasTensor (deps1)) {
490
526
return false ;
491
527
}
492
- for (auto v : variables ) {
493
- if (deps1.count (v. variable ) > 0 && deps2.count (v. variable ) == 0 ) {
528
+ for (auto v : var_info. order () ) {
529
+ if (deps1.count (v) > 0 && deps2.count (v) == 0 ) {
494
530
v1_is_left_of_v2 = false ;
495
- } else if (
496
- deps2.count (v.variable ) > 0 && deps1.count (v.variable ) == 0 ) {
531
+ } else if (deps2.count (v) > 0 && deps1.count (v) == 0 ) {
497
532
v1_is_left_of_v2 = true ;
498
533
}
499
534
}
@@ -642,9 +677,9 @@ Val* flatten(Val* value) {
642
677
643
678
// Recursively convert expressions like FlattenedAdd(a, b, c, d) into
644
679
// AddOp(AddOp(AddOp(a, b), c), d))
645
- Val* unflatten (Val* value, const std::list<ValInfo>& variables );
680
+ Val* unflatten (Val* value, const VarInfoMap& var_info );
646
681
647
- Val* unflattenRule (Val* value, const std::list<ValInfo>& variables ) {
682
+ Val* unflattenRule (Val* value, const VarInfoMap& var_info ) {
648
683
auto def = value->definition ();
649
684
if (def == nullptr ) {
650
685
return value;
@@ -663,14 +698,14 @@ Val* unflattenRule(Val* value, const std::list<ValInfo>& variables) {
663
698
// Handle flattened op:
664
699
// Convert flattened op into original binary ops
665
700
TORCH_INTERNAL_ASSERT (fop->inputs ().size () >= 2 );
666
- auto sorted_inputs = fop->sortedInputs (variables );
701
+ auto sorted_inputs = fop->sortedInputs (var_info );
667
702
// We need to recursively unflatten all inputs, because we might have
668
703
// nested flattened expressions like
669
704
// FlattenedAdd(a, b, FlattenedMul(c, d, e))
670
- Val* lhs = unflatten (sorted_inputs.at (0 ), variables );
705
+ Val* lhs = unflatten (sorted_inputs.at (0 ), var_info );
671
706
int64_t next = 1 ;
672
- while (next < (int )sorted_inputs.size ()) {
673
- auto rhs = unflatten (sorted_inputs.at (next), variables );
707
+ while (next < (int64_t )sorted_inputs.size ()) {
708
+ auto rhs = unflatten (sorted_inputs.at (next), var_info );
674
709
auto output = IrBuilder::newScalar (*value->getDataType ());
675
710
IrBuilder::create<BinaryOp>(fop->getOpType (), output, lhs, rhs);
676
711
lhs = output;
@@ -681,10 +716,9 @@ Val* unflattenRule(Val* value, const std::list<ValInfo>& variables) {
681
716
return value;
682
717
}
683
718
684
- Val* unflatten (Val* value, const std::list<ValInfo>& variables) {
685
- using namespace std ::placeholders;
719
+ Val* unflatten (Val* value, const VarInfoMap& var_info) {
686
720
return recurseDown (
687
- value, [&variables ](Val* val) { return unflattenRule (val, variables ); });
721
+ value, [&var_info ](Val* val) { return unflattenRule (val, var_info ); });
688
722
}
689
723
690
724
} // namespace assoc_comm
@@ -954,35 +988,53 @@ namespace prove {
954
988
// - x can be either zero or non-zero, it is just a symbolic number that depends
955
989
// - x is zero
956
990
957
- bool isNonNegative (Val* value) {
991
+ bool isPositive (Val* value, const VarInfoMap& var_info);
992
+
993
+ bool isNonNegative (Val* value, const VarInfoMap& var_info) {
958
994
value = foldConstants (value);
959
995
if (value->getInt ().has_value () && *value->getInt () >= 0 ) {
960
996
return true ;
961
997
}
962
998
if (value->getDouble ().has_value () && *value->getDouble () >= 0.0 ) {
963
999
return true ;
964
1000
}
1001
+ if (isPositive (value, var_info)) {
1002
+ return true ;
1003
+ }
965
1004
if (auto ns = dynamic_cast <NamedScalar*>(value)) {
966
1005
if (ns->getParallelDim ().has_value () ||
967
- ns->getParallelIndex ().has_value ()) {
1006
+ ns->getParallelIndex ().has_value () || ns->isTensorSize () ||
1007
+ ns->isTensorStride ()) {
968
1008
return true ;
969
1009
}
970
1010
}
1011
+ value = maybeUnwrapMagicZero (value);
1012
+ if (var_info.has (value)) {
1013
+ return isNonNegative (var_info.info (value).start , var_info) &&
1014
+ isNonNegative (var_info.info (value).step , var_info);
1015
+ }
971
1016
if (auto fop = dynamic_cast <FOp*>(value->definition ())) {
972
1017
auto op = fop->getOpType ();
973
1018
if (op == BinaryOpType::Add || op == BinaryOpType::Mul) {
974
1019
for (auto inp : fop->inputs ()) {
975
- if (!isNonNegative (inp)) {
1020
+ if (!isNonNegative (inp, var_info )) {
976
1021
return false ;
977
1022
}
978
1023
}
979
1024
return true ;
980
1025
}
1026
+ } else if (auto bop = dynamic_cast <BinaryOp*>(value->definition ())) {
1027
+ auto op = bop->getBinaryOpType ();
1028
+ if (op == BinaryOpType::Mod || op == BinaryOpType::Div ||
1029
+ op == BinaryOpType::CeilDiv) {
1030
+ return isNonNegative (bop->lhs (), var_info) &&
1031
+ isPositive (bop->rhs (), var_info);
1032
+ }
981
1033
}
982
1034
return false ;
983
1035
}
984
1036
985
- bool isPositive (Val* value) {
1037
+ bool isPositive (Val* value, const VarInfoMap& var_info ) {
986
1038
value = foldConstants (value);
987
1039
if (value->getInt ().has_value () && *value->getInt () > 0 ) {
988
1040
return true ;
@@ -1000,16 +1052,16 @@ bool isPositive(Val* value) {
1000
1052
if (op == BinaryOpType::Add) {
1001
1053
bool has_positive = false ;
1002
1054
for (auto inp : fop->inputs ()) {
1003
- if (isPositive (inp)) {
1055
+ if (isPositive (inp, var_info )) {
1004
1056
has_positive = true ;
1005
- } else if (!isNonNegative (inp)) {
1057
+ } else if (!isNonNegative (inp, var_info )) {
1006
1058
return false ;
1007
1059
}
1008
1060
}
1009
1061
return has_positive;
1010
1062
} else if (op == BinaryOpType::Mul) {
1011
1063
for (auto inp : fop->inputs ()) {
1012
- if (!isPositive (inp)) {
1064
+ if (!isPositive (inp, var_info )) {
1013
1065
return false ;
1014
1066
}
1015
1067
}
@@ -1019,20 +1071,20 @@ bool isPositive(Val* value) {
1019
1071
return false ;
1020
1072
}
1021
1073
1022
- bool isNonZero (Val* value) {
1074
+ bool isNonZero (Val* value, const VarInfoMap& var_info ) {
1023
1075
value = foldConstants (value);
1024
1076
if (value->getInt ().has_value () && *value->getInt () != 0 ) {
1025
1077
return true ;
1026
1078
}
1027
1079
if (value->getDouble ().has_value () && *value->getDouble () != 0.0 ) {
1028
1080
return true ;
1029
1081
}
1030
- if (isPositive (value)) {
1082
+ if (isPositive (value, var_info )) {
1031
1083
return true ;
1032
1084
}
1033
1085
if (auto fop = toFlattenedMul (value->definition ())) {
1034
1086
for (auto inp : fop->inputs ()) {
1035
- if (!isNonZero (inp)) {
1087
+ if (!isNonZero (inp, var_info )) {
1036
1088
return false ;
1037
1089
}
1038
1090
}
@@ -1064,7 +1116,7 @@ namespace rules {
1064
1116
// a / 1 -> a
1065
1117
// x - x -> 0
1066
1118
// ...
1067
- Val* eliminateTrivialComputation (Val* value) {
1119
+ Val* eliminateTrivialComputation (Val* value, const VarInfoMap& var_info ) {
1068
1120
auto folded = foldConstants (value);
1069
1121
if (folded != value) {
1070
1122
return folded;
@@ -1133,7 +1185,8 @@ Val* eliminateTrivialComputation(Val* value) {
1133
1185
auto lhs = foldConstants (bop->lhs ());
1134
1186
auto rhs = foldConstants (bop->rhs ());
1135
1187
bool divisor_is_1 = (rhs->getInt () == 1 || rhs->getDouble () == 1.0 );
1136
- if (divisor_is_1 || (prove::isNonZero (rhs) && lhs->getInt () == 0 )) {
1188
+ if (divisor_is_1 ||
1189
+ (prove::isNonZero (rhs, var_info) && lhs->getInt () == 0 )) {
1137
1190
return lhs;
1138
1191
}
1139
1192
} else if (bop->getBinaryOpType () == BinaryOpType::Sub) {
@@ -1147,12 +1200,79 @@ Val* eliminateTrivialComputation(Val* value) {
1147
1200
return value;
1148
1201
}
1149
1202
1203
+ // If x can be proved to be non-negative, then replace x >= 0 as true, replace
1204
+ // x < 0 as false
1205
+ // If x can be proved to be positive, then replace x >= 0 and x > 0 as true,
1206
+ // replace x <= 0 and x < 0 as false
1207
+ // If x can be proved to be nonzero, then replace x != 0 as true, replace x == 0
1208
+ // as false
1209
+ // if x->sameAs(y), then replace x == y as true, replace x != y as false
1210
+ Val* eliminateTrivialPredicate (Val* value, const VarInfoMap& var_info) {
1211
+ if (!value->isABool ()) {
1212
+ return value;
1213
+ }
1214
+ if (auto bop = dynamic_cast <BinaryOp*>(value->definition ())) {
1215
+ auto op = bop->getBinaryOpType ();
1216
+ if (bop->lhs ()->sameAs (bop->rhs ())) {
1217
+ if (op == BinaryOpType::Eq) {
1218
+ return value->fusion ()->trueVal ();
1219
+ } else if (op == BinaryOpType::NE) {
1220
+ return value->fusion ()->falseVal ();
1221
+ }
1222
+ }
1223
+ if (bop->rhs ()->isZeroInt ()) {
1224
+ if (op == BinaryOpType::GE &&
1225
+ prove::isNonNegative (bop->lhs (), var_info)) {
1226
+ return value->fusion ()->trueVal ();
1227
+ } else if (
1228
+ op == BinaryOpType::GT && prove::isPositive (bop->lhs (), var_info)) {
1229
+ return value->fusion ()->trueVal ();
1230
+ } else if (
1231
+ op == BinaryOpType::NE && prove::isNonZero (bop->lhs (), var_info)) {
1232
+ return value->fusion ()->trueVal ();
1233
+ } else if (
1234
+ op == BinaryOpType::LT &&
1235
+ prove::isNonNegative (bop->lhs (), var_info)) {
1236
+ return value->fusion ()->falseVal ();
1237
+ } else if (
1238
+ op == BinaryOpType::LE && prove::isPositive (bop->lhs (), var_info)) {
1239
+ return value->fusion ()->falseVal ();
1240
+ } else if (
1241
+ op == BinaryOpType::Eq && prove::isNonZero (bop->lhs (), var_info)) {
1242
+ return value->fusion ()->falseVal ();
1243
+ }
1244
+ } else if (bop->lhs ()->isZeroInt ()) {
1245
+ if (op == BinaryOpType::LE &&
1246
+ prove::isNonNegative (bop->rhs (), var_info)) {
1247
+ return value->fusion ()->trueVal ();
1248
+ } else if (
1249
+ op == BinaryOpType::LT && prove::isPositive (bop->rhs (), var_info)) {
1250
+ return value->fusion ()->trueVal ();
1251
+ } else if (
1252
+ op == BinaryOpType::NE && prove::isNonZero (bop->rhs (), var_info)) {
1253
+ return value->fusion ()->trueVal ();
1254
+ } else if (
1255
+ op == BinaryOpType::GT &&
1256
+ prove::isNonNegative (bop->rhs (), var_info)) {
1257
+ return value->fusion ()->falseVal ();
1258
+ } else if (
1259
+ op == BinaryOpType::GE && prove::isPositive (bop->rhs (), var_info)) {
1260
+ return value->fusion ()->falseVal ();
1261
+ } else if (
1262
+ op == BinaryOpType::Eq && prove::isNonZero (bop->rhs (), var_info)) {
1263
+ return value->fusion ()->falseVal ();
1264
+ }
1265
+ }
1266
+ }
1267
+ return value;
1268
+ }
1269
+
1150
1270
// Apply rule L to replace x % y with 0 if x can be proved to be a multiple of y
1151
1271
// Also, according to rule M, if x can be factorized as x = k * y, then x / y
1152
1272
// can be simplified as x / y = (k * y) / y = k * (y / y) = k
1153
- Val* simplifyDivisibleDivMod (Val* value) {
1273
+ Val* simplifyDivisibleDivMod (Val* value, const VarInfoMap& var_info ) {
1154
1274
if (auto bop = dynamic_cast <BinaryOp*>(value->definition ())) {
1155
- if (prove::isNonZero (bop->rhs ())) {
1275
+ if (prove::isNonZero (bop->rhs (), var_info )) {
1156
1276
if (bop->getBinaryOpType () == BinaryOpType::Mod) {
1157
1277
if (prove::isMultipleOf (bop->lhs (), bop->rhs ())) {
1158
1278
return IrBuilder::newConstant (0 , *value->getDataType ());
@@ -1172,12 +1292,15 @@ Val* simplifyDivisibleDivMod(Val* value) {
1172
1292
1173
1293
} // namespace rules
1174
1294
1175
- #define RUN_PASS (pass_name ) \
1176
- simplified = recurseDown(simplified, rules::pass_name); \
1295
+ #define RUN_PASS (pass_name ) \
1296
+ simplified = recurseDown(simplified, [&var_info](Val* val) { \
1297
+ return rules::pass_name (val, var_info); \
1298
+ }); \
1177
1299
logger->record (#pass_name, simplified)
1178
1300
1179
- Val* simplifyExpr(Val* value, const std::list<ValInfo >& variables) {
1301
+ Val* simplifyExpr(Val* value, const std::list<VarInfo >& variables) {
1180
1302
FusionGuard fg (value->fusion ());
1303
+ const VarInfoMap var_info (variables);
1181
1304
auto logger = debug_print::createLogger (value);
1182
1305
1183
1306
auto simplified = assoc_comm::flatten (value);
@@ -1187,6 +1310,7 @@ Val* simplifyExpr(Val* value, const std::list<ValInfo>& variables) {
1187
1310
while (old_simplified != simplified) {
1188
1311
old_simplified = simplified;
1189
1312
RUN_PASS (eliminateTrivialComputation);
1313
+ RUN_PASS (eliminateTrivialPredicate);
1190
1314
RUN_PASS (simplifyDivisibleDivMod);
1191
1315
}
1192
1316
0 commit comments