Skip to content

Minor update on cp.async code generation. #1901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
6f5ba21
use custom propagator in ampere TN
shmsong Jul 12, 2022
2329caf
add tile ordering utilities
shmsong Jul 12, 2022
121af43
initial matmul scheduler implementation
shmsong Jul 13, 2022
f958c53
use matmul scheduler prototype on ampere and turing test cases
shmsong Jul 13, 2022
397f74c
extend to support Volta
shmsong Jul 13, 2022
00d9a57
minor cleanup
shmsong Jul 13, 2022
d7035aa
comment cleanup
shmsong Jul 13, 2022
9ffc61d
minor fix
shmsong Jul 13, 2022
ed0f525
add fragment iteration and use it in matmul scheduler
shmsong Jun 5, 2022
c972116
use scheduler params for tests
shmsong Jul 14, 2022
d12a90f
fragment support in double buffer
shmsong Jul 14, 2022
c306b9b
add register double buffering test cases
shmsong Jul 14, 2022
63f561f
clean up custom transform propagator
shmsong Jul 19, 2022
3d47c1f
Merge remote-tracking branch 'origin/devel' into matmul_propagator
shmsong Jul 19, 2022
29f88c7
rebase fix
shmsong Jul 19, 2022
d029b9f
comment
shmsong Jul 19, 2022
5ac053f
move bounded selector to common area
shmsong Jul 20, 2022
b51d247
Add logic to handle fake boundary tensors in selection.
shmsong Jul 20, 2022
aba5087
naming and comment
shmsong Jul 20, 2022
426c381
remove unused parameters from mma node
shmsong Jul 20, 2022
6d4f377
remove unnecessary parameters from mma ir node
shmsong Jul 20, 2022
5e1f41f
rename scheduling variables
shmsong Jul 20, 2022
1960da9
change accumulator tv interface
shmsong Jul 20, 2022
3a411c2
Update torch/csrc/jit/codegen/cuda/scheduler/utils.h
shmsong Jul 20, 2022
8f2e4da
PR feedback
shmsong Jul 20, 2022
eef3a97
Merge branch 'matmul_propagator' of https://github.com/csarofeen/pyto…
shmsong Jul 20, 2022
6ad2967
pipe through parallel type position
shmsong Jul 20, 2022
65c8f0a
Merge remote-tracking branch 'origin/devel' into matmul_propagator
shmsong Jul 20, 2022
cd03b00
Revert "fragment support in double buffer"
shmsong Jul 20, 2022
380dd66
Merge branch 'matmul_propagator' into fragment_iter
shmsong Jul 20, 2022
6ce6ff6
use cache op to handle double buffer input
shmsong Jul 20, 2022
62f09fc
add more comment in matmul scheduler
shmsong Jul 21, 2022
538aa8b
more comments
shmsong Jul 21, 2022
91f44fd
comment fix
shmsong Jul 21, 2022
75d51a5
Merge remote-tracking branch 'origin/devel' into fragment_iter
shmsong Jul 25, 2022
546844a
rebase fix
shmsong Jul 25, 2022
ca55194
add inline pred for cpasync
shmsong Jul 12, 2022
2b6f447
Merge remote-tracking branch 'origin/devel' into speculative_index
shmsong Aug 1, 2022
41c221a
minor cleanup
shmsong Aug 1, 2022
214f2a2
add inlining test in unit
shmsong Aug 1, 2022
99e4d4c
add option to dump ptx
shmsong Aug 1, 2022
da45d51
Merge remote-tracking branch 'origin/devel' into speculative_index
shmsong Aug 10, 2022
c4a8739
rebase fix
shmsong Aug 10, 2022
7f42537
Fix missing thread predicates
naoyam Sep 27, 2022
93124a3
Merge branch 'devel' of github.com:csarofeen/pytorch into speculative…
zasdfgbnm Sep 28, 2022
ebeb201
fix merge
zasdfgbnm Sep 28, 2022
cde6e4d
fix merge
zasdfgbnm Sep 28, 2022
022c443
format
zasdfgbnm Sep 28, 2022
c90b90f
Merge branch 'devel' of github.com:csarofeen/pytorch into speculative…
zasdfgbnm Sep 29, 2022
3417e8e
Merge branch 'speculative_index' of github.com:csarofeen/pytorch into…
zasdfgbnm Sep 29, 2022
52099e0
cleanup
zasdfgbnm Oct 3, 2022
7d6e28d
Merge branch 'devel' of github.com:csarofeen/pytorch into speculative…
zasdfgbnm Oct 6, 2022
1f8ecba
cleanup clone
zasdfgbnm Oct 6, 2022
0742e7e
fix
zasdfgbnm Oct 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,18 @@ class CudaKernelGenerator : private OptOutConstDispatch {
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
auto dtype = ldst->in()->getDataType().value();

indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
if (ldst->predicate() == nullptr) {
// Out of line predicate variant
indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
} else {
// Inline predicate variant
indent() << "Ampere::cpAsync("
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
<< genVectorPointer(ldst->in(), dtype, vec_size) << ","
<< genInline(ldst->predicate()) << ");\n";
}
}

