Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
df5f5fe
trivial mod
zasdfgbnm Dec 13, 2022
7ab0654
simplifyZeroMod
zasdfgbnm Dec 13, 2022
110b160
recurseDown doc
zasdfgbnm Dec 13, 2022
be528a2
fix
zasdfgbnm Dec 13, 2022
0cfa8cf
save
zasdfgbnm Dec 13, 2022
73e1cad
update
zasdfgbnm Dec 14, 2022
5f67a3d
save
zasdfgbnm Dec 14, 2022
1adf8a4
Merge branch 'devel' of github.com:csarofeen/pytorch into simplify-tr…
zasdfgbnm Dec 15, 2022
7ae1981
format
zasdfgbnm Dec 15, 2022
975b950
save
zasdfgbnm Dec 15, 2022
2e039b1
move
zasdfgbnm Dec 15, 2022
604ac82
rename
zasdfgbnm Dec 15, 2022
ee8867b
fix
zasdfgbnm Dec 15, 2022
6d9726f
Merge branch 'devel' of github.com:csarofeen/pytorch into simplify-tr…
zasdfgbnm Dec 15, 2022
222f884
save
zasdfgbnm Dec 15, 2022
024f6c2
fix
zasdfgbnm Dec 15, 2022
9aa3ab3
fix
zasdfgbnm Dec 16, 2022
9530702
save
zasdfgbnm Dec 16, 2022
2794f43
save
zasdfgbnm Dec 16, 2022
3a9c899
non zero check
zasdfgbnm Dec 16, 2022
0d739cd
prove::isPositive, prove::isNonNegative, prove::isNonZero
zasdfgbnm Dec 16, 2022
6bcff97
save
zasdfgbnm Dec 16, 2022
6655ea3
more
zasdfgbnm Dec 16, 2022
fd5c2f4
Merge branch 'devel' of github.com:csarofeen/pytorch into compatible-…
zasdfgbnm Dec 16, 2022
659df58
save
zasdfgbnm Dec 16, 2022
1053cb1
more fix
zasdfgbnm Dec 16, 2022
52a1bce
Merge branch 'devel' of github.com:csarofeen/pytorch into simplify-tr…
zasdfgbnm Dec 16, 2022
769cbc5
fix
zasdfgbnm Dec 16, 2022
b8f9c08
save
zasdfgbnm Dec 16, 2022
1d9d93a
save
zasdfgbnm Dec 17, 2022
fc7e897
same
zasdfgbnm Dec 17, 2022
847a053
cleanup
zasdfgbnm Dec 17, 2022
6f287b4
save
zasdfgbnm Dec 17, 2022
ddbf623
save
zasdfgbnm Dec 17, 2022
183e87d
cleanup
zasdfgbnm Dec 17, 2022
450aa29
save
zasdfgbnm Dec 17, 2022
96a43e5
save
zasdfgbnm Dec 17, 2022
a91c71e
fix
zasdfgbnm Dec 17, 2022
72b65a6
save
zasdfgbnm Dec 17, 2022
32a261c
greatestCommonDivisor
zasdfgbnm Dec 17, 2022
867bfdd
save
zasdfgbnm Dec 17, 2022
8f2e5aa
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Dec 17, 2022
443db1b
fix
zasdfgbnm Dec 17, 2022
07739be
distributeDivisibleDivMod
zasdfgbnm Dec 17, 2022
f77b9ee
save
zasdfgbnm Dec 17, 2022
e795507
save
zasdfgbnm Dec 17, 2022
8c9ac8d
distributeMul
zasdfgbnm Dec 17, 2022
2f27c08
save
zasdfgbnm Dec 17, 2022
d8624f4
save
zasdfgbnm Dec 17, 2022
edcdc84
save
zasdfgbnm Dec 17, 2022
84bea81
simplify equivalence
zasdfgbnm Dec 17, 2022
f2f6c16
Merge branch 'compatible-sign-check' into distribute-divmod
zasdfgbnm Dec 18, 2022
0718c6f
save
zasdfgbnm Dec 18, 2022
0ad7d4e
save
zasdfgbnm Dec 18, 2022
470d302
save
zasdfgbnm Dec 18, 2022
6758d77
FusionDistributeDivisibleDivMod_CUDA
zasdfgbnm Dec 18, 2022
f9e3f92
FusionDistributeMul_CUDA
zasdfgbnm Dec 18, 2022
d1020c9
Merge branch 'devel' of github.com:csarofeen/pytorch into compatible-…
zasdfgbnm Jan 16, 2023
ebb47f2
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Jan 16, 2023
c93b84a
Merge branch 'devel' of github.com:csarofeen/pytorch into distribute-…
zasdfgbnm Jan 16, 2023
e4aa3cb
cleanup
zasdfgbnm Jan 16, 2023
caa3857
ident
zasdfgbnm Jan 16, 2023
cbba2a8
Fusion::trueVal and Fusion::falseVal
zasdfgbnm Jan 16, 2023
c8bb61c
Merge branch 'devel' of github.com:csarofeen/pytorch into compatible-…
zasdfgbnm Jan 16, 2023
0fc3cd7
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Jan 16, 2023
68b2a62
no gcd
zasdfgbnm Jan 16, 2023
d570c65
cleanup include
zasdfgbnm Jan 16, 2023
6400b7d
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Jan 16, 2023
9ffa5a0
fix
zasdfgbnm Jan 16, 2023
39f0e02
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Jan 16, 2023
c824a1b
cleanup
zasdfgbnm Jan 16, 2023
eec6b5a
cleanup
zasdfgbnm Jan 17, 2023
6b35c43
rename
zasdfgbnm Jan 17, 2023
a31cc24
save
zasdfgbnm Jan 17, 2023
f6c1522
VarInfoMap comment
zasdfgbnm Jan 17, 2023
43c5ec5
Merge branch 'compatible-sign-check' of github.com:csarofeen/pytorch …
zasdfgbnm Jan 17, 2023
77252a1
fix
zasdfgbnm Jan 17, 2023
55b7f7b
fix
zasdfgbnm Jan 17, 2023
b4539d0
distributeMul restrict
zasdfgbnm Jan 17, 2023
8f7afbb
distribute all
zasdfgbnm Jan 17, 2023
26b3c92
fix
zasdfgbnm Jan 17, 2023
07097f1
FusionAssociativeAndCommutativeReordering_CUDA skip
zasdfgbnm Jan 17, 2023
0cb7472
FusionAssociativeAndCommutativeReordering_CUDA
zasdfgbnm Jan 17, 2023
72b4e9c
Merge branch 'devel' of github.com:csarofeen/pytorch into distribute-…
zasdfgbnm Jan 18, 2023
dbe16cd
comment
zasdfgbnm Jan 18, 2023
7aef8da
Merge branch 'devel' of github.com:csarofeen/pytorch into distribute-…
zasdfgbnm Jan 18, 2023
6250809
indent
zasdfgbnm Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 252 additions & 74 deletions third_party/nvfuser/csrc/expr_simplifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,12 @@ namespace {
using FOp = assoc_comm::FlattenedAssocCommOp;

FOp* toFlattenedAdd(Expr* expr) {
if (auto fop = dynamic_cast<FOp*>(expr)) {
if (fop->getOpType() == BinaryOpType::Add) {
return fop;
}
auto fop = dynamic_cast<FOp*>(expr);
if (!fop) {
return nullptr;
}
if (fop->getOpType() == BinaryOpType::Add) {
return fop;
}
return nullptr;
}
Expand All @@ -741,10 +743,12 @@ bool isFlattenedAdd(Val* x) {
}

FOp* toFlattenedMul(Expr* expr) {
if (auto fop = dynamic_cast<FOp*>(expr)) {
if (fop->getOpType() == BinaryOpType::Mul) {
return fop;
}
auto fop = dynamic_cast<FOp*>(expr);
if (!fop) {
return nullptr;
}
if (fop->getOpType() == BinaryOpType::Mul) {
return fop;
}
return nullptr;
}
Expand All @@ -753,6 +757,19 @@ bool isFlattenedMul(Val* x) {
return toFlattenedMul(x->definition()) != nullptr;
}

BinaryOp* toDivModOp(Expr* expr) {
auto bop = dynamic_cast<BinaryOp*>(expr);
if (!bop) {
return nullptr;
}
if (bop->getBinaryOpType() == BinaryOpType::Div ||
bop->getBinaryOpType() == BinaryOpType::Mod) {
// TODO: Add CeilDiv as well? Need mathematiclly prove its rules first
return bop;
}
return nullptr;
}

// Classify terms of a FlattenedMul as (constant, symbolic), for example:
// a * 3 * b * 5 --> (15, {a, b})
// a * b --> (1, {a, b})
Expand Down Expand Up @@ -1101,6 +1118,10 @@ bool isMultipleOf(Val* x, Val* y) {
return sym_algebra::divideFactorized(lhs, rhs) != nullptr;
}

bool hasCompatibleSign(Val* x, Val* y, const VarInfoMap& var_info) {
return isNonNegative(x, var_info) && isNonNegative(y, var_info);
}

} // namespace prove

namespace rules {
Expand Down Expand Up @@ -1211,57 +1232,55 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
if (!value->isABool()) {
return value;
}
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
auto op = bop->getBinaryOpType();
if (bop->lhs()->sameAs(bop->rhs())) {
if (op == BinaryOpType::Eq) {
return value->fusion()->trueVal();
} else if (op == BinaryOpType::NE) {
return value->fusion()->falseVal();
}
auto bop = dynamic_cast<BinaryOp*>(value->definition());
if (!bop) {
return value;
}
auto op = bop->getBinaryOpType();
if (bop->lhs()->sameAs(bop->rhs())) {
if (op == BinaryOpType::Eq) {
return value->fusion()->trueVal();
} else if (op == BinaryOpType::NE) {
return value->fusion()->falseVal();
}
if (bop->rhs()->isZeroInt()) {
if (op == BinaryOpType::GE &&
prove::isNonNegative(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::GT && prove::isPositive(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::NE && prove::isNonZero(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::LT &&
prove::isNonNegative(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::LE && prove::isPositive(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
}
} else if (bop->lhs()->isZeroInt()) {
if (op == BinaryOpType::LE &&
prove::isNonNegative(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::LT && prove::isPositive(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::NE && prove::isNonZero(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::GT &&
prove::isNonNegative(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::GE && prove::isPositive(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::Eq && prove::isNonZero(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
}
}
if (bop->rhs()->isZeroInt()) {
if (op == BinaryOpType::GE && prove::isNonNegative(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::GT && prove::isPositive(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::NE && prove::isNonZero(bop->lhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::LT && prove::isNonNegative(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::LE && prove::isPositive(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::Eq && prove::isNonZero(bop->lhs(), var_info)) {
return value->fusion()->falseVal();
}
} else if (bop->lhs()->isZeroInt()) {
if (op == BinaryOpType::LE && prove::isNonNegative(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::LT && prove::isPositive(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::NE && prove::isNonZero(bop->rhs(), var_info)) {
return value->fusion()->trueVal();
} else if (
op == BinaryOpType::GT && prove::isNonNegative(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::GE && prove::isPositive(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
} else if (
op == BinaryOpType::Eq && prove::isNonZero(bop->rhs(), var_info)) {
return value->fusion()->falseVal();
}
}
return value;
Expand All @@ -1271,25 +1290,160 @@ Val* eliminateTrivialPredicate(Val* value, const VarInfoMap& var_info) {
// Also, according to rule M, if x can be factorized as x = k * y, then x / y
// can be simplified as x / y = (k * y) / y = k * (y / y) = k
Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
if (auto bop = dynamic_cast<BinaryOp*>(value->definition())) {
if (prove::isNonZero(bop->rhs(), var_info)) {
if (bop->getBinaryOpType() == BinaryOpType::Mod) {
if (prove::isMultipleOf(bop->lhs(), bop->rhs())) {
return IrBuilder::newConstant(0, *value->getDataType());
}
} else if (bop->getBinaryOpType() == BinaryOpType::Div) {
auto lhs = sym_algebra::factorize(bop->lhs());
auto rhs = sym_algebra::factorize(bop->rhs());
auto quotient = sym_algebra::divideFactorized(lhs, rhs);
if (quotient != nullptr) {
return quotient;
}
auto bop = dynamic_cast<BinaryOp*>(value->definition());
if (!bop) {
return value;
}
if (!prove::isNonZero(bop->rhs(), var_info)) {
return value;
}
if (bop->getBinaryOpType() == BinaryOpType::Mod) {
if (prove::isMultipleOf(bop->lhs(), bop->rhs())) {
return IrBuilder::newConstant(0, *value->getDataType());
}
} else if (bop->getBinaryOpType() == BinaryOpType::Div) {
auto lhs = sym_algebra::factorize(bop->lhs());
auto rhs = sym_algebra::factorize(bop->rhs());
auto quotient = sym_algebra::divideFactorized(lhs, rhs);
if (quotient != nullptr) {
return quotient;
}
}
return value;
}

// Simplify div and mod by canceling common terms
//
// For div, use rule N): a / (b * c) = (a / b) / c to simplify division:
// Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
// then we can simplify x / y as:
// x / y = x / (gcd(x, y) * y') = (x / gcd(x, y)) / y' = x' / y'
//
// For mod, use rule
// O) If d divides a and b, then a % b = ((a / d) % (b / d)) * d
// Let y = gcd(x, y) * y' and x = gcd(x, y) * x'
// If gcd is nonzero, then we can simplify x % y as:
// x' % y' * gcd(x, y)
Val* cancelDivMod(Val* value, const VarInfoMap& var_info) {
auto divmod = toDivModOp(value->definition());
if (!divmod) {
return value;
}
auto op = divmod->getBinaryOpType();
if (op != BinaryOpType::Div && op != BinaryOpType::Mod) {
return value;
}
auto lhs = sym_algebra::factorize(divmod->lhs());
auto rhs = sym_algebra::factorize(divmod->rhs());
auto gcd = sym_algebra::greatestCommonDivisor({lhs, rhs});
if (gcd->isOneInt() || !prove::isNonZero(gcd, var_info)) {
return value;
}
auto numerator = sym_algebra::divideFactorized(lhs, gcd);
auto denominator = sym_algebra::divideFactorized(rhs, gcd);
if (op == BinaryOpType::Div) {
return IrBuilder::divExpr(numerator, denominator);
} else {
TORCH_INTERNAL_ASSERT(op == BinaryOpType::Mod);
return assoc_comm::flatten(
IrBuilder::mulExpr(IrBuilder::modExpr(numerator, denominator), gcd));
}
}

// Use the following rule to simplify div and mod:
// J) Distributivity of % over +:
// If compatible_sign(a, b), then (a + b) % c = (a % c + b % c) % c
// Q) If compatible_sign(a, b) and -|c| < a % c + b % c < |c|, then
// (a+b)/c = a/c + b/c
// In this pass we distribute div and mod for a special case:
// If compatible_sign(a, b), and a is a multiple of c, then:
// (a+b)/c = a/c + b/c
// (a + b) % c = b % c
Val* distributeDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
auto divmod = toDivModOp(value->definition());
if (!divmod) {
return value;
}
auto lhs = divmod->lhs();
auto rhs = divmod->rhs();
if (!prove::isNonZero(rhs, var_info)) {
return value;
}
auto fop = toFlattenedAdd(lhs->definition());
if (!fop) {
return value;
}
for (auto i : c10::irange(fop->inputs().size())) {
Val* divisible_term = fop->input(i);
if (!prove::isMultipleOf(divisible_term, rhs)) {
continue;
}
std::vector<Val*> other_terms;
other_terms.reserve(fop->inputs().size() - 1);
for (auto j : c10::irange(fop->inputs().size())) {
if (j == i) {
continue;
}
other_terms.emplace_back(fop->input(j));
}
Val* sum_of_other_terms = nullptr;
if (other_terms.size() == 1) {
sum_of_other_terms = other_terms.at(0);
} else {
sum_of_other_terms = IrBuilder::newScalar(*value->getDataType());
IrBuilder::create<FOp>(
BinaryOpType::Add, sum_of_other_terms, std::move(other_terms));
}
if (prove::hasCompatibleSign(
divisible_term, sum_of_other_terms, var_info)) {
std::vector<Val*> new_inputs;
auto term1 = IrBuilder::newScalar(*value->getDataType());
IrBuilder::create<BinaryOp>(
divmod->getBinaryOpType(), term1, divisible_term, rhs);
new_inputs.emplace_back(simplifyDivisibleDivMod(term1, var_info));
new_inputs.emplace_back(IrBuilder::newScalar(*value->getDataType()));
IrBuilder::create<BinaryOp>(
divmod->getBinaryOpType(), new_inputs[1], sum_of_other_terms, rhs);
auto output = IrBuilder::newScalar(*value->getDataType());
IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move(new_inputs));
return output;
}
}
return value;
}

// a * (b + c) -> a * b + a * c
Val* distributeMul(Val* value, const VarInfoMap& var_info) {
auto fop = toFlattenedMul(value->definition());
if (!fop) {
return value;
}
Val* flattened_add = nullptr;
std::vector<Val*> other_terms;
for (auto inp : fop->inputs()) {
if (flattened_add == nullptr && isFlattenedAdd(inp)) {
flattened_add = inp;
} else {
other_terms.emplace_back(inp);
}
}
if (flattened_add == nullptr) {
return value;
}
auto fadd_op = toFlattenedAdd(flattened_add->definition());
std::vector<Val*> add_terms;
for (auto inp : fadd_op->inputs()) {
std::vector<Val*> inputs = other_terms;
inputs.emplace_back(inp);
add_terms.emplace_back(IrBuilder::newScalar(*value->getDataType()));
IrBuilder::create<FOp>(
BinaryOpType::Mul, add_terms.back(), std::move(inputs));
}
auto output = IrBuilder::newScalar(*value->getDataType());
IrBuilder::create<FOp>(BinaryOpType::Add, output, std::move(add_terms));
return output;
}

} // namespace rules

#define RUN_PASS(pass_name) \
Expand All @@ -1298,20 +1452,44 @@ Val* simplifyDivisibleDivMod(Val* value, const VarInfoMap& var_info) {
}); \
logger->record(#pass_name, simplified)

// Requires that all the passes before the barrier to be converged before
// procceeding to the passes after the barrier.
#define PASS_BARRIER \
if (old_simplified != simplified) \
continue

Val* simplifyExpr(Val* value, const std::list<VarInfo>& variables) {
FusionGuard fg(value->fusion());
const VarInfoMap var_info(variables);
auto logger = debug_print::createLogger(value);

auto simplified = assoc_comm::flatten(value);
logger->record(debug_print::kFlattenName, simplified);

Val* simplified = value;
Val* old_simplified = nullptr;
while (old_simplified != simplified) {
old_simplified = simplified;

// Passes other than assoc_comm::flatten assumes that all
// associative-and-commutative binary ops (such as + *) are flattened. So
// that they don't need to worry about things like
// (a + b) + c vs a + (b + c)
// So the first step before running all other passes is always flattening
//
// Note that, some passes might create nested flattened ops, something like
// FlattenedAdd(FlattenedAdd(...), ...), so we should rerun flatten at the
// beginning of each round instead of flattening before the while loop.
simplified = assoc_comm::flatten(simplified);
logger->record(debug_print::kFlattenName, simplified);

RUN_PASS(eliminateTrivialComputation);
RUN_PASS(eliminateTrivialPredicate);
RUN_PASS(simplifyDivisibleDivMod);
RUN_PASS(eliminateTrivialPredicate);
RUN_PASS(simplifyDivisibleDivMod);
RUN_PASS(cancelDivMod);
PASS_BARRIER;
RUN_PASS(distributeDivisibleDivMod);
PASS_BARRIER;
RUN_PASS(distributeMul);
}

auto unflattened = assoc_comm::unflatten(simplified, variables);
Expand Down
Loading