Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
a449f49
initial volta support
shmsong Feb 21, 2022
edd43d9
mma parallel type && cleanup
shmsong Feb 22, 2022
ddac459
cleanup
shmsong Feb 22, 2022
2f08d09
alignment
shmsong Feb 22, 2022
ca77ff4
comment
shmsong Feb 22, 2022
9caeb18
change request
shmsong Mar 16, 2022
de1d3ec
fix same parallel type
shmsong Mar 16, 2022
b393cbe
Merge remote-tracking branch 'origin/devel' into volta_mma_op
shmsong Mar 16, 2022
74f8c12
move validation pass
shmsong Mar 16, 2022
db34181
comment and cleanup
shmsong Mar 17, 2022
4dec827
lint
shmsong Mar 17, 2022
5ecd102
comment and cleanup
shmsong Mar 17, 2022
adb19b2
Merge remote-tracking branch 'origin/devel' into volta_mma_op
shmsong Mar 17, 2022
c97c605
comment and format
shmsong Mar 17, 2022
264bc77
initial turing and ampere mma support
shmsong Mar 21, 2022
5bc918c
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong Mar 21, 2022
2f0ae5e
fix rebase
shmsong Mar 22, 2022
b96462a
more rebase fix
shmsong Mar 23, 2022
5347fda
test comment
shmsong Mar 28, 2022
3f580d1
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong Mar 28, 2022
4834995
submodule
shmsong Mar 28, 2022
9e18e04
cleanup and comments
shmsong Mar 28, 2022
1fb32ed
minor fix
shmsong Mar 28, 2022
43d0e72
test cleanup
shmsong Mar 28, 2022
c1f4374
format
shmsong Mar 29, 2022
7c91a0a
newline
shmsong Mar 29, 2022
ae2bd1f
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong Apr 21, 2022
f4c6f12
move all implementation to Ampere space
shmsong Apr 25, 2022
d719bd8
comment and naming
shmsong Apr 25, 2022
a2c28c9
leave cp async in a separate PR
shmsong Apr 25, 2022
d28574c
cleanup
shmsong Apr 25, 2022
2fa3a92
move memory op to separate IR node
shmsong Apr 26, 2022
e0d1769
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong Apr 26, 2022
d2274cd
code refactor bug fix
shmsong Apr 26, 2022
090752d
update test for shared mem initialization
shmsong Apr 26, 2022
640d3aa
refactor mma user interface
shmsong Apr 26, 2022
da40702
add turing mma support and test
shmsong May 2, 2022
36786b9
add shared mem in zero leaf detection
shmsong May 2, 2022
2edd1cc
move shared mem predicate
shmsong May 2, 2022
bcb18af
comment
shmsong May 2, 2022
fd5e178
adjust unused ldmatrix address for Turing
shmsong May 3, 2022
86f2d61
cleanup
shmsong May 3, 2022
484efd8
update comment
shmsong May 3, 2022
fbaeaf7
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 4, 2022
5355e07
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 4, 2022
9e298b0
cleanup;undo lower_unroll change;
shmsong May 4, 2022
9615196
rebase fix
shmsong May 4, 2022
a787b59
fix shared mem init
shmsong May 4, 2022
7cf98ab
clean up
shmsong May 4, 2022
e2e0707
comment
shmsong May 4, 2022
9c224dc
zero leaf update
shmsong May 4, 2022
584567c
comment
shmsong May 4, 2022
09f5ce0
use toString
shmsong May 5, 2022
998adb5
comment
shmsong May 6, 2022
13266ce
expand tensor filling op detect to include load store
shmsong May 6, 2022
5dc3d4b
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 6, 2022
b460fb1
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 6, 2022
fc35480
arch guard update
shmsong May 6, 2022
76563a5
fix rebase
shmsong May 6, 2022
119f0eb
WAR for buffer re-use
shmsong May 7, 2022
abcd1e8
clean up
shmsong May 9, 2022
ebc215a
update comment
shmsong May 9, 2022
5e8128f
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 10, 2022
80e840c
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 17, 2022
5458e4b
Merge remote-tracking branch 'origin/devel' into ampere_mma_op
shmsong May 19, 2022
b05ae79
Merge remote-tracking branch 'origin/ampere_mma_op' into turing_mma_op
shmsong May 19, 2022
9684acb
Merge remote-tracking branch 'origin/devel' into turing_mma_op
shmsong May 23, 2022
98f45dd
Merge remote-tracking branch 'origin/devel' into turing_mma_op
shmsong May 24, 2022
f0eb7b6
use relaxed arch guard
shmsong May 24, 2022
0e204aa
fix rebase
shmsong May 24, 2022
e92430e
rebase fix
shmsong May 24, 2022
37ce5aa
constexpr
shmsong May 24, 2022
3d8cb3a
rename swizzle enum
shmsong May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions torch/csrc/jit/codegen/cuda/lower_validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,14 @@ void validateMma(Fusion* fusion) {
case MmaOptions::MacroType::Volta_16_16_4:
validateMinimumArch(7, 0);
break;
case MmaOptions::MacroType::Turing_16_8_16:
validateMinimumArch(7, 5);

// Check that operands come from ldmatrix, can be
// relaxed once swizzles can be labeled on iterdomains.
validateTuringMmaInput(mma->inA()->as<TensorView>());
validateTuringMmaInput(mma->inB()->as<TensorView>());
break;
case MmaOptions::MacroType::Ampere_16_8_16:
validateMinimumArch(8, 0);

Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/jit/codegen/cuda/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ MmaBuilder::MmaBuilder(
case MmaOptions::MacroType::Volta_16_16_4:
option_.accumulator_stride = outer_stride * 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
option_.accumulator_stride = outer_stride * 2;
break;
Expand Down Expand Up @@ -58,6 +59,7 @@ namespace {
LoadStoreOpType getLdMatrixType(MmaOptions options) {
bool transpose = false;
switch (options.macro) {
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
// Turing mma assumes TN as default
transpose = (options.operand == MmaOptions::Operand::A &&
Expand All @@ -84,7 +86,7 @@ bool isVolta(MmaOptions::MacroType macro) {
}

bool isTuring(MmaOptions::MacroType macro) {
return false;
return macro == MmaOptions::MacroType::Turing_16_8_16;
}

bool isAmpere(MmaOptions::MacroType macro) {
Expand All @@ -96,6 +98,7 @@ int getOutputRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 8;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
break;
Expand All @@ -111,6 +114,7 @@ int getInputARegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 8;
break;
Expand All @@ -126,6 +130,7 @@ int getInputBRegisterSize(MmaOptions::MacroType macro) {
case MmaOptions::MacroType::Volta_16_16_4:
return 4;
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
return 4;
default:
Expand Down Expand Up @@ -176,6 +181,7 @@ std::string toString(MmaOptions::MacroType mt) {
case MmaOptions::MacroType::Volta_16_16_4:
ss << "M16N16K4";
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
ss << "M16N8K16";
break;
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/codegen/cuda/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct MmaOptions {
NoMMA = 0,
Volta_16_16_4,
Ampere_16_8_16,
Turing_16_8_16,
Ampere_16_8_8 // place holder for tf32
};

Expand All @@ -73,7 +74,7 @@ struct MmaOptions {
enum class MmaInputLayout { NT = 0, TT, TN };

//! Utility to annotate which input of mma this option struct describes
enum class Operand { NotOperand = 0, A, B };
enum class Operand { Accumulator = 0, A, B };

//! Utility to annotate which mma macro this config uses.
MacroType macro = MacroType::NoMMA;
Expand Down Expand Up @@ -117,7 +118,7 @@ class TORCH_CUDA_CU_API MmaBuilder {
//! Specifies which element in the mma op this builder is generating
//! parameters for, i.e. A or B. This is useful when generating
//! data swizzles for different elements of mma.
//! - Operand::NotOperand means the parameters describe accumulator in mma
//! - Operand::Accumulator means the parameters describe accumulator in mma
//! op.
//! - This option is ignored when configuring the mma operator itself.
MmaBuilder& operand(MmaOptions::Operand a_or_b);
Expand Down
36 changes: 36 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,40 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) {
return smem_ptr_uint;
}

// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and
// .x4. In .x2 option. the the address register of upper half warp (lane 16-31)
// are un-used but on Turing [sm75,sm80) architecture these un-used addresses
// need to be valid, in the sense that:
// 1. The data it points to has to be within allocated shared mem buffer.
// 2. The address needs to be aligned to 16 byte.
// See also:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
// This function addresses 2. above by masking out the sub-16B component
// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle
// util.
// This will **not** affect any functionality. This is just modification
// of unused pointers to satisfy the alignment requirement on Turing
// hardware.
// The alignment requirement is lifted on sm80+,
// so this function is a no-op on Ampere or above.
DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) {
#if (__CUDA_ARCH__ < 800)
const unsigned thread_id = threadIdx.x;
// Upper half warp has 8 bytes offset from aligned in .x2 option
// of ldmatrix. Currently no support for .x1 so assume always
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason not to add x1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could add .x1 in a follow up. I currently didn't yet add the smaller mma tiles that .x1 would pair with, they are not immediately useful yet in large CTA tile kernels.

// adjust by half warp.
constexpr unsigned half_warp = 16;
// Need to adjust to 16 byte alignment, mask out un-aligned component.
constexpr unsigned mask_out = 16 - 1;
// Adjust only in upper half warp.
// use bit math to reduce strength
if (thread_id & half_warp) {
// mask out the bits where adjust_mask has 1.
addr_in_byte &= (~mask_out);
}
#endif //(__CUDA_ARCH__ < 800)
}

} // namespace util

// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory.
Expand All @@ -36,6 +70,7 @@ DEVICE_INLINE unsigned toSmem(const void* raw_ptr) {
DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) {
uint2& val = reinterpret_cast<uint2&>(out);
unsigned addr = util::toSmem(ptr);
util::adjustPartialLdMatrixAddrInTuring(addr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "r"(addr));
Expand All @@ -47,6 +82,7 @@ DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) {
DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) {
uint2& val = reinterpret_cast<uint2&>(out);
unsigned addr = util::toSmem(ptr);
util::adjustPartialLdMatrixAddrInTuring(addr);
asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];"
: "=r"(val.x), "=r"(val.y)
: "r"(addr));
Expand Down
68 changes: 68 additions & 0 deletions torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,74 @@ DEVICE_INLINE void initM16N16K4NT(Array<float, 8, 8>* accumulator) {

} // namespace Volta

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))

namespace Turing {

namespace util {
// MMA instruction wrappers (sm_75+):
DEVICE_INLINE void m16n8k16TN(
Array<float, 4, 4>* C,
Array<__half, 8, 8>* A,
Array<__half, 4, 4>* B) {
unsigned const* _A = reinterpret_cast<unsigned const*>(A);
unsigned const* _B = reinterpret_cast<unsigned const*>(B);
unsigned* _C = reinterpret_cast<unsigned*>(C);
const unsigned* _D = reinterpret_cast<const unsigned*>(C);

asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
: "r"(_A[0]),
"r"(_A[1]),
"r"(_B[0]),
"r"(_D[0]),
"r"(_D[1]),
"r"(_D[2]),
"r"(_D[3]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(_C[0]), "=r"(_C[1]), "=r"(_C[2]), "=r"(_C[3])
: "r"(_A[2]),
"r"(_A[3]),
"r"(_B[1]),
"r"(_D[0]),
"r"(_D[1]),
"r"(_D[2]),
"r"(_D[3]));
}

} // namespace util

template <int acc_stride>
DEVICE_INLINE void initM16N8K16TN(Array<float, 4, 4>* accumulator) {
float* _C = reinterpret_cast<float*>(accumulator);
_C[0] = 0;
_C[1] = 0;
_C[acc_stride] = 0;
_C[acc_stride + 1] = 0;
}

template <int acc_stride = 2>
DEVICE_INLINE void M16N8K16TN(
Array<float, 4, 4>* C,
Array<__half, 8, 8>* A,
Array<__half, 4, 4>* B) {
// TODO: in a follow up,
// lift this fused swizzle onto iterdomain
float* _C = reinterpret_cast<float*>(C);
float C_data[4] = {_C[0], _C[1], _C[acc_stride], _C[acc_stride + 1]};

util::m16n8k16TN(reinterpret_cast<Array<float, 4, 4>*>(&C_data[0]), A, B);

_C[0] = C_data[0];
_C[1] = C_data[1];
_C[acc_stride] = C_data[2];
_C[acc_stride + 1] = C_data[3];
}

} // namespace Turing

#endif // Arch 75

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))

namespace Ampere {
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ void WarpMmaSwizzler::scheduleMmaWarpOutput(
setWarpMapped(tv, 5);
}
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
scheduleTuringM16N8K16MmaWarpOutput(tv, options);
if (tv->definition()->isA<MmaOp>()) {
Expand All @@ -240,6 +241,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOptions options) {
case MmaOptions::MacroType::Volta_16_16_4:
scheduleVoltaOperandRead(tv, options);
break;
case MmaOptions::MacroType::Turing_16_8_16:
case MmaOptions::MacroType::Ampere_16_8_16:
scheduleTuringOperandRead(tv, options);
break;
Expand Down Expand Up @@ -415,7 +417,8 @@ void scheduleLdMatrix(TensorView* tv, MmaOptions options) {
: isOperandTransposed(options);
// Check mma option is supported
TORCH_CHECK(
options.macro == MmaOptions::MacroType::Ampere_16_8_16,
options.macro == MmaOptions::MacroType::Ampere_16_8_16 ||
options.macro == MmaOptions::MacroType::Turing_16_8_16,
"scheduleLdMatrix: unknown macro for ldmatrix");

if (options.operand == MmaOptions::Operand::A) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ bool TensorView::isEmptyTensor() const {

void TensorView::applyMmaSwizzle(MmaOptions options) {
switch (options.operand) {
case MmaOptions::Operand::NotOperand:
case MmaOptions::Operand::Accumulator:
mma_util::WarpMmaSwizzler::scheduleMmaWarpOutput(this, options);
break;
case MmaOptions::Operand::A:
Expand Down
Loading