Skip to content

Commit d2ca7e3

Browse files
shmsongzasdfgbnmnaoyam
authored
Minor update on cp.async code generation. (#1901)
* use custom propagator in ampere TN * add tile ordering utilities * initial matmul scheduler implementation * use matmul scheduler prototype on ampere and turing test cases * extend to support Volta * minor cleanup * comment cleanup * minor fix * add fragment iteration and use it in matmul scheduler * use scheduler params for tests * fragment support in double buffer * add register double buffering test cases * clean up custom transform propagator * rebase fix * comment * move bounded selector to common area * Add logic to handle fake boundary tensors in selection. * naming and comment * remove unused parameters from mma node * remove unnecessary parameters from mma ir node * rename scheduling variables * change accumulator tv interface * Update torch/csrc/jit/codegen/cuda/scheduler/utils.h Co-authored-by: Gao, Xiang <[email protected]> * PR feedback * pipe through parallel type position * Revert "fragment support in double buffer" This reverts commit d12a90f. * use cache op to handle double buffer input * add more comment in matmul scheduler * more comments * comment fix * rebase fix * add inline pred for cpasync * minor cleanup * add inlining test in unit * add option to dump ptx * rebase fix * Fix missing thread predicates Unlikely to matter, but should be necessary * fix merge * fix merge * format * cleanup * cleanup clone * fix Co-authored-by: Gao, Xiang <[email protected]> Co-authored-by: Naoya Maruyama <[email protected]>
1 parent d36cf61 commit d2ca7e3

12 files changed

+172
-61
lines changed

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,18 @@ class CudaKernelGenerator : private OptOutConstDispatch {
543543
void genCpAsync(const LoadStoreOp* ldst, int vec_size) {
544544
auto dtype = ldst->in()->getDataType().value();
545545

546-
indent() << "Ampere::cpAsync("
547-
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
548-
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
546+
if (ldst->predicate() == nullptr) {
547+
// Out of line predicate variant
548+
indent() << "Ampere::cpAsync("
549+
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
550+
<< genVectorPointer(ldst->in(), dtype, vec_size) << ");\n";
551+
} else {
552+
// Inline predicate variant
553+
indent() << "Ampere::cpAsync("
554+
<< genVectorPointer(ldst->out(), dtype, vec_size) << ","
555+
<< genVectorPointer(ldst->in(), dtype, vec_size) << ","
556+
<< genInline(ldst->predicate()) << ");\n";
557+
}
549558
}
550559

551560
void genLdMatrix(const LoadStoreOp* ldst, int vector_word_size) {

torch/csrc/jit/codegen/cuda/executor_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,12 @@ std::pair<NvrtcFunction, std::string> nvrtcCompile(
10191019
// compile to sass is not allowed prior to CUDA 11.1
10201020
compile_to_sass = false;
10211021
#endif
1022+
1023+
if (isOptionDisabled(DisableOption::CompileToSass)) {
1024+
// Allows manually disabling compilation to sass
1025+
// so the intermediate ptx could be checked.
1026+
compile_to_sass = false;
1027+
}
10221028
// CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
10231029
// which gives better backwards compatibility to work on older driver,
10241030
// (since older driver doesn't necessrily recognize PTX emitted by new

torch/csrc/jit/codegen/cuda/lower_index.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,9 @@ void IndexLowering::handleGroupedGridWelford(
982982
void IndexLowering::handle(const LoadStoreOp* ldst) {
983983
const auto in = lowerSrcIndex(ldst->in(), ldst->out());
984984
const auto out = lowerDstIndex(ldst->out());
985-
pushBack(IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in));
985+
auto new_ldst = IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in)
986+
->withPredicate(ldst->predicate());
987+
pushBack(new_ldst);
986988
GpuLower::current()->propagateExprInfo(ldst, back());
987989
}
988990

torch/csrc/jit/codegen/cuda/lower_predicate.cpp

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,40 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
4343
// Replace expr predicate with bool conditional
4444
auto conditional = generateConditional(expr->predicate());
4545
if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
46-
// TODO: This logic doesn't seem to fit well here, for unswitch the
47-
// logic is in the unroll loop to set the thread predicate to the expr.
48-
// I didn't have a quick way to do that so placing this here for now.
49-
TORCH_INTERNAL_ASSERT(
50-
expr->isA<kir::IfThenElse>(),
51-
"Predicate handling expects ITE statement.");
52-
auto ite = expr->as<kir::IfThenElse>();
53-
54-
TORCH_INTERNAL_ASSERT(
55-
ite->thenBody().size() == 1,
56-
"Expecting predicated body to only have one vectorized expression.");
57-
auto vec_expr = ite->thenBody()[0];
58-
TORCH_INTERNAL_ASSERT(
59-
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
60-
"Vectorize predicate exprs only supported on set operations.");
61-
TORCH_INTERNAL_ASSERT(
62-
ir_utils::isTvOp(vec_expr),
63-
"Vectorize predicate exprs only supported on tensor view operations.");
64-
if (!vec_expr->inputs()[0]->isConstScalar()) {
46+
if (expr->isA<kir::IfThenElse>()) {
47+
// TODO: This logic doesn't seem to fit well here, for unswitch the
48+
// logic is in the unroll loop to set the thread predicate to the
49+
// expr. I didn't have a quick way to do that so placing this here for
50+
// now.
51+
auto ite = expr->as<kir::IfThenElse>();
52+
53+
TORCH_INTERNAL_ASSERT(
54+
ite->thenBody().size() == 1,
55+
"Expecting predicated body to only have one vectorized expression.");
56+
auto vec_expr = ite->thenBody()[0];
57+
TORCH_INTERNAL_ASSERT(
58+
vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
59+
"Vectorize predicate exprs only supported on set operations.");
60+
TORCH_INTERNAL_ASSERT(
61+
ir_utils::isTvOp(vec_expr),
62+
"Vectorize predicate exprs only supported on tensor view operations.");
63+
if (!vec_expr->inputs()[0]->isConstScalar()) {
64+
conditional = SimplifyingIrBuilder::andExpr(
65+
conditional,
66+
GpuLower::current()->threadPredMap().getPredicate(
67+
ir_utils::getTvOutput(vec_expr)))
68+
->as<Bool>();
69+
}
70+
} else {
71+
TORCH_INTERNAL_ASSERT(lower_utils::supportInlinePredicate(expr));
72+
auto thread_pred = GpuLower::current()->threadPredMap().getPredicate(
73+
ir_utils::getTvOutput(expr));
74+
TORCH_INTERNAL_ASSERT(
75+
thread_pred->isConst() && thread_pred->value().value());
6576
conditional = SimplifyingIrBuilder::andExpr(
6677
conditional,
6778
GpuLower::current()->threadPredMap().getPredicate(
68-
ir_utils::getTvOutput(vec_expr)))
79+
ir_utils::getTvOutput(expr)))
6980
->as<Bool>();
7081
}
7182
}

torch/csrc/jit/codegen/cuda/lower_unroll.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ void UnrollPass::handle(Expr* expr) {
138138
PredicateType::Inline, expr, thread_pred);
139139
}
140140

