Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 51 additions & 52 deletions test/cpp/jit/test_gpu.cpp

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>

namespace torch {
namespace jit {
Expand All @@ -13,6 +14,11 @@ FusionGuard::FusionGuard(Fusion* fusion) {
ACTIVE_FUSION = fusion;
}

FusionGuard::FusionGuard(const cuda::CudaKernel* cuda_kernel) {
prev_fusion = ACTIVE_FUSION;
ACTIVE_FUSION = cuda_kernel->fusion_.get();
}

FusionGuard::~FusionGuard() {
ACTIVE_FUSION = prev_fusion;
}
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ struct TypeHash {
struct Fusion;
struct TensorView;

namespace cuda {
struct CudaKernel;
}

// Fusion Guard is our "context manager". It holds the actrive fusion and allows
// it to be accessed anywhere through FusionGuard::getCurFusion().
struct TORCH_CUDA_API FusionGuard {
Expand All @@ -59,6 +63,7 @@ struct TORCH_CUDA_API FusionGuard {

// Set the active fusion so it can be manipulated.
FusionGuard(Fusion* fusion);
FusionGuard(const cuda::CudaKernel* cuda_kernel);

~FusionGuard();

Expand Down
43 changes: 23 additions & 20 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct ExtractSizeStride {
const at::Tensor& val,
c10::optional<at::IntArrayRef> broadcasted_size = c10::nullopt) {
if (broadcasted_size) {
// [Note - broadcast support in integration]
// PyTorch follows numpy broadcasting rule.
// (https://numpy.org/doc/stable/user/basics.broadcasting.html)
//
Expand Down Expand Up @@ -144,7 +145,7 @@ struct KernelArgumentHolder {
}
};

std::pair<std::string, std::string> codeGeneration(Fusion& fusion) {
std::pair<std::string, std::string> codeGeneration(Fusion* fusion) {
std::stringstream str_stream;
str_stream << "namespace " << CG_NAMESPACE << " {\n"
<< code_template_tensor_struct << "\n"
Expand All @@ -153,7 +154,7 @@ std::pair<std::string, std::string> codeGeneration(Fusion& fusion) {
<< code_helper_funcs << "\n"
<< code_template_block_reduction << "\n";
std::stringstream cdg;
GPULower gpulw(&fusion);
GPULower gpulw(fusion);
gpulw.printKernel(str_stream, KERNEL_NAME);
str_stream << "\n} // namespace";

Expand All @@ -171,16 +172,22 @@ bool validateKernelArgTensor(
msg << "Argument is a tensor, but the parameter is not.";
return false;
}

// Check the rank of the tensors.
size_t arg_dim = arg.dim();
// Note: This requires current Fusion to be active.
size_t param_dim = TensorDomain::noReductions(
static_cast<const TensorView*>(param)->getRootDomain())
.size();
if (arg_dim != param_dim) {
// see [Note - broadcast support in integration]
// Because of broadcasting support handled in integration, we relax the rank
// check as necessary.
if (arg_dim > param_dim) {
msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is "
<< param_dim;
return false;
}

if (arg.device().index() != device_index) {
msg << "Argument is on device that is not compiled for";
return false;
Expand Down Expand Up @@ -256,12 +263,15 @@ void validateKernelArgs(
const CudaKernel& entry,
const at::ArrayRef<IValue>& inputs,
const std::vector<at::Tensor>& outputs) {
// This is necessary as we were traversing the fusion graph later in the check
FusionGuard fg(&entry);
// Check inputs
TORCH_INTERNAL_ASSERT(
inputs.size() == entry.inputs.size(), "Wrong number of kernel inputs.");
inputs.size() == entry.fusion_->inputs().size(),
"Wrong number of kernel inputs.");
for (size_t i = 0; i < inputs.size(); ++i) {
const IValue& arg = inputs[i];
const Val* const param = entry.inputs[i];
const Val* const param = entry.fusion_->inputs()[i];
std::stringstream msg;
TORCH_INTERNAL_ASSERT(
validateKernelArg(arg, param, entry.device_, msg),
Expand All @@ -272,15 +282,15 @@ void validateKernelArgs(
}

TORCH_INTERNAL_ASSERT(
entry.outputs.size() != 0,
entry.fusion_->outputs().size() != 0,
"Kernel should have at least one output tensor.");

TORCH_INTERNAL_ASSERT(
outputs.size() == entry.outputs.size(),
outputs.size() == entry.fusion_->outputs().size(),
"Wrong number of kernel outputs.");
for (size_t i = 0; i < outputs.size(); ++i) {
const at::Tensor& arg = outputs[i];
const Val* const param = entry.outputs[i];
const Val* const param = entry.fusion_->outputs()[i];
std::stringstream msg;
TORCH_INTERNAL_ASSERT(
validateKernelArgTensor(arg, param, entry.device_, msg),
Expand Down Expand Up @@ -310,18 +320,11 @@ bool NaivePWKernelArgsReq::matchKernelSize(const at::ArrayRef<IValue> inputs) {
return true;
}

void compileKernel(Fusion& fusion, CudaKernel* entry) {
void compileKernel(CudaKernel* entry) {
// generating cuda code;
std::string code;
std::string func_name;
std::tie(func_name, code) = codeGeneration(fusion);

// Keep input and output reference to validate/line up arguments
for (auto inp : fusion.inputs())
entry->inputs.push_back(inp);

for (auto out : fusion.outputs())
entry->outputs.push_back(out);
std::tie(func_name, code) = codeGeneration(entry->fusion_.get());

static int32_t compiled_kernel_id = 0;

Expand All @@ -338,7 +341,7 @@ void compileKernel(Fusion& fusion, CudaKernel* entry) {

// set device for the operation;
at::cuda::set_device(entry->device_);
entry->has_random_ = fusion.hasRNG();
entry->has_random_ = entry->fusion_->hasRNG();

const auto prop = at::cuda::getCurrentDeviceProperties();
int nvrtc_major, nvrtc_minor;
Expand Down Expand Up @@ -528,7 +531,7 @@ void runTestKernel(

KernelArgumentHolder kernel_args;

auto exprs = entry->outputs[0]->fusion()->exprs(true);
auto exprs = entry->fusion_->exprs(true);
bool has_reduction = std::any_of(exprs.begin(), exprs.end(), [](Expr* expr) {
return expr->getExprType() == ExprType::ReductionOp;
});
Expand All @@ -542,7 +545,7 @@ void runTestKernel(
input.toTensor().device().index() == entry->device_,
"input to kernel on device that is not compiled for");
TORCH_INTERNAL_ASSERT(
!entry->outputs.empty(),
!entry->fusion_->outputs().empty(),
"No output found for this kernel, aborting.");
if (has_reduction) {
kernel_args.push(input.toTensor());
Expand Down
11 changes: 6 additions & 5 deletions torch/csrc/jit/codegen/cuda/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ struct NaivePWKernelArgsReq : KernelArgsReq {

class CudaKernel {
public:
std::deque<Val*> inputs;
std::deque<Val*> outputs;

CudaKernel() = default;
CudaKernel() {
fusion_ = std::make_unique<Fusion>();
}

CUmodule& getModule() {
return module_;
Expand Down Expand Up @@ -74,12 +73,14 @@ class CudaKernel {
dim3 block_;
dim3 grid_;
bool has_random_;

std::unique_ptr<Fusion> fusion_;
};

// compile Fusion to CUDA functions:
// 1. JIT compilation via nvrtc to generate CUDA c++ kernel code;
// 2. CUDA Drive API to load CUDA c++ kernel code as function_;
TORCH_CUDA_API void compileKernel(Fusion& fusion, CudaKernel* entry);
TORCH_CUDA_API void compileKernel(CudaKernel* entry);

// run loaded kernel through Function.
// inputs/outputs is given in the sense of a PyTorch JIT ir node. This function
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/kernel_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ at::optional<CudaKernel*> CudaKernelCache::getKernelPtr(

CudaKernel* CudaKernelCache::allocateKernelInCache(
std::unique_ptr<KernelArgsReq>&& args_req) {
kernels_.emplace_back(std::make_pair(std::move(args_req), CudaKernel()));
kernels_.emplace_back(std::move(args_req), CudaKernel());
return &(kernels_.back().second);
}

Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/codegen/cuda/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@ class CudaFusionManager {
makePWKernelSupport(inputs));

// lower torch::jit::Graph to torch::jit::fuser::cuda::fusion
Fusion fusion;
// TODO: pass contiguity infor as well as size req, so we can apply proper
// transform to computation
// we should propagate more information back:
// 1. device;
// 2. launch config;
parseJitIR(graph, fusion, cuda_kernel.value());
parseJitIR(graph, cuda_kernel.value());

// find device in inputs.
for (const auto& input : inputs) {
Expand All @@ -123,7 +122,7 @@ class CudaFusionManager {
}

// NVRTC compile kernel
compileKernel(fusion, cuda_kernel.value());
compileKernel(cuda_kernel.value());

runKernel(*cuda_kernel, inputs, outputs);
}
Expand Down
27 changes: 10 additions & 17 deletions torch/csrc/jit/codegen/cuda/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ class IrParser {
static const int unroll_factor = 4;

public:
IrParser(
std::shared_ptr<Graph> graph,
Fusion& fusion,
CudaKernel* cuda_kernel)
: graph_(std::move(graph)), fusion_(&fusion), cuda_kernel_(cuda_kernel) {
IrParser(std::shared_ptr<Graph> graph, CudaKernel* cuda_kernel)
: graph_(std::move(graph)), cuda_kernel_(cuda_kernel) {
if (init_registry_) {
registerJitOperator();
init_registry_ = false;
Expand All @@ -52,7 +49,7 @@ class IrParser {

// Fuses pointwise ops with loop unrolling (factor = 4).
void parse() {
FusionGuard fg(fusion_);
FusionGuard fg(cuda_kernel_->fusion_.get());
auto block = graph_->block();

// in case of broadcast, we don't support explicit broadcast, so we need to
Expand All @@ -67,7 +64,7 @@ class IrParser {
// we only explicitly register inputs in the graph.
for (auto val : block->inputs()) {
TORCH_CHECK(registerValue(val, broadcast_dim));
fusion_->addInput(value_map_[val->unique()]);
cuda_kernel_->fusion_->addInput(value_map_[val->unique()]);

auto opt_dtype = value_map_[val->unique()]->getDataType();
// computation promotion, we cast fp16 inputs to fp32 and use promoted
Expand Down Expand Up @@ -103,7 +100,7 @@ class IrParser {
out = static_cast<TensorView*>(castOp(DataType::Half, out));
}

fusion_->addOutput(out);
cuda_kernel_->fusion_->addOutput(out);

// Merge all dimensions because we're only supporting pointwise
while (out->nDims() > 1)
Expand All @@ -120,19 +117,19 @@ class IrParser {

// Run through outputs, grab all inputs of outputs
// squeeze with computeAt to set overall structure.
for (auto output : fusion_->outputs()) {
for (auto output : cuda_kernel_->fusion_->outputs()) {
if (output->getValType() != ValType::TensorView)
continue;
TensorView* out_tv = static_cast<TensorView*>(output);
for (Val* inp : fusion_->inputsOf(output)) {
for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) {
if (inp->getValType().value() == ValType::TensorView)
static_cast<TensorView*>(inp)->computeAt(out_tv, 1);
}
out_tv->axis(0)->parallelize(ParallelType::BIDx);
}

// Run through intermediates, unroll, and bind their axes
for (auto val : fusion_->vals()) {
for (auto val : cuda_kernel_->fusion_->vals()) {
if (val->getValType().value() != ValType::TensorView)
continue;
TensorView* tv = static_cast<TensorView*>(val);
Expand Down Expand Up @@ -539,7 +536,6 @@ class IrParser {
}

std::shared_ptr<Graph> graph_;
Fusion* fusion_;
CudaKernel* cuda_kernel_;

// maps from JitValue::unique() to fusion Val;
Expand All @@ -564,11 +560,8 @@ bool isNodeParsible(const Node* const node) {
return IrParser::canParseNode(node);
}

void parseJitIR(
std::shared_ptr<Graph>& graph,
Fusion& fusion,
CudaKernel* cuda_kernel) {
IrParser parser(graph, fusion, cuda_kernel);
void parseJitIR(std::shared_ptr<Graph>& graph, CudaKernel* cuda_kernel) {
IrParser parser(graph, cuda_kernel);
parser.parse();
}

Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/codegen/cuda/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ TORCH_CUDA_API bool isNodeParsible(const Node* const node);
// lowers PyTorch jit graph to `Fusion`.
TORCH_CUDA_API void parseJitIR(
std::shared_ptr<Graph>& graph,
Fusion& fusion,
CudaKernel* cuda_kernel);

} // namespace cuda
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/codegen/cuda/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class NaiveShapeTypePropagator {
case aten::pow:
case aten::remainder:
case aten::fmod:
case aten::lerp:
// add/sub could be ternary op and the third argument does not contribute
// to neither type promoteion nor shape.
case aten::add:
Expand Down Expand Up @@ -131,6 +132,15 @@ class NaiveShapeTypePropagator {
node->output()->setType(promoted_type);
break;
}
case aten::addcmul: {
auto promoted_type = binary_broadcast_type(
node->input(1)->type()->cast<TensorType>(),
node->input(2)->type()->cast<TensorType>());
promoted_type = binary_broadcast_type(
promoted_type, node->input(0)->type()->cast<TensorType>());
node->output()->setType(promoted_type);
break;
}
default:
TORCH_CHECK(false, "shape/type inference failed.");
// TODO: generate a proper error log, as this probably means something
Expand Down