Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp
Expand Down
101 changes: 83 additions & 18 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#if defined(USE_CUDA)
// #if defined(USE_CUDA)
#include <test/cpp/jit/test_base.h>

#include <torch/csrc/jit/codegen/cuda/arith.h>
Expand Down Expand Up @@ -28,7 +28,7 @@ using namespace torch::jit::fuser;
TensorView* makeDummyTensor(int nDims) {
std::vector<IterDomain*> dom;
for (int i = 0; i < nDims; i++)
dom.push_back(new IterDomain(new Int()));
dom.push_back(new IterDomain(new Int(0), new Int()));

return new TensorView(new TensorDomain(dom), DataType::Float);
}
Expand Down Expand Up @@ -385,22 +385,22 @@ void testGPU_FusionTVSplit() {

tv = tv->split(2, 2);
TORCH_CHECK(tv->nDims() == 4);
Expr* outer = tv->axis(2)->size()->getOrigin();
Expr* outer = tv->axis(2)->extent()->getOrigin();

TORCH_CHECK(
outer->getExprType().value() == ExprType::BinaryOp &&
static_cast<BinaryOp*>(outer)->getBinaryOpType() ==
BinaryOpType::CeilDiv &&
static_cast<BinaryOp*>(outer)->lhs()->sameAs(
tv->getRootDomain()->axis(2)->size()) &&
tv->getRootDomain()->axis(2)->extent()) &&
static_cast<Int*>(static_cast<BinaryOp*>(outer)->rhs())
->sameAs(new Int(2)));

IterDomain* inner = static_cast<IterDomain*>(tv->axis(3));
TORCH_CHECK(
inner->size()->isScalar() &&
static_cast<Int*>(inner->size())->isConst() &&
static_cast<Int*>(inner->size())->value().value() == 2);
inner->extent()->isScalar() &&
static_cast<Int*>(inner->extent())->isConst() &&
static_cast<Int*>(inner->extent())->value().value() == 2);
}

void testGPU_FusionTVMerge() {
Expand All @@ -410,15 +410,15 @@ void testGPU_FusionTVMerge() {
TensorView* tv = makeDummyTensor(3);

tv = tv->merge(1);
Expr* axisOp = tv->axis(1)->size()->getOrigin();
Expr* axisOp = tv->axis(1)->extent()->getOrigin();

TORCH_CHECK(
tv->nDims() == 2 && axisOp->getExprType() == ExprType::BinaryOp &&
static_cast<BinaryOp*>(axisOp)->getBinaryOpType() == BinaryOpType::Mul &&
static_cast<BinaryOp*>(axisOp)->lhs() ==
tv->getRootDomain()->axis(1)->size() &&
tv->getRootDomain()->axis(1)->extent() &&
static_cast<BinaryOp*>(axisOp)->rhs() ==
tv->getRootDomain()->axis(2)->size());
tv->getRootDomain()->axis(2)->extent());
}

