Skip to content
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp
Expand Down
57 changes: 38 additions & 19 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/lower2device.h>
#include <torch/csrc/jit/codegen/cuda/mutator.h>
#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
#include <torch/csrc/jit/codegen/cuda/transform_rfactor.h>

Expand Down Expand Up @@ -370,6 +371,8 @@ void testGPU_FusionClear() {
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::Unroll);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
}

// 2. Clear the IR
Expand All @@ -382,6 +385,8 @@ void testGPU_FusionClear() {
TORCH_CHECK(fusion.inputs().empty());
TORCH_CHECK(fusion.outputs().empty());

TORCH_CHECK(fusion.launch_configs().empty());

TORCH_CHECK(!fusion.hasReduction());
TORCH_CHECK(!fusion.hasBlockReduction());
TORCH_CHECK(!fusion.hasGridReduction());
Expand Down Expand Up @@ -415,8 +420,11 @@ void testGPU_FusionClear() {
at::Tensor input2 = at::randn_like(input1);
at::Tensor output = at::empty_like(input1);

fuser::cuda::scheduleFusion(&fusion, {input1, input2});
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output});
prog.device_ = 0;
torch::jit::fuser::cuda::runKernel(
&prog, {input1, input2}, {output}, output.sizes().vec());

at::Tensor tv2_ref = input2 + 2.0;
at::Tensor output_ref = input1 + tv2_ref;
Expand Down Expand Up @@ -449,6 +457,10 @@ void testGPU_FusionCopy() {

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

original_fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
original_fusion.setLaunchConfig(
LaunchConfigType::BIDx, tv3->axis(0)->rawExtent());
}

// Test copy before lowering
Expand Down Expand Up @@ -519,6 +531,9 @@ void testGPU_FusionMove() {

tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1));
fusion.setLaunchConfig(LaunchConfigType::BIDx, tv3->axis(0)->rawExtent());
}

