Skip to content

Conversation

zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Dec 13, 2022

This PR adds a few common computer symbolic algebra functions factorize, divideFactorized, and greatestCommonDivisor. These functions are another three important building blocks to simplify matmul indices.
In order to test these functions, I added a new pass simplifyDivisibleDivMod. This pass has effect on matmul kernels:

[ RUN      ] NVFuserTest.FusionAmpereMatmulTNSwizzled_CUDA
Simplifying expression:
i1463 = ( ( ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) / 8 ) * 8 ) + ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) % 8 ) / 2 ) * 2 ) + ( ( ( ( threadIdx.z * 64 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) % 8 ) % 2 ) ) ) * 32 ) + ( ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) % 8 ) / 2 ) ^ ( ( ( i144 * 16 ) + ( ( ( threadIdx.x / 8 ) / ( ceilDiv(16, 8) ) ) * 8 ) ) / 8 ) ) * 8 ) + ( ( ( i144 * 16 ) + ( ( ( threadIdx.x / 8 ) / ( ceilDiv(16, 8) ) ) * 8 ) ) % 8 ) ) ) + ( ( i142 % 2 ) * ( ( ( ( ( ceilDiv(( ceilDiv(128, 8) ), 4) ) * ( ceilDiv(4, 2) ) ) * 2 ) * ( ( ( ceilDiv(8, 2) ) * 2 ) * ( ceilDiv(32, 8) ) ) ) * 8 ) ) )
assoc_comm::flatten:
i1518 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( ( FlattenedAdd(( FlattenedMul(i144, 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) / 2 ), 8) )) ) / 8 )) ), 8) ), ( ( FlattenedAdd(( FlattenedMul(i144, 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) / 2 ), 8) )) ) % 8 ), ( FlattenedMul(( i142 % 2 ), 4096) )) )
simplifyDivisibleDivMod:
i1560 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( FlattenedAdd(( FlattenedMul(i144, 2) ), ( ( threadIdx.x / 8 ) / 2 )) )) ), 8) ), 0, ( FlattenedMul(( i142 % 2 ), 4096) )) )
eliminateTrivialComputation:
i1561 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( FlattenedAdd(( FlattenedMul(i144, 2) ), ( ( threadIdx.x / 8 ) / 2 )) )) ), 8) ), ( FlattenedMul(( i142 % 2 ), 4096) )) )
assoc_comm::unflatten:
i1636 = ( ( ( ( ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) / 8 ) * 8 ) + ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) % 8 ) / 2 ) * 2 ) ) + ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) % 8 ) % 2 ) ) * 32 ) + ( 4096 * ( i142 % 2 ) ) ) + ( 8 * ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) % 8 ) / 2 ) ^ ( ( ( threadIdx.x / 8 ) / 2 ) + ( 2 * i144 ) ) ) ) )
================================================================================
Simplifying expression:
i1742 = ( ( ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( i146 + 1 ) * 16 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) ) / 8 ) * 8 ) + ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( i146 + 1 ) * 16 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) ) % 8 ) / 2 ) * 2 ) + ( ( ( ( threadIdx.z * 64 ) + ( ( ( i146 + 1 ) * 16 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) ) % 8 ) % 2 ) ) ) * 32 ) + ( ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( i146 + 1 ) * 16 ) + ( ( ( ( threadIdx.x / 8 ) % ( ceilDiv(16, 8) ) ) * 8 ) + ( threadIdx.x % 8 ) ) ) ) % 8 ) / 2 ) ^ ( ( ( i144 * 16 ) + ( ( ( threadIdx.x / 8 ) / ( ceilDiv(16, 8) ) ) * 8 ) ) / 8 ) ) * 8 ) + ( ( ( i144 * 16 ) + ( ( ( threadIdx.x / 8 ) / ( ceilDiv(16, 8) ) ) * 8 ) ) % 8 ) ) ) + ( ( i142 % 2 ) * ( ( ( ( ( ceilDiv(( ceilDiv(128, 8) ), 4) ) * ( ceilDiv(4, 2) ) ) * 2 ) * ( ( ( ceilDiv(8, 2) ) * 2 ) * ( ceilDiv(32, 8) ) ) ) * 8 ) ) )
assoc_comm::flatten:
i1809 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( ( FlattenedAdd(( FlattenedMul(i144, 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) / 2 ), 8) )) ) / 8 )) ), 8) ), ( ( FlattenedAdd(( FlattenedMul(i144, 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) / 2 ), 8) )) ) % 8 ), ( FlattenedMul(( i142 % 2 ), 4096) )) )
simplifyDivisibleDivMod:
i1863 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( FlattenedAdd(( FlattenedMul(i144, 2) ), ( ( threadIdx.x / 8 ) / 2 )) )) ), 8) ), 0, ( FlattenedMul(( i142 % 2 ), 4096) )) )
eliminateTrivialComputation:
i1864 = ( FlattenedAdd(( FlattenedMul(( FlattenedAdd(( FlattenedMul(( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) / 8 ), 8) ), ( FlattenedMul(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), 2) ), ( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) % 2 )) ), 32) ), ( FlattenedMul(( FlattenedXor(( ( ( FlattenedAdd(( FlattenedMul(threadIdx.z, 64) ), ( FlattenedMul(( FlattenedAdd(i146, 1) ), 16) ), ( FlattenedMul(( ( threadIdx.x / 8 ) % 2 ), 8) ), ( threadIdx.x % 8 )) ) % 8 ) / 2 ), ( FlattenedAdd(( FlattenedMul(i144, 2) ), ( ( threadIdx.x / 8 ) / 2 )) )) ), 8) ), ( FlattenedMul(( i142 % 2 ), 4096) )) )
assoc_comm::unflatten:
i1975 = ( ( ( 4096 * ( i142 % 2 ) ) + ( 32 * ( ( ( 8 * ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) + ( 16 * ( 1 + i146 ) ) ) / 8 ) ) + ( 2 * ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) + ( 16 * ( 1 + i146 ) ) ) % 8 ) / 2 ) ) ) + ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) + ( 16 * ( 1 + i146 ) ) ) % 8 ) % 2 ) ) ) ) + ( 8 * ( ( ( ( threadIdx.x / 8 ) / 2 ) + ( 2 * i144 ) ) ^ ( ( ( ( ( ( threadIdx.z * 64 ) + ( ( ( threadIdx.x / 8 ) % 2 ) * 8 ) ) + ( threadIdx.x % 8 ) ) + ( 16 * ( 1 + i146 ) ) ) % 8 ) / 2 ) ) ) )
================================================================================
[       OK ] NVFuserTest.FusionAmpereMatmulTNSwizzled_CUDA (1411 ms)

