-
Notifications
You must be signed in to change notification settings - Fork 7
Expr simplifier: divisibility analysis #2258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
// - x is zero | ||
|
||
bool isNonZero(Val* value) { | ||
bool isNonNegative(Val* value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod
in unit tests. This will be replaced by #2273
return false; | ||
} | ||
|
||
bool isPositive(Val* value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod
in unit tests. This will be replaced by #2273
return false; | ||
} | ||
|
||
bool isNonZero(Val* value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod
in unit tests. This will be replaced by #2273
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would we be able to reuse any of this for https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/scheduler/vectorize_helper.h#L27-L57 ?
I wonder if we should also use this type of analysis for: https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_divisible_split.cpp
LGTM, just please add comments to each function.
return toFlattenedMul(x->definition()) != nullptr; | ||
} | ||
|
||
std::pair<int64_t, std::list<Val*>> getConstAndSymbolicFactors(Val* x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some descriptions to your functions (even if it's just a sentence)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
quoient_const_factor = x_factors.first / y_factors.first; | ||
} | ||
|
||
for (auto yf : y_factors.second) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I should do some refactoring in a separate PR.
|
||
// Symbolic gcd, for example: greatestCommonDivisor({6*a*b, 9*b*c}) -> 3*b | ||
Val* greatestCommonDivisor(const std::vector<Val*>& inputs) { | ||
// The gcd of the constant part. Because gcd(0, a) = gcd(a, 0) = 0, it is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Isn't gcd(0, a) = a
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, thanks for catching!
IrBuilder::create<FOp>( | ||
BinaryOpType::Mul, product, std::vector<Val*>{quotient, gcd}); | ||
// quotient might contain nested FlattenedAdd, so we need to reflatten it. | ||
return assoc_comm::flatten(product); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it necessary to flatten product
? Isn't it already flatten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the comment in code:
// Quotient might contain nested FlattenedAdd, for example, if we have:
// FlattenedAdd(a * FlattenedAdd(b, c), a * FlattenedAdd(d, e))
// then the gcd will be a, and the quotient will be:
// FlattenedAdd(FlattenedAdd(b, c), FlattenedAdd(d, e))
// So we need to reflatten to get rid of this nested FlattenedAdd.
"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/nvfuser>" | ||
) | ||
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 14) | ||
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @jjsjann123 I bumped this to C++17 to be able to use std::gcd
. PyTorch is already using C++17, so I believe this is a safe change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is. Thanks for cleaning this up~~
This PR adds a few common computer symbolic algebra functions
factorize
,divideFactorized
, andgreatestCommonDivisor
. These functions are another three important building blocks to simplify matmul indices.In order to test these functions, I added a new pass
simplifyDivisibleDivMod
. This pass has effect on matmul kernels:but this pass alone still has a very narrow application. So this pass is mostly useful for unit testing my
factorize
,divideFactorized
,greatestCommonDivisor
.Good news is, combining this PR with #2273, I should be able to unlock some more useful simplifications and should be able to hoist the swizzle out of the inner loop. I will write these more useful passes in a future PR.