Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 68 additions & 67 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#if defined(USE_CUDA)

#include <test/cpp/jit/test_base.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
Expand Down Expand Up @@ -26,17 +27,16 @@ namespace jit {

using namespace torch::jit::fuser;

static TensorView* makeDummyTensor(
int nDims,
DataType dtype = DataType::Float) {
namespace {

TensorView* makeDummyTensor(int nDims, DataType dtype = DataType::Float) {
std::vector<IterDomain*> dom;
for (int i = 0; i < nDims; i++)
dom.push_back(new IterDomain(new Int(0), new Int()));

return new TensorView(new TensorDomain(dom), dtype);
}

static void checkIntValue(
void checkIntValue(
const EvaluationContext* eval_context,
Val* val,
Int::ScalarType expected_value) {
Expand All @@ -46,6 +46,8 @@ static void checkIntValue(
TORCH_CHECK(actual_value.value() == expected_value);
}

} // namespace

// 1. Test cases are void() functions.
// 2. They start with the prefix `test`

Expand Down Expand Up @@ -2971,88 +2973,86 @@ void testGPU_FusionSimpleBCast() {
}

void testGPU_FusionSimpleGemm() {
{
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note the extra { ... } around the body of the function, which resulted in a indent update

torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2); // M, K
TensorView* tv1 = makeDummyTensor(2); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2); // M, K
TensorView* tv1 = makeDummyTensor(2); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

TensorView* tv2 = broadcast(tv0, {false, false, true});
// tv2[I0, I1, B] = tv0[I0, I1]
TensorView* tv2 = broadcast(tv0, {false, false, true});
// tv2[I0, I1, B] = tv0[I0, I1]

TensorView* tv3 = broadcast(tv1, {true, false, false});
// tv3[B, I1, I2] = tv1[I1, I2]
TensorView* tv3 = broadcast(tv1, {true, false, false});
// tv3[B, I1, I2] = tv1[I1, I2]

// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
TensorView* tv4 = mul(tv2, tv3);
// tv5[I0, R1, I2] = tv4[I0, I1, I2]
TensorView* tv5 = sum(tv4, {1});
fusion.addOutput(tv5);
// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
TensorView* tv4 = mul(tv2, tv3);
// tv5[I0, R1, I2] = tv4[I0, I1, I2]
TensorView* tv5 = sum(tv4, {1});
fusion.addOutput(tv5);

tv5->split(1, 32);
// tv5[I0, R1o, R1i{32}, I2]
tv5->split(1, 32);
// tv5[I0, R1o, R1i{32}, I2]

auto tv6 = tv5->rFactor({1});
// tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
// tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
auto tv6 = tv5->rFactor({1});
// tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
// tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]

tv5->split(0, 4);
tv5->split(-1, 4);
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
tv5->split(0, 4);
tv5->split(-1, 4);
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]

tv0->computeAt(tv5, -1);
tv1->computeAt(tv5, -1);
tv0->computeAt(tv5, -1);
tv1->computeAt(tv5, -1);

// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
//--> (line symbolizes compute at location)
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
//--> (line symbolizes compute at location)
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]

tv0->computeAt(tv6, -1);
tv1->computeAt(tv6, -1);
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
tv0->computeAt(tv6, -1);
tv1->computeAt(tv6, -1);
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]

tv5->axis(0)->parallelize(ParallelType::BIDz);
tv5->axis(1)->parallelize(ParallelType::TIDz);
tv5->axis(0)->parallelize(ParallelType::BIDz);
tv5->axis(1)->parallelize(ParallelType::TIDz);

tv5->axis(-2)->parallelize(ParallelType::BIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDy);
tv5->axis(-2)->parallelize(ParallelType::BIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDy);

tv5->axis(2)->parallelize(ParallelType::TIDx);
tv6->axis(2)->parallelize(ParallelType::TIDx);
tv5->axis(2)->parallelize(ParallelType::TIDx);
tv6->axis(2)->parallelize(ParallelType::TIDx);

constexpr int M = 65, K = 33, N = 17;
constexpr int M = 65, K = 33, N = 17;

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);
at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);

at::Tensor cg_output = at::empty({M, N}, options);
at::Tensor cg_output = at::empty({M, N}, options);

prog.device_ = 0;
prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));
prog.device_ = 0;
prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));

prog.block(32, 4, 4);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
prog.block(32, 4, 4);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});

auto t2 = t0.matmul(t1);
TORCH_CHECK(
t2.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
t2.sub(cg_output).abs().max());
}
auto t2 = t0.matmul(t1);
TORCH_CHECK(
t2.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
t2.sub(cg_output).abs().max());
}

// Softmax with a 1D tensor. Parallelized only with a single thread block.
Expand Down Expand Up @@ -4211,4 +4211,5 @@ void testGPU_FusionReductionScheduler() {

} // namespace jit
} // namespace torch

#endif // #if defined(USE_CUDA)
12 changes: 4 additions & 8 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,10 @@ TORCH_CUDA_API TensorView* reductionOp(
TORCH_CUDA_API Val* neg(Val* v);
TORCH_CUDA_API TensorView* neg(TensorView* v);

// BINARY OPERATIONS
// add
/*
* Broadcasts v1 based on bool vector. Size of broadcast bool vector should be
* the number of dims desired in the broadcasted tensor. This vector should be
* true if output dim should be a broadcasted dim, and false if it is not a
* broadcasted dim. Number of false entires must match the number of input dims.
*/
// Broadcasts v1 based on bool vector. Size of broadcast bool vector should be
// the number of dims desired in the broadcasted tensor. This vector should be
// true if output dim should be a broadcasted dim, and false if it is not a
// broadcasted dim. Number of false entires must match the number of input dims.
TORCH_CUDA_API TensorView* broadcast(
TensorView* inp,
const std::vector<bool>& is_broadcast_dim);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ constexpr auto kKernelName = "kernel";
namespace {

// See NOTE [ USE OF NVRTC AND DRIVER API ]
static const at::cuda::NVRTC& nvrtc() {
const at::cuda::NVRTC& nvrtc() {
return at::globalContext().getNVRTC();
}

static int ceilDiv(const int a, const int b) {
int ceilDiv(const int a, const int b) {
return (a + b - 1) / b;
}

Expand Down