141+
if (lower_utils::supportInlinePredicate(expr)) {
142+
expr_with_predicate = expr_with_predicate->withPredicate(pred);
143+
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
144+
return;
145+
}
146+
141147
// If we need a predicate, put expr inside an if then else
142148
kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
143149
if (for_loops_.empty()) {

torch/csrc/jit/codegen/cuda/lower_utils.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,17 @@ BasicAllocInfo getAllocInformation(
727727
return info;
728728
}
729729

730+
//! Implementing this in here to avoid including too many headers
731+
//! in type.cpp. Conceptually this should be a generic definition
732+
//! rather than a util.
733+
bool supportInlinePredicate(Expr* expr) {
734+
if (ir_utils::isCpAsyncOp(expr)) {
735+
return true;
736+
}
737+
// TODO: build out support.
738+
return false;
739+
}
740+
730741
} // namespace lower_utils
731742

732743
} // namespace cuda

torch/csrc/jit/codegen/cuda/lower_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ BasicAllocInfo getAllocInformation(
263263
const std::vector<kir::ForLoop*>& loops,
264264
const std::unordered_map<IterDomain*, IterDomain*>& id_map = {},
265265
bool use_id_map = false);
266+
267+
//! Returns true if the expression has a variant that takes a predicate
268+
//! as an inline argument.
269+
bool supportInlinePredicate(Expr* expr);
270+
266271
} // namespace lower_utils
267272

268273
} // namespace cuda

torch/csrc/jit/codegen/cuda/runtime/memory.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,31 @@ DEVICE_INLINE void cpAsync(
152152
"n"(byte_size));
153153
}
154154

155+
// Global to SMEM load that is asynchronous,
156+
// not guaranteed to be completed until cpAsyncBarrier() is called.
157+
template <typename dtype, int len>
158+
DEVICE_INLINE void cpAsync(
159+
Array<dtype, len, len>* smem_ptr,
160+
void const* gmem_ptr,
161+
bool predicate) {
162+
unsigned smem_addr = util::toSmem(&(smem_ptr->array[0]));
163+
constexpr int byte_size = sizeof(dtype) * len;
164+
165+
static_assert(
166+
byte_size == 4 || byte_size == 8 || byte_size == 16,
167+
"cp_async : unsupported byte size");
168+
169+
asm volatile(
170+
"{\n"
171+
" .reg .pred p;\n"
172+
" setp.ne.b32 p, %3, 0;\n"
173+
"@p cp.async.ca.shared.global [%0], [%1], %2;\n"
174+
"}\n" ::"r"(smem_addr),
175+
"l"(gmem_ptr),
176+
"n"(byte_size),
177+
"r"((int)predicate));
178+
}
179+
155180
// TODO: Might have a different category of sync if we want to build out this:
156181
DEVICE_INLINE void cpAsyncBarrier() {
157182
asm volatile("cp.async.wait_all;");

torch/csrc/jit/codegen/cuda/test/test_gpu1.cpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,42 +1434,6 @@ TEST_F(NVFuserTest, FusionSimplePWise_CUDA) {
14341434
TORCH_CHECK(output_ref.equal(output));
14351435
}
14361436

1437-
TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) {
1438-
Fusion fusion;
1439-
FusionGuard fg(&fusion);
1440-
1441-
// requires ampere+ GPU
1442-
if (!deviceMajorMinorCheck(8)) {
1443-
GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs";
1444-
return;
1445-
}
1446-
1447-
auto tv0 = makeContigTensor(1);
1448-
1449-
fusion.addInput(tv0);
1450-
1451-
auto tv1 = set(tv0);
1452-
1453-
fusion.addOutput(tv1);
1454-
1455-
auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync);
1456-
tv_cache->setMemoryType(MemoryType::Shared);
1457-
1458-
tv1->split(0, 16);
1459-
tv0->computeAt(tv1, 1);
1460-
1461-
tv_cache->circularBuffer(10);
1462-
1463-
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
1464-
at::Tensor input1 = at::randn({255}, options);
1465-
1466-
FusionExecutor fe;
1467-
fe.compileFusion(&fusion, {input1});
1468-
auto cg_outputs = fe.runFusion({input1});
1469-
1470-
testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__);
1471-
}
1472-
14731437
TEST_F(NVFuserTest, FusionSimplePWiseDtypeComplex_CUDA) {
14741438
Fusion fusion;
14751439
FusionGuard fg(&fusion);

torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4073,7 +4073,7 @@ TEST_F(NVFuserTest, FusionLoopSwizzleCheck1_CUDA) {
40734073
// Swizzle inner tile of tv2
40744074
tv2->swizzle(Swizzle2DType::ZShape, -2, -1, SwizzleMode::Loop);
40754075

4076-
// Make tv2 swizzled and half-inlined (unsupported).
4076+
// Make tv2 swizzled and partially-inlined (unsupported).
40774077
tv0->computeAt(tv3, -2);
40784078

40794079
FusionExecutor fe;
@@ -6440,6 +6440,73 @@ TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) {
64406440
}
64416441
}
64426442

6443+
TEST_F(NVFuserTest, FusionSimpleAmperePipeline_CUDA) {
6444+
Fusion fusion;
6445+
FusionGuard fg(&fusion);
6446+
6447+
// requires ampere+ GPU
6448+
if (!deviceMajorMinorCheck(8)) {
6449+
GTEST_SKIP() << "skipping tests on pre-AMPERE GPUs";
6450+
return;
6451+
}
6452+
6453+
auto tv0 = makeContigTensor(1);
6454+
6455+
fusion.addInput(tv0);
6456+
6457+
auto tv1 = set(tv0);
6458+
6459+
fusion.addOutput(tv1);
6460+
6461+
auto tv_cache = tv0->cacheAfter(LoadStoreOpType::CpAsync);
6462+
tv_cache->setMemoryType(MemoryType::Shared);
6463+
6464+
tv1->split(0, 16);
6465+
tv0->computeAt(tv1, 1);
6466+
6467+
tv_cache->circularBuffer(10);
6468+
6469+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
6470+
at::Tensor input1 = at::randn({255}, options);
6471+
6472+
// Add check that the cp async op has an inlined predicate.
6473+
class InlinedCpAsyncPredChecker : public kir::IrVisitor {
6474+
public:
6475+
using kir::IrVisitor::handle;
6476+
6477+
private:
6478+
void handle(kir::IfThenElse* ite) final {
6479+
auto prev_within_ite = within_ite_;
6480+
within_ite_ = true;
6481+
kir::IrVisitor::handle(ite);
6482+
within_ite_ = prev_within_ite;
6483+
}
6484+
6485+
void handle(LoadStoreOp* ldst) final {
6486+
if (ldst->opType() == LoadStoreOpType::CpAsync) {
6487+
TORCH_INTERNAL_ASSERT(!within_ite_, "CPASYNC predicate not inlined");
6488+
TORCH_INTERNAL_ASSERT(
6489+
ldst->predicate()->hasValue() &&
6490+
!ldst->predicate()->value()->isConst(),
6491+
"CPASYNC predicate is not generated");
6492+
}
6493+
}
6494+
6495+
private:
6496+
bool within_ite_ = false;
6497+
} pred_checker;
6498+
6499+
// Check that cp async is inlined:
6500+
GpuLower gpulw(&fusion);
6501+
pred_checker.handle(gpulw.kernel()->topLevelExprs());
6502+
6503+
FusionExecutor fe;
6504+
fe.compileFusion(&fusion, {input1});
6505+
auto cg_outputs = fe.runFusion({input1});
6506+
6507+
testValidate(&fusion, cg_outputs, {input1}, {input1}, __LINE__, __FILE__);
6508+
}
6509+
64436510
// Test file size should be up to 10K LoC. Create a new file for more tests.
64446511

64456512
} // namespace jit

torch/csrc/jit/codegen/cuda/utils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ auto parseDebugDumpOptions() {
132132
auto parseDisableOptions() {
133133
std::unordered_map<DisableOption, bool> options_map = {
134134
{DisableOption::ArchCheck, false},
135+
{DisableOption::CompileToSass, false},
135136
{DisableOption::Fallback, false},
136137
{DisableOption::Fma, false},
137138
{DisableOption::IndexHoist, false},
@@ -145,6 +146,8 @@ auto parseDisableOptions() {
145146
const auto token = options_view.substr(0, end_pos);
146147
if (token == "arch_check") {
147148
options_map[DisableOption::ArchCheck] = true;
149+
} else if (token == "compile_to_sass") {
150+
options_map[DisableOption::CompileToSass] = true;
148151
} else if (token == "fallback") {
149152
options_map[DisableOption::Fallback] = true;
150153
} else if (token == "fma") {

torch/csrc/jit/codegen/cuda/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
7070
//!
7171
enum class DisableOption {
7272
ArchCheck, //! Disable hardware-specific checks to enable cross arch debug
73+
CompileToSass, //! Disable direct compilation to sass so the ptx can be
74+
//! examined
7375
Fallback, //! Disable fallback
7476
Fma, //! Disable FMA instructions
7577
IndexHoist, //! Disable index hoisting

0 commit comments

Comments
 (0)