void testGPU_FusionTVReorder() {
Expand Down Expand Up @@ -857,7 +857,7 @@ void testGPU_FusionSimplePWise() {
// Set up symbolic sizes for the axes should be dimensionality of the problem
std::vector<IterDomain*> dom;
for (int i = 0; i < nDims; i++)
dom.push_back(new IterDomain(new Int()));
dom.push_back(new IterDomain(new Int(0), new Int()));

// Set up your input tensor views
TensorView* tv0 = new TensorView(new TensorDomain(dom), DataType::Float);
Expand Down Expand Up @@ -937,13 +937,18 @@ void testGPU_FusionExecKernel() {
// Register your outputs
fusion.addOutput(tv3);

tv3->split(0, 4);

// For all inputs, computeAt the output inline, temporaries should be squeezed
// between them
tv0->computeAt(tv3, -1);
tv1->computeAt(tv3, -1);
tv0->computeAt(tv3, 1);
tv1->computeAt(tv3, 1);

// Parallelize TV3
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(1)->parallelize(ParallelType::Unroll);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

torch::jit::fuser::cuda::CudaKernel prog;
Expand Down Expand Up @@ -973,14 +978,16 @@ void testGPU_FusionForLoop() {
FusionGuard fg(&fusion);

const auto TV0 = new TensorView(
new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);
new TensorDomain({new IterDomain(new Int(0), new Int(16))}),
DataType::Float);
const auto TV1 = new TensorView(
new TensorDomain({new IterDomain(new Int(16))}), DataType::Float);
new TensorDomain({new IterDomain(new Int(0), new Int(16))}),
DataType::Float);

fusion.addInput(TV0);
fusion.addInput(TV1);

auto ID0 = new IterDomain(new Int(8));
auto ID0 = new IterDomain(new Int(0), new Int(8));

TensorView* TV2 = static_cast<TensorView*>(add(TV0, TV1));
BinaryOp* op = static_cast<BinaryOp*>(TV2->getOrigin());
Expand All @@ -1001,8 +1008,66 @@ void testGPU_FusionForLoop() {
}
}

void testGPU_Fusion() {}
void testGPU_FusionLoopUnroll() {
Fusion fusion;
FusionGuard fg(&fusion);

// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(1);
TensorView* tv1 = makeDummyTensor(1);

// Register your inputs
fusion.addInput(tv0);
fusion.addInput(tv1);

// Do math with it, it returns a `Val*` but can be static_casted back to
// TensorView
TensorView* tv2 = static_cast<TensorView*>(add(tv1, new Float(2.0)));
TensorView* tv3 = static_cast<TensorView*>(add(tv0, tv2));

// Register your outputs
fusion.addOutput(tv3);

tv3->split(0, 16);
tv3->split(0, 4);

// For all inputs, computeAt the output inline, temporaries should be squeezed
// between them
tv0->computeAt(tv3, 1);
tv1->computeAt(tv3, 1);

// Parallelize
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(1)->parallelize(ParallelType::Unroll);
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(0)->parallelize(ParallelType::BIDx);

torch::jit::fuser::cuda::CudaKernel prog;
prog.device_ = 0;
prog.grid(2);
prog.block(16);

// GPULower lower(&fusion);
// lower.printKernel(std::cout);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor input1 = at::ones({128}, options);
at::Tensor input2 = at::ones_like(input1);

at::Tensor output = at::empty_like(input1);
std::vector<at::Tensor> inputs{{input1, input2}};
std::vector<at::Tensor> outputs{{output}};

torch::jit::fuser::cuda::compileKernel(fusion, prog);
torch::jit::fuser::cuda::runTestKernel(prog, inputs, outputs);

at::Tensor check = at::full({128}, 4, options);

TORCH_CHECK(output.equal(check));
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)
// #endif // #if defined(USE_CUDA)
3 changes: 2 additions & 1 deletion test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ namespace jit {
_(GPU_FusionCodeGen2) \
_(GPU_FusionSimplePWise) \
_(GPU_FusionExecKernel) \
_(GPU_FusionForLoop)
_(GPU_FusionForLoop) \
_(GPU_FusionLoopUnroll)
#else
#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
2 changes: 2 additions & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
"torch/csrc/jit/codegen/cuda/kernel.cpp",
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
"torch/csrc/jit/codegen/cuda/manager.cpp",
"torch/csrc/jit/codegen/cuda/mutator.cpp",
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void IndexCompute::replayBackward(Merge* expr) {
ax >= 0 && ax < indices.size(),
"Hit an invalid MERGE transformation during IndexCompute, axis is not within bounds.");

Val* I = expr->in()->axis(ax + 1)->size();
Val* I = expr->in()->axis(ax + 1)->extent();
Val* ind = indices[ax];
indices[ax] = div(ind, I);
indices.insert(indices.begin() + ax + 1, mod(ind, I));
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ Val::Val(ValType _vtype, DataType _dtype) : vtype_{_vtype}, dtype_{_dtype} {
}
}

// Traverse origin of all values involved in constructing the provided val.
// Check if all values involved are constant values, meaning the provided
// val is also a constant value.
namespace {

struct ConstCheck : OptOutConstDispatch {
Expand Down Expand Up @@ -88,6 +91,22 @@ bool Val::isConstScalar() const {
return ConstCheck::isConst(this);
}

bool Val::isZeroInt() const {
if (isConstScalar() && getValType().value() == ValType::Scalar &&
getDataType().value() == DataType::Int &&
static_cast<const Int*>(this)->value().value() == 0)
return true;
return false;
}

bool Val::isOneInt() const {
if (isConstScalar() && getValType().value() == ValType::Scalar &&
getDataType().value() == DataType::Int &&
static_cast<const Int*>(this)->value().value() == 1)
return true;
return false;
}

c10::optional<DataType> Val::getDataType() const {
TORCH_INTERNAL_ASSERT(
dtype_ != DataType::Null, "Value does not have a data type.");
Expand Down Expand Up @@ -147,6 +166,10 @@ bool Scope::sameAs(const Scope& other) const {
return true;
}

void Scope::clear() {
this->exprs_ = std::vector<Expr*>();
}

bool IRInputOutput::hasInput(const Val* const input) const {
for (auto val : inputs_)
if (val == input)
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ struct TORCH_CUDA_API Val : public Statement {
return isScalar() && dtype_ == DataType::Int;
}

bool isZeroInt() const;
bool isOneInt() const;

// Returns the Expr that this value is an output of, returns nullptr if none
// was found
Expr* getOrigin();
Expand Down Expand Up @@ -251,6 +254,8 @@ struct TORCH_CUDA_API Scope {

bool sameAs(const Scope& other) const;

void clear();

private:
std::vector<Expr*> exprs_;
};
Expand Down
36 changes: 24 additions & 12 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,18 @@ struct TORCH_CUDA_API BinaryOp : public Expr {
};

/*
* Simply a representation of an iterable from 0 to size. TensorDomains which
* represent how to iterate over a tensor is made up of IterDomains. We directly
* set parallization strategies on IterDomains.
* Simply a representation of an iterable from start to extent. TensorDomains
* which represent how to iterate over a tensor is made up of IterDomains. We
* directly set parallization strategies on IterDomains.
*/
struct TORCH_CUDA_API IterDomain : public Val {
~IterDomain() = default;

IterDomain() = delete;

IterDomain(
Val* int_size,
Val* _start,
Val* _extent,
ParallelType _parallel_method = ParallelType::Serial,
bool _reduction_domain = false);

Expand Down Expand Up @@ -157,15 +158,25 @@ struct TORCH_CUDA_API IterDomain : public Val {
TORCH_CHECK(
t != ParallelType::Vectorize, "Vectorization not yet supported.");
if (t == ParallelType::Unroll)
TORCH_CHECK(false, "Unrolling not yet supported.");
TORCH_CHECK(
start()->isZeroInt() && extent()->isConstScalar(),
"Unrolling only supported with start = 0 and extent as a const int, but got ",
"a start of ",
start(),
" and extent ",
extent(),
" .");
}
}

ParallelType parallel_method() const noexcept {
return parallel_method_;
}

Val* size() const;
Val* start() const noexcept {
return start_;
}
Val* extent() const;

IterDomain(const IterDomain& other) = delete;
IterDomain& operator=(const IterDomain& other) = delete;
Expand All @@ -174,7 +185,8 @@ struct TORCH_CUDA_API IterDomain : public Val {
IterDomain& operator=(IterDomain&& other) = delete;

private:
Val* const size_;
Val* const start_;
Val* const extent_;
ParallelType parallel_method_ = ParallelType::Serial;
bool is_reduction_domain_;
};
Expand Down Expand Up @@ -317,7 +329,7 @@ struct TORCH_CUDA_API Reorder : public Expr {
};

/*
* ForLoop provides scoping around an int iterator from 0 to range. Exprs placed
* ForLoop provides scoping around an index through an IterDomain. Exprs placed
* in its body are considered inside the scope of the for loop. In the future
* the implementation should look quite different so that we can do proper
* dependency annalysis like in Fusion.
Expand All @@ -329,7 +341,7 @@ struct TORCH_API ForLoop : public Expr {
~ForLoop() = default;
ForLoop(
Val* _index,
IterDomain* _range,
IterDomain* _iter_domain,
const std::vector<Expr*>& _body = {},
Expr* parent_scope = nullptr);

Expand All @@ -343,8 +355,8 @@ struct TORCH_API ForLoop : public Expr {
return index_;
}

IterDomain* range() const noexcept {
return range_;
IterDomain* iter_domain() const noexcept {
return iter_domain_;
}

Scope& body() noexcept {
Expand All @@ -365,7 +377,7 @@ struct TORCH_API ForLoop : public Expr {

private:
Val* const index_;
IterDomain* const range_;
IterDomain* const iter_domain_;
Scope body_;
Expr* parent_scope_;
};
Expand Down
Loading