Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmarks/cpp/nvfuser/batch_norm_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ static void Baseline_BatchNorm_BWD(
at::Tensor save_mean = at::zeros({input_shape[1]}, fp32_options);
at::Tensor save_var = at::ones({input_shape[1]}, fp32_options);


auto ato_weight = c10::optional<at::Tensor>(weight);
auto ato_bias = c10::optional<at::Tensor>(bias);
auto ato_run_mean = c10::optional<at::Tensor>(run_mean);
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/cpp/nvfuser/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ static void NvFuserScheduler_LayerNorm(
DataType dtype) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);

std::vector<int64_t> input_shape{benchmark_state.range(0), benchmark_state.range(1)};
std::vector<int64_t> input_shape{
benchmark_state.range(0), benchmark_state.range(1)};
const float kEps = 1e-5;

// inputs
Expand Down Expand Up @@ -86,7 +87,8 @@ static void Baseline_LayerNorm(
DataType dtype) {
TORCH_INTERNAL_ASSERT(dtype == DataType::Float || dtype == DataType::Half);

std::vector<int64_t> input_shape{benchmark_state.range(0), benchmark_state.range(1)};
std::vector<int64_t> input_shape{
benchmark_state.range(0), benchmark_state.range(1)};
const int kReductionAxis = 1;
std::vector<int64_t> norm_shape;
for (int idx = kReductionAxis; idx < input_shape.size(); ++idx) {
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/nvfuser/layer_norm_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ static void NvFuserScheduler_LayerNorm_BWD(
at::Tensor mean = at::randn({input_shape[0], 1}, options);
at::Tensor rstd = at::randn({input_shape[0], 1}, options);

std::vector<c10::IValue> aten_inputs({grad_out, input, weight, bias, mean, rstd});
std::vector<c10::IValue> aten_inputs(
{grad_out, input, weight, bias, mean, rstd});

runBenchmarkIterations(benchmark_state, fusion_executor_cache, aten_inputs);

Expand Down
3 changes: 1 addition & 2 deletions benchmarks/cpp/nvfuser/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ static void NvFuserScheduler_Softmax(

at::Tensor aten_input =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

std::vector<c10::IValue> aten_inputs({aten_input});

Expand Down Expand Up @@ -187,7 +187,6 @@ static void Baseline_Softmax(
benchmark::State& benchmark_state,
DataType dtype,
const int reduction_axis) {

at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/cpp/nvfuser/softmax_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ static void NvFuserScheduler_Softmax_BWD(

at::Tensor input =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

at::Tensor grad_output =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

at::Tensor output =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

std::vector<c10::IValue> aten_inputs({grad_output, output, input});

Expand All @@ -88,7 +88,6 @@ static void Baseline_Softmax_BWD(
benchmark::State& benchmark_state,
DataType dtype,
const int reduction_axis) {

at::manual_seed(0);
auto options =
at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
Expand All @@ -98,20 +97,21 @@ static void Baseline_Softmax_BWD(

at::Tensor input =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

at::Tensor grad_output =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

at::Tensor output =
(reduction_axis ? at::randn({iter_size, reduction_size}, options)
: at::randn({reduction_size, iter_size}, options));
: at::randn({reduction_size, iter_size}, options));

for (auto _ : benchmark_state) {
clearL2Cache();
CudaKernelTimer timer;
auto grad_input = at::_softmax_backward_data(grad_output, output, reduction_axis, data_type_to_aten(dtype));
auto grad_input = at::_softmax_backward_data(
grad_output, output, reduction_axis, data_type_to_aten(dtype));
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
Expand Down
1 change: 0 additions & 1 deletion benchmarks/cpp/nvfuser/softmax_dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

using namespace torch::jit::fuser::cuda;


//------------------------------------------------------------------------------

static void setupSoftmaxDropout(
Expand Down
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ if(USE_CUDA OR USE_ROCM)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/welford.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/warp.cu
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/runtime/tensor_core.cu
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh
${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/cuda/detail/UnpackRaw.cuh
)
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ set(JIT_TEST_SRCS
if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu.cpp)
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${JIT_TEST_ROOT}/test_gpu_tensorcore.cpp)
endif()

add_executable(test_jit
Expand Down
253 changes: 253 additions & 0 deletions test/cpp/jit/test_gpu_tensorcore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
#if defined(USE_CUDA)
#include <gtest/gtest.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/codegen.h>
#include <torch/csrc/jit/codegen/cuda/disjoint_set.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/fusion_segmenter.h>
#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_graphviz.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/iter_visitor.h>
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/kernel_expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_builder.h>
#include <torch/csrc/jit/codegen/cuda/kernel_ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/codegen/cuda/root_domain_map.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/utils.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>

// fuser and IR parser
#include "test_gpu_validator.h"

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>

#include <c10/cuda/CUDAStream.h>

#include <algorithm>
#include <iostream>

// Tests go in torch::jit
namespace torch {
namespace jit {

using namespace torch::jit::fuser::cuda;
using namespace at::indexing;

namespace {

// Make a tensor that is known to be fully contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeContigTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder()
.ndims(ndims)
.dtype(dtype)
.contiguity(std::vector<bool>(ndims, true))
.build();
}

// Make a tensor that is known to be non-contiguous of dimensionality=ndims,
// but unknown sizes
TensorView* makeSymbolicTensor(size_t ndims, DataType dtype = DataType::Float) {
return TensorViewBuilder().ndims(ndims).dtype(dtype).build();
}

// Make a non-contiguous tensor of compile-time known sizes
TensorView* makeConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float) {
return TensorViewBuilder().shape(shape).dtype(dtype).build();
}

void checkIntValue(
ExpressionEvaluator& evaluator,
Val* val,
Int::ScalarType expected_value) {
TORCH_CHECK(val->isAnInt());
const auto actual_value = evaluator.evaluate(val);
TORCH_CHECK(actual_value.has_value());
TORCH_CHECK(actual_value.value() == expected_value);
}

void checkIntValue(
kir::ExpressionEvaluator& evaluator,
const kir::Val* val,
kir::Int::ScalarType expected_value) {
const auto actual_value = evaluator.evaluate(val);
TORCH_CHECK(actual_value.has_value());
TORCH_CHECK(actual_value.value() == expected_value);
}

bool cudaArchGuardShouldSkip(int required_major, int required_minor) {
int capability_major = at::cuda::getCurrentDeviceProperties()->major;
int capability_minor = at::cuda::getCurrentDeviceProperties()->minor;

if (capability_major < required_major ||
(capability_major == required_major &&
capability_minor < required_minor)) {
return true;
}
return false;
}

#define NVFUSER_TEST_CUDA_ARCH_GUARD(REQUIRED_MAJOR, REQUIRED_MINOR) \
if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR)) { \
GTEST_SKIP() << "Requires GPU capability above " << REQUIRED_MAJOR << "." \
<< REQUIRED_MINOR << " to run.\n"; \
}

} // namespace

TEST(NVFuserTest, FusionMMASwizzlePrimitiveFloatAcc_CUDA) {
NVFUSER_TEST_CUDA_ARCH_GUARD(7, 0);

FusionExecutor fe;

std::string kernel = R"(
__global__ void mma32x32x8_mini_gemm_float(Tensor<__half, 2> Ag, Tensor<__half, 2> Bg, Tensor<float, 2> Cg0, Tensor<float, 2> Cg1){
// Accumulator
float C[32];
// initialize C:
for(int i=0;i<32;i++){
C[i] = 0;
}
// Allocate Smem
__shared__ uint4 As_mem[32*8 / 8];
__shared__ uint4 Bs_mem[8*32 / 8];
auto As = reinterpret_cast<__half*>(As_mem);
auto Bs = reinterpret_cast<__half*>(Bs_mem);
// Compute lane_id
int lane_id = threadIdx.x % 32;

// Mini prolog: (single warp)
// global load:
uint4 Ag_buffer;
Ag_buffer = *reinterpret_cast<uint4*>(&Ag[threadIdx.x*8]);
uint4 Bg_buffer;
Bg_buffer = *reinterpret_cast<uint4*>(&Bg[threadIdx.x*8]);
// Write into smem: (should consider cleaner utility for vectorization)
*reinterpret_cast<uint4*>(
&As[
// Utility to compose a sequence of swizzles
mem_swizzle::swizzle_sequence<
// This is the instruction mem layout
// for mma32X32X8
warp_mma::mmaM32N32K8::SmemWriteASwizzle>(
// The coordinate is what the Ag_buffer data
// corresponds to in un-swizzled layout
// mostly compatible with nvfuser indexing
{(int)threadIdx.x, 0}, lane_id).linearize({8,1})])
= Ag_buffer;
*reinterpret_cast<uint4*>(
&Bs[
// Utility to compose a sequence of swizzles
mem_swizzle::swizzle_sequence<
// This is the instruction mem layout
// for mma32X32X8
warp_mma::mmaM32N32K8::SmemWriteBSwizzle>(
// The coordinate is what the Bg_buffer data
// corresponds to in un-swizzled layout
// mostly compatible with nvfuser indexing
{((int)threadIdx.x*8 / 32), ((int)threadIdx.x*8 % 32)}, lane_id)
.linearize({32,1})])
=Bg_buffer;

__syncthreads();

// Read from Smem:
__half A[16];
__half B[16];

// Two 128b reads per operand, this to be formulated in layout conversion
// will actually be two loads instead of a for loop in the generated code
for(int warptile_i = 0;warptile_i<2;warptile_i++){
*reinterpret_cast<uint4*>(&A[8*warptile_i]) =
*reinterpret_cast<uint4*>(
&As[
mem_swizzle::swizzle_sequence<
warp_mma::mmaM32N32K8::SmemReadASwizzle>(
// Need to skip iterdomains within
// instruction tile in this indexing
// since they will be determined by lane_id
{16*warptile_i+threadIdx.x, 0}, lane_id
).linearize({8,1})
]);

// This will be a separate loop in the generated code
// but both will be unrolled eventually.
*reinterpret_cast<uint4*>(&B[8*warptile_i]) =
*reinterpret_cast<uint4*>(
&Bs[
mem_swizzle::swizzle_sequence<
warp_mma::mmaM32N32K8::SmemReadBSwizzle>(
// Need to skip iterdomains within
// instruction tile in this indexing
// since they will be determined by lane_id
{4*warptile_i, 0}, lane_id
).linearize({32,1})
]);
}

// finally calling the mma:
warp_mma::mmaM32N32K8::run(&C[0], &A[0], &B[0]);
// Write back result:

// write back option 1: use un-swizzle helper
for(int write_c_i=0;write_c_i<32;write_c_i++){
Cg1[
warp_mma::mmaM32N32K8::UnswizzleFloatC::run(
{0, write_c_i},
lane_id
).linearize({(int)Cg1.stride[0],(int)Cg1.stride[1]})
] = C[write_c_i];
}
// write back option 2: using the writeC helper
warp_mma::mmaM32N32K8::writeC(&Cg0[0], &C[0], lane_id, Cg0.stride[0]);
}
)";
fe.compileRtc(kernel, "CudaCodeGen::mma32x32x8_mini_gemm_float");
LaunchParams lp(
1, // gdimx
1, // gdimy
1, // gdimz
32, // bdimx
1, // bdimy
1 // bdimz
);
lp.setSmem(0);
const auto options =
at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);

// Use smaller numbers to avoid confusion with rounding for
// small matrices.
auto inA = at::randn({32, 8}, options) / 2;
auto inB = at::randn({8, 32}, options) / 2;

auto outC0 = at::empty({32, 32}, options.dtype(at::kFloat));
auto outC1 = at::empty({32, 32}, options.dtype(at::kFloat));

fe.runRtc(lp, {inA, inB, outC0, outC1});

auto refC = inA.to(at::kDouble).matmul(inB.to(at::kDouble)).to(at::kFloat);

TORCH_CHECK(refC.allclose(outC0));
TORCH_CHECK(refC.allclose(outC1));
}

#undef NVFUSER_TEST_CUDA_ARCH_GUARD

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
"torch/csrc/jit/codegen/cuda/runtime/welford.cu",
"torch/csrc/jit/codegen/cuda/runtime/warp.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor_core.cu",
"aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh",
"aten/src/ATen/cuda/detail/UnpackRaw.cuh",
]
Expand Down
Loading