std::stringstream original_ir;
Expand Down Expand Up @@ -1010,6 +1025,10 @@ void testGPU_FusionParser() {
prog.block(32);
prog.device_ = 0;
fuser::cuda::parseJitIR(g, &prog);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input1 = at::randn({16}, options);
at::Tensor input2 = at::randn({16}, options);
fuser::cuda::scheduleFusion(prog.fusion_.get(), {input1, input2});

// CONSIDER:
// 1. this can be moved to a dedicated "golden" file
Expand All @@ -1018,32 +1037,32 @@ void testGPU_FusionParser() {
__global__ void CUDAGeneratedKernel(Tensor<float, 1> T0, Tensor<float, 1> T1, Tensor<float, 1> T3){
float T2[4];
if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
for(size_t i40 = 0; i40 < 4; ++i40 ) {
T2[ i40 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
for(size_t i49 = 0; i49 < 4; ++i49 ) {
T2[ i49 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
}
} else {
for(size_t i40 = 0; i40 < 4; ++i40 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T2[ i40 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
for(size_t i49 = 0; i49 < 4; ++i49 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T2[ i49 ]
= T0[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]
* T1[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ];
}
}
}
if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
for(size_t i41 = 0; i41 < 4; ++i41 ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i41 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
for(size_t i50 = 0; i50 < 4; ++i50 ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i50 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
}
} else {
for(size_t i41 = 0; i41 < 4; ++i41 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i41 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
for(size_t i50 = 0; i50 < 4; ++i50 ) {
if ( ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) {
T3[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ]
= T2[ i50 ]
* T0[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ];
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ libtorch_cuda_sources = [
"torch/csrc/jit/codegen/cuda/dispatch.cpp",
"torch/csrc/jit/codegen/cuda/expr_evaluator.cpp",
"torch/csrc/jit/codegen/cuda/fusion.cpp",
"torch/csrc/jit/codegen/cuda/scheduler.cpp",
"torch/csrc/jit/codegen/cuda/graph_fuser.cpp",
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ void swap(Fusion& a, Fusion& b) noexcept {
swap(a.inputs_, b.inputs_);
swap(a.outputs_, b.outputs_);

swap(a.launch_configs_, b.launch_configs_);

// Fixup the Statement::fusion_ links for a
for (auto val : a.val_set_) {
val->fusion_ = &a;
Expand Down Expand Up @@ -138,6 +140,11 @@ Fusion::Fusion(const Fusion& other) {

inputs_ = ir_cloner.clone(other.inputs_);
outputs_ = ir_cloner.clone(other.outputs_);

for (const auto& kv : other.launch_configs_) {
auto val = ir_cloner.clone(kv.second);
launch_configs_.insert({kv.first, val});
}
}

Fusion::Fusion(Fusion&& other) noexcept {
Expand Down Expand Up @@ -189,6 +196,8 @@ void Fusion::clear() noexcept {

inputs_.clear();
outputs_.clear();

launch_configs_.clear();
}

void Fusion::removeExpr(Expr* expr) {
Expand Down
22 changes: 22 additions & 0 deletions torch/csrc/jit/codegen/cuda/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,30 @@ class TORCH_CUDA_API Fusion final {
return outputs_;
}

const auto& launch_configs() const {
return launch_configs_;
}

bool hasInput(const Val* val) const;
bool hasOutput(const Val* val) const;

void replaceInput(Val* replace, Val* with);
void replaceOutput(Val* replace, Val* with);

// set new launch config value for given type; return the previous value;
const Val* setLaunchConfig(LaunchConfigType type, Val* val) {
TORCH_CHECK(val->fusion() == this);
const auto ret = getLaunchConfig(type);
launch_configs_[type] = val;
return ret;
}

// retrieve launch config value for given type;
Val* getLaunchConfig(LaunchConfigType type) {
const auto it = launch_configs_.find(type);
return it != launch_configs_.end() ? it->second : nullptr;
}

private:
// Return an int that monotonically increases for each val/expr, some are
// explicitly incremented by type.
Expand Down Expand Up @@ -281,6 +299,10 @@ class TORCH_CUDA_API Fusion final {
// Fusion inputs and outputs
std::vector<Val*> inputs_;
std::vector<Val*> outputs_;

// values for launch configuration:
// compatible_flag, BlockDim.x/y/z, GridDim.x/y/z, shared_memory,
std::unordered_map<LaunchConfigType, Val*> launch_configs_;
};

} // namespace fuser
Expand Down
69 changes: 38 additions & 31 deletions torch/csrc/jit/codegen/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/ArrayRef.h>

#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/kernel.h>
#include <torch/csrc/jit/codegen/cuda/kernel_arg.h>
#include <torch/csrc/jit/codegen/cuda/kernel_resource_strings.h>
Expand Down Expand Up @@ -544,31 +545,6 @@ void runKernel(
// Naive launch config;
const size_t numel = outputs[0].numel();

int blocks = 1;
int thread_x = 1;
int thread_y = 1;
if (!entry->reduction_axes_.empty()) {
// TODO: MAJOR HACK! Expr evaluation makes launch configuration much easier
blocks = numel;
// Translated to `fcd_reduction`
if (entry->reduction_axes_.back() ==
outputs[0].dim() + ((int)entry->reduction_axes_.size()) - 1) {
thread_x = kFcdReductionThreadX;
thread_y = 1;
} else {
thread_x = kNonFcdReductionThreadX;
thread_y = kNonFcdReductionThreadY;
}
} else {
// TODO: we can't randomly clap down this until we got striding.
blocks = ceilDiv(numel, kPwThreadX * entry->unroll_factor_);
thread_x = kPwThreadX;
thread_y = 1;
}
const auto nBlocks = blocks;
const auto nThreadx = thread_x;
const auto nThready = thread_y;

KernelArgumentHolder kernel_args;

// Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O
Expand All @@ -586,10 +562,41 @@ void runKernel(
kernel_args.push(output);
}

Fusion* fusion = entry->fusion_.get();
FusionGuard fg(fusion);
EvaluationContext eval_context(fusion);
for (int i = 0; i < (int)inputs.size(); i++) {
if (inputs[i].isTensor()) {
ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_shape);
int nDims = ess.sizes.size();
TensorView* tv = fusion->inputs()[i]->as<TensorView>();
for (int j = 0; j < nDims; j++) {
eval_context.bind(tv->getRootDomain()[j]->extent(), ess.sizes[j]);
}
}
}

auto expr_eval_fn = [&](LaunchConfigType type) {
const auto val = ExpressionEvaluator::evaluate(
fusion->getLaunchConfig(type), &eval_context);
TORCH_CHECK(
val.has_value(), "scheduler didn't bind launch configs properly");
return val.value();
};

const int nBlocks_x = expr_eval_fn(LaunchConfigType::BIDx);
const int nBlocks_y = expr_eval_fn(LaunchConfigType::BIDy);
const int nBlocks_z = expr_eval_fn(LaunchConfigType::BIDz);
const auto nThreadx = expr_eval_fn(LaunchConfigType::TIDx);
const auto nThready = expr_eval_fn(LaunchConfigType::TIDy);
const auto nThreadz = expr_eval_fn(LaunchConfigType::TIDz);
const auto shared_memory = expr_eval_fn(LaunchConfigType::SharedMemory);

// TODO: this probably won't work for us.
if (entry->has_random_) {
std::pair<uint64_t, uint64_t> philox_engine_inputs;
const auto rand_offset = 4 * (std::ceil(numel / (4.0 * 128 * nBlocks)) + 1);
const auto rand_offset =
4 * (std::ceil(numel / (4.0 * 128 * nBlocks_x)) + 1);
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
{
// See Note [Acquire lock when using random generators]
Expand All @@ -605,13 +612,13 @@ void runKernel(
// launch kernel;
AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel(
entry->function_,
nBlocks,
1,
1,
nBlocks_x,
nBlocks_y,
nBlocks_z,
nThreadx,
nThready,
1,
0,
nThreadz,
shared_memory,
stream,
kernel_args.getBuffer(),
nullptr));
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/jit/codegen/cuda/manager.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <torch/csrc/jit/codegen/cuda/manager.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/kernel_cache.h>
#include <torch/csrc/jit/codegen/cuda/parser.h>
#include <torch/csrc/jit/codegen/cuda/scheduler.h>
#include <torch/csrc/jit/codegen/cuda/shape_inference.h>
#include <torch/csrc/jit/codegen/cuda/utils.h>
#include <torch/csrc/jit/passes/canonicalize.h>
Expand Down Expand Up @@ -103,6 +105,7 @@ class CudaFusionManager {
// 1. device;
// 2. launch config;
parseJitIR(graph, cuda_kernel.value());
scheduleFusion(cuda_kernel.value()->fusion_.get(), inputs);

// find device in inputs.
for (const auto& input : inputs) {
Expand Down
Loading