void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/executor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,12 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
// compile to sass is not allowed prior to CUDA 11.1
compile_to_sass = false;
#endif

if (isOptionDisabled(DisableOption::CompileToSass)) {
// Allows manually disabling compilation to sass
// so the intermediate ptx could be checked.
compile_to_sass = false;
}
// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
// which gives better backwards compatibility to work on older driver,
// (since older driver doesn't necessrily recognize PTX emitted by new
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,9 @@ void IndexLowering::handleGroupedGridWelford(
void IndexLowering::handle(const LoadStoreOp* ldst) {
const auto in = lowerSrcIndex(ldst->in(), ldst->out());
const auto out = lowerDstIndex(ldst->out());
pushBack(IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in));
auto new_ldst = IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in)
->withPredicate(ldst->predicate());
pushBack(new_ldst);
GpuLower::current()->propagateExprInfo(ldst, back());
}

Expand Down
51 changes: 31 additions & 20 deletions torch/csrc/jit/codegen/cuda/lower_predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,40 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
// Replace expr predicate with bool conditional
auto conditional = generateConditional(expr->predicate());
if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
// TODO: This logic doesn't seem to fit well here, for unswitch the
// logic is in the unroll loop to set the thread predicate to the expr.
// I didn't have a quick way to do that so placing this here for now.
TORCH_INTERNAL_ASSERT(
expr->isA<kir::IfThenElse>(),
"Predicate handling expects ITE statement.");
auto ite = expr->as<kir::IfThenElse>();

TORCH_INTERNAL_ASSERT(
ite->thenBody().size() == 1,
"Expecting predicated body to only have one vectorized expression.");
auto vec_expr = ite->thenBody()[0];
TORCH_INTERNAL_ASSERT(
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
"Vectorize predicate exprs only supported on set operations.");
TORCH_INTERNAL_ASSERT(
ir_utils::isTvOp(vec_expr),
"Vectorize predicate exprs only supported on tensor view operations.");
if (!vec_expr->inputs()[0]->isConstScalar()) {
if (expr->isA<kir::IfThenElse>()) {
// TODO: This logic doesn't seem to fit well here, for unswitch the
// logic is in the unroll loop to set the thread predicate to the
// expr. I didn't have a quick way to do that so placing this here for
// now.
auto ite = expr->as<kir::IfThenElse>();

TORCH_INTERNAL_ASSERT(
ite->thenBody().size() == 1,
"Expecting predicated body to only have one vectorized expression.");
auto vec_expr = ite->thenBody()[0];
TORCH_INTERNAL_ASSERT(
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
"Vectorize predicate exprs only supported on set operations.");
TORCH_INTERNAL_ASSERT(
ir_utils::isTvOp(vec_expr),
"Vectorize predicate exprs only supported on tensor view operations.");
if (!vec_expr->inputs()[0]->isConstScalar()) {
conditional = SimplifyingIrBuilder::andExpr(
conditional,
GpuLower::current()->threadPredMap().getPredicate(
ir_utils::getTvOutput(vec_expr)))
->as<Bool>();
}
} else {
TORCH_INTERNAL_ASSERT(lower_utils::supportInlinePredicate(expr));
auto thread_pred = GpuLower::current()->threadPredMap().getPredicate(
ir_utils::getTvOutput(expr));
TORCH_INTERNAL_ASSERT(
thread_pred->isConst() && thread_pred->value().value());
conditional = SimplifyingIrBuilder::andExpr(
conditional,
GpuLower::current()->threadPredMap().getPredicate(
ir_utils::getTvOutput(vec_expr)))
ir_utils::getTvOutput(expr)))
->as<Bool>();
}
}
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ void UnrollPass::handle(Expr* expr) {
PredicateType::Inline, expr, thread_pred);
}

if (lower_utils::supportInlinePredicate(expr)) {
expr_with_predicate = expr_with_predicate->withPredicate(pred);
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
return;
}

// If we need a predicate, put expr inside an if then else
kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
if (for_loops_.empty()) {
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,17 @@ BasicAllocInfo getAllocInformation(
return info;
}

//! Implementing this in here to avoid including too many headers
//! in type.cpp. Conceptually this should be a generic definition
//! rather than a util.
bool supportInlinePredicate(Expr* expr) {
if (ir_utils::isCpAsyncOp(expr)) {
return true;
}
// TODO: build out support.
return false;
}

} // namespace lower_utils

} // namespace cuda
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ BasicAllocInfo getAllocInformation(
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, IterDomain*>& id_map = {},
bool use_id_map = false);

//! Returns true if the expression has a variant that takes a predicate
//! as an inline argument.
bool supportInlinePredicate(Expr* expr);

} // namespace lower_utils

} // namespace cuda
Expand Down
25 changes: 25 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,31 @@ DEVICE_INLINE void cpAsync(
"n"(byte_size));
}

