Skip to content

Commit 6df7b77

Browse files
authored
Mma operator and volta mma integration (#1439)
* initial volta support * mma parallel type && cleanup * cleanup * alignment * comment * change request * fix same parallel type * move validation pass * comment and cleanup * lint * comment and cleanup * comment and format
1 parent 5ba9343 commit 6df7b77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2820
-28
lines changed

benchmarks/cpp/nvfuser/layer_norm_backward.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ static void setupLayerNorm_BWD(Fusion* fusion, DataType dtype) {
6464
if (dtype != DataType::Float) {
6565
layer_norm_results.grad_input =
6666
castOp(dtype, layer_norm_results.grad_input);
67-
layer_norm_results.grad_bias =
68-
castOp(dtype, layer_norm_results.grad_bias);
67+
layer_norm_results.grad_bias = castOp(dtype, layer_norm_results.grad_bias);
6968
layer_norm_results.grad_weight =
7069
castOp(dtype, layer_norm_results.grad_weight);
7170
}

benchmarks/cpp/nvfuser/rms_norm.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ using namespace torch::jit::fuser::cuda;
1818
//------------------------------------------------------------------------------
1919

2020
static void setupRMSNorm(Fusion* fusion, DataType dtype) {
21-
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
21+
TORCH_INTERNAL_ASSERT(
22+
dtype == DataType::Float || dtype == DataType::Half ||
23+
dtype == DataType::BFloat16);
2224

2325
FusionGuard fg(fusion);
2426

@@ -54,10 +56,11 @@ static void NvFuserScheduler_RMSNorm(
5456
benchmark::State& benchmark_state,
5557
FusionExecutorCache* fusion_executor_cache,
5658
DataType dtype) {
57-
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
59+
TORCH_INTERNAL_ASSERT(
60+
dtype == DataType::Float || dtype == DataType::Half ||
61+
dtype == DataType::BFloat16);
5862

59-
std::vector<int64_t> input_shape{
60-
8, benchmark_state.range(0), 1024};
63+
std::vector<int64_t> input_shape{8, benchmark_state.range(0), 1024};
6164
const float kEps = 1e-6;
6265

6366
// inputs
@@ -73,8 +76,7 @@ static void NvFuserScheduler_RMSNorm(
7376

7477
benchmark_state.SetBytesProcessed(
7578
int64_t(benchmark_state.iterations()) *
76-
(2 * input.numel() + weight.numel()) *
77-
int64_t(dataTypeSize(dtype)));
79+
(2 * input.numel() + weight.numel()) * int64_t(dataTypeSize(dtype)));
7880
}
7981

8082
//------------------------------------------------------------------------------

benchmarks/cpp/nvfuser/rms_norm_backward.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ using namespace torch::jit::fuser::cuda;
2020
static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) {
2121
FusionGuard fg(fusion);
2222

23-
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
23+
TORCH_INTERNAL_ASSERT(
24+
dtype == DataType::Float || dtype == DataType::Half ||
25+
dtype == DataType::BFloat16);
2426

2527
const int kReductionAxis = 2;
2628
Double* eps_ptr = IrBuilder::create<Double>(1e-6);
@@ -47,14 +49,12 @@ static void setupRMSNorm_BWD(Fusion* fusion, DataType dtype) {
4749
rstd = castOp(DataType::Float, rstd);
4850
}
4951

50-
auto rms_norm_results = rms_norm_backward(
51-
grad_out, input, {1}, rstd, weight, {true, true, true});
52+
auto rms_norm_results =
53+
rms_norm_backward(grad_out, input, {1}, rstd, weight, {true, true, true});
5254

53-
if (dtype != DataType::Float ) {
54-
rms_norm_results.grad_input =
55-
castOp(dtype, rms_norm_results.grad_input);
56-
rms_norm_results.grad_weight =
57-
castOp(dtype, rms_norm_results.grad_weight);
55+
if (dtype != DataType::Float) {
56+
rms_norm_results.grad_input = castOp(dtype, rms_norm_results.grad_input);
57+
rms_norm_results.grad_weight = castOp(dtype, rms_norm_results.grad_weight);
5858
}
5959

6060
fusion->addOutput(rms_norm_results.grad_input);
@@ -65,10 +65,11 @@ static void NvFuserScheduler_RMSNorm_BWD(
6565
benchmark::State& benchmark_state,
6666
FusionExecutorCache* fusion_executor_cache,
6767
DataType dtype) {
68-
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half || dtype == DataType::BFloat16);
68+
TORCH_INTERNAL_ASSERT(
69+
dtype == DataType::Float || dtype == DataType::Half ||
70+
dtype == DataType::BFloat16);
6971

70-
std::vector<int64_t> input_shape{
71-
8, benchmark_state.range(0), 1024};
72+
std::vector<int64_t> input_shape{8, benchmark_state.range(0), 1024};
7273

7374
// inputs
7475
at::manual_seed(0);
@@ -79,15 +80,13 @@ static void NvFuserScheduler_RMSNorm_BWD(
7980
at::Tensor weight = at::randn({input_shape[2]}, options);
8081
at::Tensor rstd = at::randn({input_shape[0], input_shape[1], 1}, options);
8182

82-
std::vector<c10::IValue> aten_inputs(
83-
{grad_out, input, weight, rstd});
83+
std::vector<c10::IValue> aten_inputs({grad_out, input, weight, rstd});
8484

8585
runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);
8686

8787
benchmark_state.SetBytesProcessed(
8888
int64_t(benchmark_state.iterations()) *
89-
(3 * input.numel() + weight.numel() +
90-
rstd.numel()) *
89+
(3 * input.numel() + weight.numel() + rstd.numel()) *
9190
int64_t(dataTypeSize(dtype)));
9291
}
9392

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,7 @@ if(USE_CUDA OR USE_ROCM)
943943
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/type_traits.cu
944944
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu
945945
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/warp.cu
946+
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensorcore.cu
946947
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
947948
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh
948949
)

test/cpp/jit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ if(USE_CUDA)
9797
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu.cpp)
9898
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_fused_reduction.cpp)
9999
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp)
100+
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_tensorcore.cpp)
100101
endif()
101102

102103
add_executable(test_jit

test/cpp/jit/test_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21082,7 +21082,7 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) {
2108221082
}
2108321083
#endif
2108421084

21085-
TEST_F(NVFuserTest, FusionIssue1430) {
21085+
TEST_F(NVFuserTest, FusionIssue1430_CUDA) {
2108621086
// Derived from an expression sorting issue when using loop map, now expr
2108721087
// sorting uses parallel map.
2108821088
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();

0 commit comments

Comments
 (0)