but this pass alone still has a very narrow application. So this pass is mostly useful for unit testing my factorize, divideFactorized, greatestCommonDivisor.

Good news is, combining this PR with #2273, I should be able to unlock some more useful simplifications and should be able to hoist the swizzle out of the inner loop. I will write these more useful passes in a future PR.

@zasdfgbnm zasdfgbnm changed the title Expr simplifier: simplifyZeroMod Expr simplifier: divisibility analysis Dec 17, 2022
// - x is zero

bool isNonZero(Val* value) {
bool isNonNegative(Val* value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod in unit tests. This will be replaced by #2273

return false;
}

bool isPositive(Val* value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod in unit tests. This will be replaced by #2273

return false;
}

bool isNonZero(Val* value) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to read this, it is just a temporary solution allowing me to create non-zero variables in order to trigger simplifyDivisibleDivMod in unit tests. This will be replaced by #2273

@zasdfgbnm zasdfgbnm marked this pull request as ready for review December 17, 2022 02:03
@zasdfgbnm zasdfgbnm requested a review from naoyam December 17, 2022 02:11
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we be able to reuse any of this for https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/scheduler/vectorize_helper.h#L27-L57 ?

I wonder if we should also use this type of analysis for: https://github.com/csarofeen/pytorch/blob/devel/third_party/nvfuser/csrc/lower_divisible_split.cpp

LGTM, just please add comments to each function.

return toFlattenedMul(x->definition()) != nullptr;
}

std::pair<int64_t, std::list<Val*>> getConstAndSymbolicFactors(Val* x) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some descriptions to your functions (even if it's just a sentence)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

quoient_const_factor = x_factors.first / y_factors.first;
}

for (auto yf : y_factors.second) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should do some refactoring in a separate PR.


// Symbolic gcd, for example: greatestCommonDivisor({6*a*b, 9*b*c}) -> 3*b
Val* greatestCommonDivisor(const std::vector<Val*>& inputs) {
// The gcd of the constant part. Because gcd(0, a) = gcd(a, 0) = 0, it is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Isn't gcd(0, a) = a?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, thanks for catching!

IrBuilder::create<FOp>(
BinaryOpType::Mul, product, std::vector<Val*>{quotient, gcd});
// quotient might contain nested FlattenedAdd, so we need to reflatten it.
return assoc_comm::flatten(product);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it necessary to flatten product? Isn't it already flatten?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comment in code:

  // Quotient might contain nested FlattenedAdd, for example, if we have:
  //   FlattenedAdd(a * FlattenedAdd(b, c), a * FlattenedAdd(d, e))
  // then the gcd will be a, and the quotient will be:
  //   FlattenedAdd(FlattenedAdd(b, c), FlattenedAdd(d, e))
  // So we need to reflatten to get rid of this nested FlattenedAdd.

"$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}/nvfuser>"
)
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 14)
set_property(TARGET ${NVFUSER_CODEGEN} PROPERTY CXX_STANDARD 17)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @jjsjann123 I bumped this to C++17 to be able to use std::gcd. PyTorch is already using C++17, so I believe this is a safe change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. Thanks for cleaning this up~~

@zasdfgbnm zasdfgbnm merged commit 01d7545 into devel Jan 16, 2023
@zasdfgbnm zasdfgbnm deleted the simplify-trivial-mod branch January 16, 2023 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants