@@ -477,6 +477,33 @@ class CudaKernelGenerator : private OptOutConstDispatch {
477
477
TORCH_INTERNAL_ASSERT (false , " Unreachable" );
478
478
}
479
479
480
+ // ! Utility for generating vectorized pointer access in ldsm and
481
+ // ! cpasync.
482
+ // ! TODO: this access pattern as is could be merged with exisiting
483
+ // ! vectorization handling logic but this path will be updated in
484
+ // ! follow ups to optimize the generated assembly so keeping them
485
+ // ! separate path for now.
486
+ std::string genVectorPointer (Val* val, DataType dtype, int vec_size) {
487
+ std::stringstream ss;
488
+
489
+ ss << " reinterpret_cast<Array<" << dtype << " ," << vec_size << " ,"
490
+ << vec_size << " >*>(&" << gen (val) << " )" ;
491
+
492
+ return ss.str ();
493
+ }
494
+
495
+ void genLdMatrix (const LoadStoreOp* ldst, int vector_word_size) {
496
+ auto dtype = ldst->in ()->getDataType ().value ();
497
+ indent () << " Turing::ldMatrix" ;
498
+ if (ldst->opType () == LoadStoreOpType::LdMatrixTranspose) {
499
+ code_ << " T" ;
500
+ }
501
+ code_ << " (" ;
502
+ code_ << " *" << genVectorPointer (ldst->out (), dtype, vector_word_size)
503
+ << " ,"
504
+ << " &" << gen (ldst->in ()) << " );\n " ;
505
+ }
506
+
480
507
void handle (const UnaryOp* uop) final {
481
508
bool is_vector_op = false ;
482
509
size_t vector_word_size = 1 ;
@@ -918,7 +945,15 @@ class CudaKernelGenerator : private OptOutConstDispatch {
918
945
if (init) {
919
946
ss << " init" ;
920
947
}
921
- ss << toString (options.macro ) << toString (options.operand_layout );
948
+ ss << toString (options.macro );
949
+
950
+ if (isVolta (options.macro )) {
951
+ ss << toString (options.operand_layout );
952
+ } else if (isTuring (options.macro ) || isAmpere (options.macro )) {
953
+ // mma's in turing and ampere TN only, transpose is handled either
954
+ // via ldmatrix for fp16 or explicitly for other types.
955
+ ss << " TN" ;
956
+ }
922
957
// TODO: additional parameter could be removed by swizzling iterdomain
923
958
auto acc_stride = mma->accStride ();
924
959
TORCH_INTERNAL_ASSERT (acc_stride > 0 );
@@ -1123,6 +1158,49 @@ class CudaKernelGenerator : private OptOutConstDispatch {
1123
1158
}
1124
1159
}
1125
1160
1161
+ void handle (const LoadStoreOp* ldst) {
1162
+ // TODO:
1163
+ // Need to gradually merge the code path of this
1164
+ // with UnaryOp::Set for vectorization.
1165
+ // There is quite a bit of possible clean up.
1166
+ bool vectorize_op = false ;
1167
+ size_t vector_word_size = 1 ;
1168
+ auto ti = ldst->out ()->as <kir::TensorIndex>();
1169
+
1170
+ // Check vectorization and set vector word size
1171
+ for (auto id : ti->view ()->domain ()->domain ()) {
1172
+ if (!isParallelTypeVectorize (id->getParallelType ())) {
1173
+ continue ;
1174
+ }
1175
+
1176
+ ExpressionEvaluator expr_eval (id->fusion ());
1177
+ auto vector_size_optional = expr_eval.evaluate (id->extent ());
1178
+
1179
+ TORCH_INTERNAL_ASSERT (
1180
+ vector_size_optional.has_value (),
1181
+ " Could not evaluate constant value bound to vectorized dim." );
1182
+
1183
+ TORCH_INTERNAL_ASSERT (
1184
+ id->getParallelType () != ParallelType::MisalignedVectorize,
1185
+ " LoadStoreOp: no support yet for mis-aligned vectorization" );
1186
+ vector_word_size = vector_size_optional.value ();
1187
+ vectorize_op = true ;
1188
+ break ;
1189
+ }
1190
+
1191
+ // Dispatch instruction generation:
1192
+ switch (ldst->opType ()) {
1193
+ case LoadStoreOpType::LdMatrix:
1194
+ case LoadStoreOpType::LdMatrixTranspose:
1195
+ TORCH_INTERNAL_ASSERT (
1196
+ vectorize_op, " LdMatrix: Vectorization required: " , ldst);
1197
+ genLdMatrix (ldst, vector_word_size);
1198
+ break ;
1199
+ default :
1200
+ TORCH_INTERNAL_ASSERT (false , " LoadStoreOp: Unknown op type" );
1201
+ }
1202
+ }
1203
+
1126
1204
void handle (const WelfordOp* wop) final {
1127
1205
TORCH_INTERNAL_ASSERT (wop->out ()->isA <kir::TensorIndex>());
1128
1206
@@ -2033,7 +2111,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
2033
2111
}
2034
2112
}
2035
2113
2036
- void handle (const kir::BlockSync*) final {
2114
+ void handle (const kir::BlockSync* sync ) final {
2037
2115
// Use a custom synchronization method if enabled
2038
2116
if (std::getenv (" PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC" )) {
2039
2117
indent () << " block_sync::sync();\n " ;
0 commit comments