// Global to SMEM load that is asynchronous,
// not guaranteed to be completed until cpAsyncBarrier() is called.
template <typename dtype, int len>
DEVICE_INLINE void cpAsync(
Array<dtype, len, len>* smem_ptr,
void const* gmem_ptr,
bool predicate) {
unsigned smem_addr = util::toSmem(&(smem_ptr->array[0]));
constexpr int byte_size = sizeof(dtype) * len;

static_assert(
byte_size == 4 || byte_size == 8 || byte_size == 16,
"cp_async : unsupported byte size");

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %3, 0;\n"
"@p cp.async.ca.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem_addr),
"l"(gmem_ptr),
"n"(byte_size),
"r"((int)predicate));
}

// TODO: Might have a different category of sync if we want to build out this:
DEVICE_INLINE void cpAsyncBarrier() {
asm volatile("cp.async.wait_all;");
Expand Down
36 changes: 0 additions & 36 deletions torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1434,42 +1434,6 @@ TEST_F(NVFuserTest, FusionSimplePWise_CUDA) {
TORCH_CHECK(output_ref.equal(output));
}

TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to the new file to keep this file below 10k

Fusion fusion;
FusionGuard fg(&fusion);

// requires ampere+ GPU
if (!deviceMajorMinorCheck(8)) {
GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs";
return;
}

auto tv0 = makeContigTensor(1);

fusion.addInput(tv0);

auto tv1 = set(tv0);

fusion.addOutput(tv1);

auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync);
tv_cache->setMemoryType(MemoryType::Shared);

tv1->split(0, 16);
tv0->computeAt(tv1, 1);

tv_cache->circularBuffer(10);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input1 = at::randn({255}, options);

FusionExecutor fe;
fe.compileFusion(&fusion, {input1});
auto cg_outputs = fe.runFusion({input1});

testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
69 changes: 68 additions & 1 deletion torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4073,7 +4073,7 @@ TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) {
// Swizzle inner tile of tv2
tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop);

// Make tv2 swizzled and half-inlined (unsupported).
// Make tv2 swizzled and partially-inlined (unsupported).
tv0->computeAt(tv3, -2);

FusionExecutor fe;
Expand Down Expand Up @@ -6440,6 +6440,73 @@ TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) {
}
}

TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

// requires ampere+ GPU
if (!deviceMajorMinorCheck(8)) {
GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs";
return;
}

auto tv0 = makeContigTensor(1);

fusion.addInput(tv0);

auto tv1 = set(tv0);

fusion.addOutput(tv1);

auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync);
tv_cache->setMemoryType(MemoryType::Shared);

tv1->split(0, 16);
tv0->computeAt(tv1, 1);

tv_cache->circularBuffer(10);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input1 = at::randn({255}, options);

// Add check that the cp async op has an inlined predicate.
class InlinedCpAsyncPredChecker : public kir::IrVisitor {
public:
using kir::IrVisitor::handle;

private:
void handle(kir::IfThenElse* ite) final {
auto prev_within_ite = within_ite_;
within_ite_ = true;
kir::IrVisitor::handle(ite);
within_ite_ = prev_within_ite;
}

void handle(LoadStoreOp* ldst) final {
if (ldst->opType() == LoadStoreOpType::CpAsync) {
TORCH_INTERNAL_ASSERT(!within_ite_, "CPASYNC predicate not inlined");
TORCH_INTERNAL_ASSERT(
ldst->predicate()->hasValue() &&
!ldst->predicate()->value()->isConst(),
"CPASYNC predicate is not generated");
}
}

private:
bool within_ite_ = false;
} pred_checker;

// Check that cp async is inlined:
GpuLower gpulw(&fusion);
pred_checker.handle(gpulw.kernel()->topLevelExprs());

FusionExecutor fe;
fe.compileFusion(&fusion, {input1});
auto cg_outputs = fe.runFusion({input1});

testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__);
}

// Test file size should be up to 10K LoC. Create a new file for more tests.

} // namespace jit
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ auto parseDebugDumpOptions() {
auto parseDisableOptions() {
std::unordered_map<DisableOption, bool> options_map = {
{DisableOption::ArchCheck, false},
{DisableOption::CompileToSass, false},
{DisableOption::Fallback, false},
{DisableOption::Fma, false},
{DisableOption::IndexHoist, false},
Expand All @@ -145,6 +146,8 @@ auto parseDisableOptions() {
const auto token = options_view.substr(0, end_pos);
if (token == "arch_check") {
options_map[DisableOption::ArchCheck] = true;
} else if (token == "compile_to_sass") {
options_map[DisableOption::CompileToSass] = true;
} else if (token == "fallback") {
options_map[DisableOption::Fallback] = true;
} else if (token == "fma") {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
//!
enum class DisableOption {
ArchCheck, //! Disable hardware-specific checks to enable cross arch debug
CompileToSass, //! Disable direct compilation to sass so the ptx can be
//! examined
Fallback, //! Disable fallback
Fma, //! Disable FMA instructions
IndexHoist, //! Disable index hoisting
Expand Down