From 5ed6ec1312a4254f3431463513ab1b996822ba3e Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 18 Jul 2020 15:09:46 -0400 Subject: [PATCH 1/3] Finish converting to FusionExecutor exclusively. --- caffe2/CMakeLists.txt | 1 - test/cpp/jit/test_gpu.cpp | 125 ++--- test/test_jit_cuda_fuser.py | 3 +- tools/build_variables.bzl | 1 - torch/csrc/jit/codegen/cuda/executor.cpp | 7 +- torch/csrc/jit/codegen/cuda/executor.h | 5 +- .../jit/codegen/cuda/executor_kernel_arg.cpp | 18 - .../jit/codegen/cuda/executor_kernel_arg.h | 6 - torch/csrc/jit/codegen/cuda/fusion.cpp | 15 - torch/csrc/jit/codegen/cuda/fusion.h | 27 - torch/csrc/jit/codegen/cuda/kernel.cpp | 510 ------------------ torch/csrc/jit/codegen/cuda/kernel.h | 52 -- torch/csrc/jit/codegen/cuda/kernel_cache.cpp | 47 +- torch/csrc/jit/codegen/cuda/kernel_cache.h | 115 +--- torch/csrc/jit/codegen/cuda/manager.cpp | 123 ++--- torch/csrc/jit/codegen/cuda/parser.h | 1 - torch/csrc/jit/codegen/cuda/scheduler.cpp | 48 +- torch/csrc/jit/codegen/cuda/scheduler.h | 1 - 18 files changed, 122 insertions(+), 983 deletions(-) delete mode 100644 torch/csrc/jit/codegen/cuda/kernel.cpp delete mode 100644 torch/csrc/jit/codegen/cuda/kernel.h diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9772f5877a88..43402f265e16 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -464,7 +464,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_nodes.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index c32e6e7c1d70..70a1c28f6aee 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -1,4 +1,4 @@ -#if defined(USE_CUDA) +// #if defined(USE_CUDA) #include @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -21,6 +20,9 @@ #include #include "torch/csrc/jit/ir/irparser.h" +#include +#include + #include // Tests go in torch::jit @@ -49,26 +51,6 @@ void checkIntValue( TORCH_CHECK(actual_value.value() == expected_value); } -void setupLaunchConfig( - Fusion* fusion, - int tid_x = 1, - int tid_y = 1, - int tid_z = 1, - int gid_x = 1, - int gid_y = 1, - int gid_z = 1, - int sm_size = 0) { - fusion->setLaunchConfig(LaunchConfigType::TIDx, new Int(tid_x)); - fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(tid_y)); - fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(tid_z)); - fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(gid_x)); - fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(gid_y)); - fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(gid_z)); - fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(sm_size)); - // no need to set LaunchConfigType::Compatible, as we are not trigger via - // kernel selection in cache -} - } // namespace // 1. Test cases are void() functions. @@ -399,8 +381,6 @@ void testGPU_FusionClear() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); - - fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); } // 2. Clear the IR @@ -413,8 +393,6 @@ void testGPU_FusionClear() { TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); - TORCH_CHECK(fusion.launch_configs().empty()); - TORCH_CHECK(!fusion.hasReduction()); TORCH_CHECK(!fusion.hasBlockReduction()); TORCH_CHECK(!fusion.hasGridReduction()); @@ -553,9 +531,6 @@ void testGPU_FusionMove() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); - - fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); - fusion.setLaunchConfig(LaunchConfigType::BIDx, tv3->axis(0)->rawExtent()); } std::stringstream original_ir; @@ -1038,15 +1013,12 @@ void testGPU_FusionParser() { } } - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(fuser::cuda::parseJitIR(g)); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); - prog.setDevice(0); + auto fusion = fuser::cuda::parseJitIR(g); + FusionGuard fg(fusion.get()); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input1 = at::randn({16}, options); at::Tensor input2 = at::randn({16}, options); - fuser::cuda::scheduleFusion(prog.fusion(), {input1, input2}); + fuser::cuda::scheduleFusion(fusion.get(), {input1, input2}); // CONSIDER: // 1. this can be moved to a dedicated "golden" file @@ -1055,57 +1027,55 @@ void testGPU_FusionParser() { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ float T2[4]; if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i27 = 0; i27 < 4; ++i27 ) { - T2[ i27 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i20 = 0; i20 < 4; ++i20 ) { + T2[ i20 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i20 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i20 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } else { - for(size_t i27 = 0; i27 < 4; ++i27 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T2[ i27 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i27 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i20 = 0; i20 < 4; ++i20 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i20 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T2[ i20 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i20 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i20 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } } if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i28 = 0; i28 < 4; ++i28 ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i28 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i21 = 0; i21 < 4; ++i21 ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i21 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i21 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i21 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } else { - for(size_t i28 = 0; i28 < 4; ++i28 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i28 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i28 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i21 = 0; i21 < 4; ++i21 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i21 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i21 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i21 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i21 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } } } )"; - GPULower gpulw(fusion); - std::stringstream actual_kernel; - actual_kernel << "\n"; - gpulw.printKernel(actual_kernel); - if (expected_kernel.size() != actual_kernel.str().size() || - expected_kernel.compare(actual_kernel.str()) != 0) { + std::string actual_kernel = GPULower(fusion.get()).getKernel(); + actual_kernel = "\n" + actual_kernel; + if (expected_kernel.size() != actual_kernel.size() || + expected_kernel.compare(actual_kernel) != 0) { std::cerr << " Codegen mismatch, codegen possibly changed, or is incorrect. " << " \n ========= EXPECTED ========= \n" << expected_kernel << "\n========= ACTUAL ========== \n" - << actual_kernel.str() << "\n=================" << std::endl; + << actual_kernel << "\n=================" << std::endl; TORCH_CHECK(false); } - fuser::cuda::compileKernel(&prog); - at::Tensor output = at::empty_like(input1); + cuda::FusionExecutor fe; + fe.compileFusion(fusion.get()); // no broadcasting needed, omitting the last optional argument; - torch::jit::fuser::cuda::runKernel(&prog, {input1, input2}, {output}); + auto outputs = fe.runFusion({input1, input2}); at::Tensor output_ref = input1 * input2 * input1; - TORCH_CHECK(output_ref.equal(output)); + TORCH_CHECK(output_ref.equal(outputs[0])); } void testGPU_FusionForLoop() { @@ -3988,18 +3958,16 @@ void testGPU_FusionReductionScheduler() { constexpr int tid_x = 4096; constexpr int red_dim = 1; - torch::jit::fuser::cuda::CudaKernel prog; - prog.setFusionPtr(std::make_unique()); - Fusion* fusion = prog.fusion(); - FusionGuard fg(fusion); + Fusion fusion; + FusionGuard fg(&fusion); // Set up your input tensor views TensorView* tv0 = makeDummyTensor(2); - fusion->addInput(tv0); + fusion.addInput(tv0); TensorView* tv1 = reductionOp(BinaryOpType::Add, {red_dim}, new Float(0), tv0); - fusion->addOutput(tv1); + fusion.addOutput(tv1); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor input = at::rand({bid_x, tid_x}, options); @@ -4009,20 +3977,19 @@ void testGPU_FusionReductionScheduler() { const at::ArrayRef inputs({input}); TORCH_CHECK( - cuda::scheduleReduction(prog.fusion(), inputs), + cuda::scheduleReduction(&fusion, inputs), "Reduction schedule was not generated!"); - prog.setDevice(0); - - torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runKernel(&prog, {input}, {cg_output}, c10::nullopt); - + cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + // no broadcasting needed, omitting the last optional argument; + auto outputs = fe.runFusion({input}); auto aten_output = input.sum({red_dim}); TORCH_CHECK( - aten_output.allclose(cg_output), + aten_output.allclose(outputs[0]), "Error of: ", - aten_output.sub(cg_output).abs().max()); + aten_output.sub(outputs[0]).abs().max()); } // Simple reduction parallelized on a symbolic size. @@ -4079,4 +4046,4 @@ void testGPU_FusionSymbolicReduction() { } // namespace jit } // namespace torch -#endif // #if defined(USE_CUDA) +// #endif // #if defined(USE_CUDA) diff --git a/test/test_jit_cuda_fuser.py b/test/test_jit_cuda_fuser.py index 4fe73677c1ea..5338e0ad705f 100644 --- a/test/test_jit_cuda_fuser.py +++ b/test/test_jit_cuda_fuser.py @@ -167,7 +167,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: float): self.assertEqual(o, jit_o) self.assertGraphContains(t_jit.graph_for(x, y, 2.0), FUSION_GROUP) - @unittest.skipIf(True, "real broadcast with different output not supported yet") + @unittest.skipIf(True, "Broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") @@ -189,6 +189,7 @@ def t(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): # Currently cannot fuse this self.assertGraphContains(t_jit.graph_for(x, y, z), FUSION_GROUP) + @unittest.skipIf(True, "Broadcast with different output not supported yet") @unittest.skipIf(not RUN_CUDA, "requires CUDA") @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING and GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Requires fusion optimization pass to be effective") diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 766cb9524a96..2c04573bb797 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -347,7 +347,6 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/ir_nodes.cpp", "torch/csrc/jit/codegen/cuda/ir_iostream.cpp", "torch/csrc/jit/codegen/cuda/iter_visitor.cpp", - "torch/csrc/jit/codegen/cuda/kernel.cpp", "torch/csrc/jit/codegen/cuda/kernel_cache.cpp", "torch/csrc/jit/codegen/cuda/lower_index.cpp", "torch/csrc/jit/codegen/cuda/lower_loops.cpp", diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 4088fe7dcb54..2d807b1441ff 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1,11 +1,13 @@ #include #include #include -#include #include +#include #include +#include +#include namespace torch { namespace jit { @@ -30,7 +32,7 @@ std::string FusionExecutor::getStructuredCode(std::string kernel) { return code; } -void FusionExecutor::compileFusion(Fusion* fusion) { +void FusionExecutor::compileFusion(Fusion* fusion, CompileOptions options) { TORCH_INTERNAL_ASSERT( !fusion->outputs().empty(), "No output found for this kernel, aborting."); @@ -41,6 +43,7 @@ void FusionExecutor::compileFusion(Fusion* fusion) { fusion_ = *fusion; FusionGuard fg(&fusion_); + options_ = options; fusion_id = ++fusion_id_counter; has_random = fusion->hasRNG(); diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 51260cf4f37a..7d988a07a1ac 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -16,6 +15,7 @@ namespace jit { namespace fuser { namespace cuda { +// TODO: Should this actually be in launch params? struct TORCH_CUDA_API CompileOptions { c10::Device device = c10::Device(c10::DeviceType::CUDA, 0); }; @@ -23,9 +23,8 @@ struct TORCH_CUDA_API CompileOptions { class TORCH_CUDA_API FusionExecutor { public: FusionExecutor() {} - FusionExecutor(CompileOptions options) : options_(options) {} - void compileFusion(Fusion* fusion); + void compileFusion(Fusion* fusion, CompileOptions options = CompileOptions()); std::vector runFusion( const at::ArrayRef& inputs, diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp index c46ecd80f48e..6bbe48cb2a46 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.cpp @@ -46,24 +46,6 @@ void KernelArgumentHolder::push(const at::Tensor& tensor) { arguments_.push_back(std::move(tensor_arg)); } -// Push a tensor to the arguments -void KernelArgumentHolder::push( - const at::Tensor& val, - c10::optional broadcasted_size) { - changed_ = true; - ExtractSizeStride ess(val, std::move(broadcasted_size)); - int nDims = ess.sizes.size(); - - c10::ScalarType dtype = val.scalar_type(); - std::unique_ptr tensor_arg = getTensorArg(dtype, nDims); - tensor_arg->setPointer(val.data_ptr()); - for (int i = 0; i < nDims; i++) { - tensor_arg->setSize(i, ess.sizes[i]); - tensor_arg->setStride(i, ess.strides[i]); - } - arguments_.push_back(std::move(tensor_arg)); -} - // Push a scalar or integer to the arguments void KernelArgumentHolder::push(const IValue& val) { changed_ = true; diff --git a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h index 3ae9e3c55414..84755b85bf06 100644 --- a/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h +++ b/torch/csrc/jit/codegen/cuda/executor_kernel_arg.h @@ -143,12 +143,6 @@ class KernelArgumentHolder { // Push a tensor to the arguments void push(const at::Tensor& tensor); - // We want to get rid of this version, it's a hack for now because we don't - // have great broadcast support for translation. - void push( - const at::Tensor& tensor, - c10::optional broadcasted_size); - // Push a scalar or integer to the arguments void push(const IValue& val); diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index bf0f488e6e94..d83e005b5ba8 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace torch { @@ -17,11 +16,6 @@ FusionGuard::FusionGuard(Fusion* fusion) { ACTIVE_FUSION = fusion; } -FusionGuard::FusionGuard(cuda::CudaKernel* cuda_kernel) { - prev_fusion = ACTIVE_FUSION; - ACTIVE_FUSION = cuda_kernel->fusion(); -} - FusionGuard::~FusionGuard() { ACTIVE_FUSION = prev_fusion; } @@ -79,8 +73,6 @@ void swap(Fusion& a, Fusion& b) noexcept { swap(a.inputs_, b.inputs_); swap(a.outputs_, b.outputs_); - swap(a.launch_configs_, b.launch_configs_); - // Fixup the Statement::fusion_ links for a for (auto val : a.val_set_) { val->fusion_ = &a; @@ -140,11 +132,6 @@ Fusion::Fusion(const Fusion& other) { inputs_ = ir_cloner.clone(other.inputs_); outputs_ = ir_cloner.clone(other.outputs_); - - for (const auto& kv : other.launch_configs_) { - auto val = ir_cloner.clone(kv.second); - launch_configs_.insert({kv.first, val}); - } } Fusion::Fusion(Fusion&& other) noexcept { @@ -196,8 +183,6 @@ void Fusion::clear() noexcept { inputs_.clear(); outputs_.clear(); - - launch_configs_.clear(); } void Fusion::removeExpr(Expr* expr) { diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index e00b8937c701..f882390b8584 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -51,10 +51,6 @@ struct TypeHash { class Fusion; class TensorView; -namespace cuda { -class CudaKernel; -} - // Fusion Guard is our "context manager". It holds the actrive fusion and allows // it to be accessed anywhere through FusionGuard::getCurFusion(). class TORCH_CUDA_API FusionGuard { @@ -63,7 +59,6 @@ class TORCH_CUDA_API FusionGuard { // Set the active fusion so it can be manipulated. explicit FusionGuard(Fusion* fusion); - explicit FusionGuard(cuda::CudaKernel* cuda_kernel); ~FusionGuard(); @@ -241,30 +236,12 @@ class TORCH_CUDA_API Fusion final { return outputs_; } - const auto& launch_configs() const { - return launch_configs_; - } - bool hasInput(const Val* val) const; bool hasOutput(const Val* val) const; void replaceInput(Val* replace, Val* with); void replaceOutput(Val* replace, Val* with); - // set new launch config value for given type; return the previous value; - const Val* setLaunchConfig(LaunchConfigType type, Val* val) { - TORCH_CHECK(val->fusion() == this); - const auto ret = getLaunchConfig(type); - launch_configs_[type] = val; - return ret; - } - - // retrieve launch config value for given type; - Val* getLaunchConfig(LaunchConfigType type) { - const auto it = launch_configs_.find(type); - return it != launch_configs_.end() ? it->second : nullptr; - } - private: // Return an int that monotonically increases for each val/expr, some are // explicitly incremented by type. @@ -299,10 +276,6 @@ class TORCH_CUDA_API Fusion final { // Fusion inputs and outputs std::vector inputs_; std::vector outputs_; - - // values for launch configuration: - // compatible_flag, BlockDim.x/y/z, GridDim.x/y/z, shared_memory, - std::unordered_map launch_configs_; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp deleted file mode 100644 index 3843a967c45a..000000000000 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ /dev/null @@ -1,510 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -constexpr auto kCgNamespace = "CudaCodeGen"; -constexpr auto kKernelName = "kernel"; - -namespace { - -// See NOTE [ USE OF NVRTC AND DRIVER API ] -const at::cuda::NVRTC& nvrtc() { - return at::globalContext().getNVRTC(); -} - -int ceilDiv(const int a, const int b) { - return (a + b - 1) / b; -} - -std::pair codeGeneration(Fusion* fusion) { - std::stringstream str_stream; - str_stream << "namespace " << kCgNamespace << " {\n" - << code_template_tensor_struct << "\n" - << code_fp16_support << "\n" - << code_random_number_gen << "\n" - << code_helper_funcs << "\n" - << code_template_block_reduction << "\n" - << code_template_grid_reduction << "\n" - << code_template_block_broadcast << "\n"; - std::stringstream cdg; - GPULower gpulw(fusion); - gpulw.printKernel(str_stream, kKernelName); - str_stream << "\n} // namespace"; - - std::string func_name = std::string(kCgNamespace) + "::" + kKernelName; - return std::make_pair(func_name, str_stream.str()); -} - -bool validateKernelArgTensor( - const at::Tensor& arg, - const Val* param, - int device_index, - std::stringstream& msg) { - // Arg is a tensor. Param must be a tensor too. - if (*param->getValType() != ValType::TensorView) { - msg << "Argument is a tensor, but the parameter is not."; - return false; - } - - // Check the rank of the tensors. - size_t arg_dim = arg.dim(); - // Note: This requires current Fusion to be active. - size_t param_dim = TensorDomain::noReductions( - static_cast(param)->getRootDomain()) - .size(); - // see [Note - broadcast support in integration] - // Because of broadcasting support handled in integration, we relax the rank - // check as necessary. - if (arg_dim > param_dim) { - msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is " - << param_dim; - return false; - } - - if (arg.device().index() != device_index) { - msg << "Argument is on device that is not compiled for"; - return false; - } - // Check element type - at::ScalarType arg_data_type = arg.scalar_type(); - DataType param_data_type = *param->getDataType(); - bool match = false; - switch (arg_data_type) { - case at::ScalarType::Half: - match = param_data_type == DataType::Half; - break; - case at::ScalarType::Float: - match = param_data_type == DataType::Float; - break; - case at::ScalarType::Bool: - match = param_data_type == DataType::Bool; - break; - default: - msg << "Argument element type, " << arg_data_type - << ", is not supported."; - return false; - } - if (!match) - msg << "Argument element type is " << arg_data_type - << ", but the parameter is " << param_data_type; - return match; -} - -bool validateKernelArgScalar( - const c10::TypePtr& arg_type, - const Val* param, - std::stringstream& msg) { - if (!param->isScalar()) { - msg << "Argument is a scalar, but the parameter is not."; - return false; - } - DataType param_type = *param->getDataType(); - bool match = false; - switch (arg_type->kind()) { - case c10::TypeKind::IntType: - match = param_type == DataType::Int; - break; - case c10::TypeKind::FloatType: - match = param_type == DataType::Float; - break; - case c10::TypeKind::BoolType: - match = param_type == DataType::Bool; - break; - default: - match = false; - } - if (!match) { - msg << "Argument type is " << *arg_type << ", but the parameter is " - << param_type; - } - return match; -} - -bool validateKernelArg( - const c10::IValue& arg, - const Val* param, - int device_index, - std::stringstream& msg) { - if (arg.type()->kind() != c10::TypeKind::TensorType) { - return validateKernelArgScalar(arg.type(), param, msg); - } else { - return validateKernelArgTensor(arg.toTensor(), param, device_index, msg); - } -} - -void validateKernelArgs( - CudaKernel* entry, - const at::ArrayRef& inputs, - const std::vector& outputs) { - // This is necessary as we were traversing the fusion graph later in the check - FusionGuard fg(entry); - // Check inputs - TORCH_INTERNAL_ASSERT( - inputs.size() == entry->fusion()->inputs().size(), - "Wrong number of kernel inputs."); - for (size_t i = 0; i < inputs.size(); ++i) { - const IValue& arg = inputs[i]; - const Val* param = entry->fusion()->inputs()[i]; - std::stringstream msg; - TORCH_INTERNAL_ASSERT( - validateKernelArg(arg, param, entry->device(), msg), - "Input argument at position ", - i, - " is invalid; ", - msg.str()); - } - - TORCH_INTERNAL_ASSERT( - entry->fusion()->outputs().size() != 0, - "Kernel should have at least one output tensor."); - - TORCH_INTERNAL_ASSERT( - outputs.size() == entry->fusion()->outputs().size(), - "Wrong number of kernel outputs."); - for (size_t i = 0; i < outputs.size(); ++i) { - const at::Tensor& arg = outputs[i]; - const Val* param = entry->fusion()->outputs()[i]; - std::stringstream msg; - TORCH_INTERNAL_ASSERT( - validateKernelArgTensor(arg, param, entry->device(), msg), - "Output argument at position ", - i, - " is invalid; ", - msg.str()); - } -} - -size_t size(const dim3& d) { - return (size_t)d.x * (size_t)d.y * (size_t)d.z; -} - -dim3 dimensionOfReductionBlock( - const dim3& block_dim, - bool x_thread, - bool y_thread, - bool z_thread) { - return dim3{x_thread ? block_dim.x : 1, - y_thread ? block_dim.y : 1, - z_thread ? block_dim.z : 1}; -} - -int sizeOfReductionBlock( - const dim3& block_dim, - bool x_thread, - bool y_thread, - bool z_thread) { - return size( - dimensionOfReductionBlock(block_dim, x_thread, y_thread, z_thread)); -} - -// Returns the total number of reduction segments. -size_t numberOfReductionSegments( - const dim3& grid_dim, - bool x_block, - bool y_block, - bool z_block) { - return (x_block ? 1 : grid_dim.x) * (y_block ? 1 : grid_dim.y) * - (z_block ? 1 : grid_dim.z); -} - -std::array gridReductionTempBufferSizes( - CudaKernel* entry, - const dim3& grid_dim, - const dim3& block_dim) { - size_t buffer_size = 0; - size_t sync_flag_size = 0; - for (auto expr : entry->fusion()->exprs(true)) { - if (expr->getExprType() != ExprType::ReductionOp) - continue; - ReductionOp* rop = static_cast(expr); - auto domains = rop->getParallelReductionDomains(); - bool x_block = domains.find(ParallelType::BIDx) != domains.end(); - bool y_block = domains.find(ParallelType::BIDy) != domains.end(); - bool z_block = domains.find(ParallelType::BIDz) != domains.end(); - // No buffer needed unless it's a grid reduction - if (!x_block && !y_block && !z_block) - continue; - // Assumption here is that reduction along the block-parallel - // domains is done prior to this grid reduction, so those domains - // do not need to participate in the grid reductions - bool x_thread = domains.find(ParallelType::TIDx) == domains.end(); - bool y_thread = domains.find(ParallelType::TIDy) == domains.end(); - bool z_thread = domains.find(ParallelType::TIDz) == domains.end(); - auto rb_size = - sizeOfReductionBlock(block_dim, x_thread, y_thread, z_thread); - auto num_blocks = size(grid_dim); - auto element_size = dataTypeSize(*(rop->out()->getDataType())); - auto required_temp_buffer_size = num_blocks * rb_size * element_size; - buffer_size = std::max(buffer_size, required_temp_buffer_size); - auto flag_size = sizeof(unsigned) * - numberOfReductionSegments(grid_dim, x_block, y_block, z_block); - sync_flag_size = std::max(sync_flag_size, flag_size); - } - return {{buffer_size, sync_flag_size}}; -} - -} // namespace - -void compileKernel(CudaKernel* entry) { - // generating cuda code; - std::string code; - std::string func_name; - std::tie(func_name, code) = codeGeneration(entry->fusion()); - - static int32_t compiled_kernel_id = 0; - // We increment the id here instead of at the end of the function to avoid - // error during jit-compilation that would make debug message confusing. - compiled_kernel_id++; - const char* debug_env = getenv("PYTORCH_CUDA_FUSER_DEBUG"); - if (debug_env && atoi(debug_env)) { - std::cout << "\n==== codegen output for kernel: " << compiled_kernel_id - << " ====" << std::endl - << code << std::endl - << "====================================" << std::endl; - } - - // vvv NVRTC COMPILATION vvv - - // lazily construct context if non-existing yet; - CUcontext pctx = nullptr; - AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); - if (!pctx) { - std::unique_lock cudaFreeMutexLock( - *(c10::cuda::CUDACachingAllocator::getFreeMutex())); - cudaFree(nullptr); - } - - // set device for the operation; - at::cuda::set_device(entry->device()); - - const auto prop = at::cuda::getCurrentDeviceProperties(); - int nvrtc_major, nvrtc_minor; - AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor)); - - // Short-circuits if NVRTC version too low - TORCH_INTERNAL_ASSERT(nvrtc_major >= 6); - // Major and minor is determined by device properties and - // possibly "downcompiled" to a lower (compatible) compute architecture - // based on the NVRTC version - int major, minor; - major = prop->major; - minor = prop->minor; - nvrtcProgram program; - AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( - &program, code.c_str(), nullptr, 0, nullptr, nullptr)); - ResourceGuard holdProgram( - [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); - - const std::string compute = "--gpu-architecture=compute_" + - std::to_string(major) + std::to_string(minor); - const std::vector args = { - "--std=c++14", compute.c_str(), "-default-device"}; - - nvrtc().nvrtcAddNameExpression(program, func_name.c_str()); - const auto result = - nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); - if (result != NVRTC_SUCCESS) { - size_t logsize; - nvrtc().nvrtcGetProgramLogSize(program, &logsize); - std::vector log(logsize); - nvrtc().nvrtcGetProgramLog(program, log.data()); - - TORCH_INTERNAL_ASSERT( - false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data()); - } - const char* lowered_kernel_name; - nvrtc().nvrtcGetLoweredName(program, func_name.c_str(), &lowered_kernel_name); - - AT_CUDA_NVRTC_CHECK(result); - size_t ptx_size; - AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size)); - std::vector ptx; - ptx.resize(ptx_size); - AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data())); - - // TODO: We do go through different code path, should investigate whether this - // has an impact on generated binary. - const char* prefix_env = getenv("PYTORCH_CUDA_FUSER_CUBIN"); - if (prefix_env) { - // Output ptx file - std::stringstream ptx_file_name; - ptx_file_name << prefix_env << "_" << compiled_kernel_id << ".ptx"; - std::ofstream myPtxFile(ptx_file_name.str().c_str(), std::ios::out); - if (myPtxFile.is_open()) { - myPtxFile.write(ptx.data(), ptx.size()); - myPtxFile.close(); - } - - CUlinkState linkState; - - AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkCreate(0, nullptr, nullptr, &linkState)); - AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkAddData( - linkState, - CU_JIT_INPUT_PTX, - ptx.data(), - ptx_size, - "compiling PTX", - 0, - nullptr, - nullptr)); - size_t cubinSize; - void* cubin; - AT_CUDA_DRIVER_CHECK(nvrtc().cuLinkComplete(linkState, &cubin, &cubinSize)); - - // Output binary file - std::stringstream cubin_file_name; - cubin_file_name << prefix_env << "_" << compiled_kernel_id << ".cubin"; - std::ofstream myCubinFile( - cubin_file_name.str().c_str(), std::ios::out | std::ios::binary); - if (myCubinFile.is_open()) { - myCubinFile.write(static_cast(cubin), cubinSize); - myCubinFile.close(); - } - - // load compiled cubin - AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(entry->module(), cubin)); - } else { - // load ptx directly - AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(entry->module(), ptx.data())); - } - AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleGetFunction( - entry->function(), *entry->module(), lowered_kernel_name)); -} - -void runKernel( - CudaKernel* entry, - const at::ArrayRef inputs, - const std::vector& outputs, - const c10::optional& broadcasted_size) { - validateKernelArgs(entry, inputs, outputs); - - const auto prior_device = at::cuda::current_device(); - at::cuda::set_device(entry->device()); - auto stream = at::cuda::getCurrentCUDAStream(); - - TORCH_INTERNAL_ASSERT(!outputs.empty(), "No outputs set for test kernel."); - const size_t numel = outputs[0].numel(); - - KernelArgumentHolder kernel_args; - - // Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O - // allocated here from the subgraph could be, and very likely are, different - // from I/O expected by the generated CUDA kernel. - for (auto& input : inputs) { - if (input.isTensor()) { - kernel_args.push(input.toTensor(), broadcasted_size); - } else { - kernel_args.push(input); - } - } - - for (auto& output : outputs) { - kernel_args.push(output); - } - - Fusion* fusion = entry->fusion(); - FusionGuard fg(fusion); - EvaluationContext eval_context(fusion); - for (int i = 0; i < (int)inputs.size(); i++) { - if (inputs[i].isTensor()) { - ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_size); - int nDims = ess.sizes.size(); - TensorView* tv = fusion->inputs()[i]->as(); - for (int j = 0; j < nDims; j++) { - eval_context.bind(tv->getRootDomain()[j]->extent(), ess.sizes[j]); - } - } - } - - auto expr_eval_fn = [&](LaunchConfigType type) { - const auto val = ExpressionEvaluator::evaluate( - fusion->getLaunchConfig(type), &eval_context); - TORCH_CHECK( - val.has_value(), "scheduler didn't bind launch configs properly"); - return val.value(); - }; - - const int nBlocks_x = expr_eval_fn(LaunchConfigType::BIDx); - const int nBlocks_y = expr_eval_fn(LaunchConfigType::BIDy); - const int nBlocks_z = expr_eval_fn(LaunchConfigType::BIDz); - const auto nThreadx = expr_eval_fn(LaunchConfigType::TIDx); - const auto nThready = expr_eval_fn(LaunchConfigType::TIDy); - const auto nThreadz = expr_eval_fn(LaunchConfigType::TIDz); - const auto shared_memory = expr_eval_fn(LaunchConfigType::SharedMemory); - - // TODO: this probably won't work for us. - if (entry->hasRNG()) { - std::pair philox_engine_inputs; - const auto rand_offset = - 4 * (std::ceil(numel / (4.0 * 128 * nBlocks_x)) + 1); - auto gen = at::cuda::detail::getDefaultCUDAGenerator(); - { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen.mutex()); - philox_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( - rand_offset); - } - kernel_args.push(philox_engine_inputs.first); - kernel_args.push(philox_engine_inputs.second); - } - - dim3 grid_dim(nBlocks_x, nBlocks_y, nBlocks_z); - dim3 block_dim(nThreadx, nThready, nThreadz); - // When the kernel has global reductions, the kernel needs two - // additional temporary buffers, one for intermediate results and - // another for synchronization among thread blocks. - if (entry->fusion()->hasGridReduction()) { - auto temp_buf_type = at::kFloat; - auto temp_buf_sizes = - gridReductionTempBufferSizes(entry, grid_dim, block_dim); - auto options = - at::TensorOptions().dtype(temp_buf_type).device(at::kCUDA, 0); - at::Tensor reduction_work_buffer = at::empty( - {(long)(temp_buf_sizes[0] / c10::elementSize(temp_buf_type))}, options); - kernel_args.push(reduction_work_buffer); - at::Tensor sync_flags = at::zeros( - {(long)(temp_buf_sizes[1] / c10::elementSize(temp_buf_type))}, options); - kernel_args.push(sync_flags); - } - - // launch kernel; - AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( - *entry->function(), - nBlocks_x, - nBlocks_y, - nBlocks_z, - nThreadx, - nThready, - nThreadz, - shared_memory, - stream, - kernel_args.getBuffer(), - nullptr)); - - // Resets device (see at::DeviceGuard notes above) - at::cuda::set_device(prior_device); -} - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel.h b/torch/csrc/jit/codegen/cuda/kernel.h deleted file mode 100644 index c6c566d16e15..000000000000 --- a/torch/csrc/jit/codegen/cuda/kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include - -/* - * The exposed APIs in this file is used by manager.h/cpp - * - * code here handles CUDA code generation and execution from Fusion IR. - * NVRTC is used for kernel compilation. CUDA Driver API is used to load and - * execute compiled kernel. - * - * A stringify trick is used to unify the IO data structure for kernel - * execution. We stringify the data structure and assert it direclty in the - * generated CUDA source to avoid runtime search of header files. - * The header file is included twice: one time as a c++ code to allow host code - * to prepare IO data; the other time for stringify. - */ - -namespace torch { -namespace jit { -namespace fuser { -namespace cuda { - -// compile Fusion to CUDA functions: -// 1. JIT compilation via nvrtc to generate CUDA c++ kernel code; -// 2. CUDA Drive API to load CUDA c++ kernel code as function_; -TORCH_CUDA_API void compileKernel(CudaKernel* entry); - -// run loaded kernel through Function. -// inputs/outputs is given in the sense of a PyTorch JIT ir node. This function -// wraps IO data structure for tensors on host. -TORCH_CUDA_API void runKernel( - CudaKernel* entry, - const at::ArrayRef inputs, - const std::vector& outputs, - const c10::optional& broadcasted_size = c10::nullopt); - -// Facility API to run kernel in tests. -TORCH_CUDA_API void runTestKernel( - CudaKernel* entry, - const at::ArrayRef inputs, - const std::vector& outputs); - -} // namespace cuda -} // namespace fuser -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index 4a245313fb3f..88384ac13cd2 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -1,48 +1,21 @@ #include -#include -#include -/* - */ +// TODO: This class is dead at the moment, but we need to figure out a generic +// cacheing system that will suite our needs. namespace torch { namespace jit { namespace fuser { namespace cuda { -at::optional CudaKernelCache::getKernelPtr( - const at::ArrayRef inputs) { - for (auto& cuda_kernel : kernels_) { - // bound input sizes - Fusion* fusion = cuda_kernel.fusion(); - FusionGuard fg(fusion); - EvaluationContext eval_context(fusion); - for (int i = 0; i < (int)inputs.size(); i++) { - if (inputs[i].isTensor()) { - ExtractSizeStride ess(inputs[i].toTensor()); - const int nDims = ess.sizes.size(); - TensorView* tv = fusion->inputs()[i]->as(); - for (int j = 0; j < nDims; j++) { - eval_context.bind(tv->getRootDomain()[j]->extent(), ess.sizes[j]); - } - } - } - - const auto val = ExpressionEvaluator::evaluate( - fusion->getLaunchConfig(LaunchConfigType::Compatible), &eval_context); - TORCH_INTERNAL_ASSERT( - val.has_value(), "scheduler didn't bind launch configs properly"); - if (val.value()) { - return &cuda_kernel; - } - } - return at::nullopt; -} - -CudaKernel* CudaKernelCache::allocateKernelInCache( - const at::ArrayRef inputs) { - kernels_.emplace_back(); - return &kernels_.back(); +FusionExecutorCache::FusionExecutorCache( + Fusion* fusion, + CompileOptions options) { + TORCH_INTERNAL_ASSERT( + entry == nullptr, + "At this time FusionExecutorCache only supports one entry."); + entry = new FusionExecutor(); + entry->compileFusion(fusion, options); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.h b/torch/csrc/jit/codegen/cuda/kernel_cache.h index 31c5ef52847a..6a73cadce34c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.h +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.h @@ -1,119 +1,36 @@ #pragma once +#include +#include + #include #include -#include - namespace torch { namespace jit { namespace fuser { namespace cuda { -// Go through a tensor, and grab it's sizes/strides potentially broadcasted -struct ExtractSizeStride { - std::vector sizes; - std::vector strides; +// Given a particular torchscript string produced by std::string +// Graph::toString(bool print_source_locations) const; cache a kernel that can +// run it. Assume contiguity information is included in the string. - // TODO: broadcasted_size should be handled in codegen directly instead of at - // the integration limbo. - explicit ExtractSizeStride( - const at::Tensor& val, - c10::optional broadcasted_size = c10::nullopt) { - if (broadcasted_size) { - // [Note - broadcast support in integration] - // PyTorch follows numpy broadcasting rule. - // (https://numpy.org/doc/stable/user/basics.broadcasting.html) - // - // So in case where the rank of two operators differ, we align them on - // the higher dimensions, hence the offset o_dim-b_dim to the index here. - const int b_dim = static_cast(broadcasted_size->size()); - const int o_dim = static_cast(val.dim()); - TORCH_CHECK(b_dim >= o_dim); - for (int i = 0; i < b_dim; i++) { - sizes.push_back(broadcasted_size->at(i)); - const int index = i + o_dim - b_dim; - if (index < 0) { - strides.push_back(0); - } else if (val.sizes()[index] == sizes[i]) { - strides.push_back(val.strides()[index]); - } else { - TORCH_CHECK( - val.sizes()[index] == 1, - "Not compatible dimension size for broadcast"); - strides.push_back(0); - } - } - } else { - const auto o_dim = val.dim(); - for (decltype(val.dim()) i{0}; i < o_dim; i++) { - sizes.push_back(val.sizes()[i]); - strides.push_back(val.strides()[i]); - } - } - } -}; +// There are many things to figure out with this cacheing class, for now we will +// keep it very simple, and take in functionality as complexity grows. -class CudaKernel { +// TODO: Figure out how we want to cache based on heuristics, should probably +// use something similar. Heuristics may also return a LaunchParams object. +// TODO: Validate it is included in the string. +class FusionExecutorCache { public: - void setFusionPtr(std::unique_ptr fusion) { - fusion_ = std::move(fusion); - } - - Fusion* fusion() { - return fusion_.get(); - } - - const Fusion* fusion() const { - return fusion_.get(); - } - - CUmodule* module() { - return &module_; - } - - CUfunction* function() { - return &function_; + FusionExecutor* getExecutor() const { + return entry; } - int16_t device() const { - return device_; - } - - void setDevice(int16_t device) { - device_ = device; - } - - bool hasRNG() const { - if (fusion_) { - FusionGuard fg(fusion_.get()); - return fusion_->hasRNG(); - } - return false; - } - - private: - int16_t device_; - CUmodule module_; - CUfunction function_; - std::unique_ptr fusion_; -}; - -class CudaKernelCache { - public: - CudaKernelCache() = default; - - at::optional getKernelPtr( - const at::ArrayRef inputs); - CudaKernel* allocateKernelInCache(const at::ArrayRef inputs); + FusionExecutorCache(Fusion* fusion, CompileOptions options); private: - // TODO: In theory we should assume contiguity remain constant across runs - // (job for BailOut node from profiling executor). In reality we might - // want to be safe and cache on that as well. - // Assuming constant nDims. Cache of kernels targetting different tensor size; - // We should flatten - std::vector kernels_; + FusionExecutor* entry = nullptr; }; } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index d26039adb749..7dc2d8cf5f47 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -8,26 +8,43 @@ #include #include #include -#include #include +#include -#include #include +#include + +#include + namespace torch { namespace jit { namespace fuser { namespace cuda { namespace { -// CudaFusionManager holds compiled `CudaKernel` and handles all interfacing +c10::Device getDevice(const at::ArrayRef& inputs) { + // find device in inputs. + for (const auto& input : inputs) { + if (input.isTensor()) { + auto dev = input.toTensor().device(); + TORCH_INTERNAL_ASSERT( + dev.is_cuda(), "Could only fuser operations on cuda device"); + return dev; + } + } + TORCH_INTERNAL_ASSERT( + false, "Could not detect device of inputs to a fusion."); +} + +// CudaFusionManager holds a FusionExecutor and handles all interfacing // including compilation and execution. // // We cache two maps here: // a. string of graph -> kernel_id -// b. kernel_id -> CudaKernel +// b. kernel_id -> FusionExecutor // -// This allows CudaKernel reuse across nodes; +// This allows FusionExecutor reuse across nodes; class CudaFusionManager { public: static CudaFusionManager& getManager() { @@ -53,17 +70,7 @@ class CudaFusionManager { // create new graph_cache_ entry; if (graph_cache_.count(repr) == 0) { int32_t kernel_id = getNextUniqueID(); - graph_cache_[repr] = kernel_id; - - // create entry for cached kernel; - // Note: use make_pair instead of uniform initialization list here since - // it doesn't work under some env that we still support. - // eg. cuda9.2 + gcc5.4 - kernel_cache_.insert(std::make_pair(kernel_id, CudaKernelCache())); - - // TODO: we should compile here using profiled information: - // size (range) / stride (contiguity) } return graph_cache_[repr]; }; @@ -74,46 +81,33 @@ class CudaFusionManager { const at::ArrayRef inputs, const std::vector& outputs) { std::lock_guard guard(mutex_); - TORCH_CHECK( - kernel_cache_.count(kernel_id) != 0, "kernel id not recognized"); - if (auto cuda_kernel_opt = kernel_cache_[kernel_id].getKernelPtr(inputs)) { + if (kernel_cache_.find(kernel_id) != kernel_cache_.end()) { // TODO: update launch config for specific sizes; // maybe we should store it in CudaKernel and compute it later - runKernel(*cuda_kernel_opt, inputs, outputs); - } else { - // TODO: this should somehow be done after kernel compilation. - // we will want compileKernel to return a heuristic - auto cuda_kernel = kernel_cache_[kernel_id].allocateKernelInCache(inputs); + FusionExecutor* fe = kernel_cache_[kernel_id]; + fe->runFusion(inputs, outputs); + + } else { // lower torch::jit::Graph to torch::jit::fuser::cuda::fusion // TODO: pass contiguity infor as well as size req, so we can apply proper // transform to computation - cuda_kernel->setFusionPtr(parseJitIR(graph)); - TORCH_INTERNAL_ASSERT( - cuda_kernel->fusion() != nullptr, - "parser failed to construct a fusion from PyTorch JIT graph\n"); + auto fusion = parseJitIR(graph); // TODO: update the API to let `scheduleFusion` consume & return a fusion // magic scheduler updates fusion instance via transformation and setup // launch configurations; - scheduleFusion(cuda_kernel->fusion(), inputs); - - // find device in inputs. - for (const auto& input : inputs) { - if (input.isTensor()) { - const auto& device = input.toTensor().device(); - TORCH_INTERNAL_ASSERT( - device.is_cuda(), "Could only fuser operations on cuda device"); - cuda_kernel->setDevice(device.index()); - break; - } - } + Fusion fusion_copy = Fusion(*fusion); + scheduleFusion(&fusion_copy, inputs); - // NVRTC compile kernel - compileKernel(cuda_kernel); + CompileOptions options; + options.device = getDevice(inputs); - runKernel(cuda_kernel, inputs, outputs); + auto fe = new FusionExecutor(); + fe->compileFusion(fusion.get(), options); + kernel_cache_[kernel_id] = fe; + fe->runFusion(inputs, outputs); } } @@ -130,7 +124,7 @@ class CudaFusionManager { }; std::unordered_map graph_cache_; - std::unordered_map kernel_cache_; + std::unordered_map kernel_cache_; int32_t next_unique_id_ = 0; }; @@ -177,46 +171,7 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { } ShapeTypePropagate(graph); } - /* - // TODO: Delete the shape inference here once we switch to - // ExpressionEvaluator to allocate outputs - // shape inference in graph to allocate outputs - // update shape information per the new inputs; - EraseShapeInformation(shape_inf_graph); - for (size_t i = 0; i < nInputs; i++) { - shape_inf_graph->inputs()[i]->setType(inputs[i].type()); - } - // shape inference - ShapeTypePropagate(shape_inf_graph); - - // we need to construct outputs; - std::vector outputs; - for (const auto* output : shape_inf_graph->outputs()) { - const auto type = output->type()->expect(); - // Expect output to be tensor; - TORCH_CHECK( - type && type->isComplete(), - "Complete TensorType for output is expected."); - - const auto device = *(type->device()); - const auto scalar_type = *(type->scalarType()); - - auto options = at::TensorOptions() - .dtype(scalar_type) - .layout(at::kStrided) - .device(device) - .requires_grad(type->requires_grad()); - - // TODO: We should infer output shape from `inputs` - const auto sizes = extractSizes(type); - const auto strides = extractStrides(type); - - const auto tensor = at::empty_strided(sizes, strides, options); - outputs.push_back(tensor); - } - CudaFusionManager::getManager().runFusionNode( - kernel_id, graph, inputs, outputs); - */ + FusionExecutor executor; auto fusion = parseJitIR(graph); scheduleFusion(fusion.get(), inputs); diff --git a/torch/csrc/jit/codegen/cuda/parser.h b/torch/csrc/jit/codegen/cuda/parser.h index 9ff8252d3d94..f83cb8e8808b 100644 --- a/torch/csrc/jit/codegen/cuda/parser.h +++ b/torch/csrc/jit/codegen/cuda/parser.h @@ -4,7 +4,6 @@ #include #include -#include /* * This file handles Parsing PyTorch jit ir; diff --git a/torch/csrc/jit/codegen/cuda/scheduler.cpp b/torch/csrc/jit/codegen/cuda/scheduler.cpp index 1a495f36bfd9..a85c0748604d 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler.cpp @@ -6,6 +6,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -145,27 +147,6 @@ bool scheduleFusion(Fusion* fusion, const at::ArrayRef inputs) { numel = mul(numel, out0->axis(i)->rawExtent()); } } - if (fcd_reduction) { - // assuming all output to be the same shape; - fusion->setLaunchConfig( - LaunchConfigType::TIDx, new Int(kFcdReductionThreadX)); - fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDx, numel); - fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); - } else { - fusion->setLaunchConfig( - LaunchConfigType::TIDx, new Int(kNonFcdReductionThreadX)); - fusion->setLaunchConfig( - LaunchConfigType::TIDy, new Int(kNonFcdReductionThreadY)); - fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDx, numel); - fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); - } - fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0)); } else { // Run through outputs, grab all inputs of outputs // squeeze with computeAt to set overall structure. @@ -219,16 +200,6 @@ bool scheduleFusion(Fusion* fusion, const at::ArrayRef inputs) { numel = mul(numel, out0->axis(i)->rawExtent()); } } - Val* tid_x = new Int(kPwThreadX); - Val* bid_x = numel; - fusion->setLaunchConfig(LaunchConfigType::TIDx, tid_x); - fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDx, bid_x); - fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0)); } return true; } @@ -551,21 +522,6 @@ bool scheduleReduction(Fusion* fusion, const at::ArrayRef inputs) { red_tv->axis(-2)->parallelize(ParallelType::BIDx); } } - - // Communicate Blocking for Kernel Launch - // TODO: This will be replaced in favor of passing blocking - // args in the future - fusion->setLaunchConfig( - LaunchConfigType::TIDx, new Int(rparams.block_dim_x_)); - fusion->setLaunchConfig( - LaunchConfigType::TIDy, new Int(rparams.block_dim_y_)); - fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::BIDx, new Int(rparams.grid_dim_x_)); - fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(rparams.grid_dim_y_)); - fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); - fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0)); - fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); - return true; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler.h b/torch/csrc/jit/codegen/cuda/scheduler.h index efe199fcf270..ae5c29e57e9e 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler.h +++ b/torch/csrc/jit/codegen/cuda/scheduler.h @@ -2,7 +2,6 @@ #include #include -#include namespace torch { namespace jit { From 5b340ee6ab91cf485e463f20c0c91a65e7b36b7a Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 20 Jul 2020 09:11:38 -0400 Subject: [PATCH 2/3] Minor fixes, clang-tidy re-enable type prop. --- torch/csrc/jit/codegen/cuda/fusion.cpp | 2 ++ torch/csrc/jit/codegen/cuda/manager.cpp | 3 +++ 2 files changed, 5 insertions(+) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index d83e005b5ba8..9a8db4af747c 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -277,6 +277,7 @@ void Fusion::addOutput(Val* const output) { bool Fusion::inFusion(const Statement* stmt) const { bool infusion = stmt->fusion() == this; + // NOLINT Statement* nonconst_stmt = const_cast(stmt); if (stmt->isExpr()) @@ -471,6 +472,7 @@ Expr* Fusion::origin(Val* val) const { const Expr* Fusion::origin(const Val* val) const { assertInFusion(val, "Cannot dettect the origin of val, "); + // NOLINT auto it = origin_.find(const_cast(val)); if (it == origin_.end()) return nullptr; diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index 7dc2d8cf5f47..876294cb2827 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -94,6 +94,8 @@ class CudaFusionManager { // TODO: pass contiguity infor as well as size req, so we can apply proper // transform to computation auto fusion = parseJitIR(graph); + std::cout << "0" << std::endl; + fusion->printMath(); // TODO: update the API to let `scheduleFusion` consume & return a fusion // magic scheduler updates fusion instance via transformation and setup @@ -174,6 +176,7 @@ void runCudaFusionGroup(const Node* fusion_node, Stack& stack) { FusionExecutor executor; auto fusion = parseJitIR(graph); + fusion->printMath(); scheduleFusion(fusion.get(), inputs); executor.compileFusion(fusion.get()); auto outputs = executor.runFusion(inputs); From efa9a9d0d19e3bd4adae0120902d38503779bf65 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 20 Jul 2020 09:31:40 -0400 Subject: [PATCH 3/3] Try to disable clang-tidy for const_cast. --- torch/csrc/jit/codegen/cuda/fusion.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 9a8db4af747c..90682c298b80 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -277,8 +277,7 @@ void Fusion::addOutput(Val* const output) { bool Fusion::inFusion(const Statement* stmt) const { bool infusion = stmt->fusion() == this; - // NOLINT - Statement* nonconst_stmt = const_cast(stmt); + Statement* nonconst_stmt = const_cast(stmt); // NOLINT if (stmt->isExpr()) infusion &= @@ -472,8 +471,7 @@ Expr* Fusion::origin(Val* val) const { const Expr* Fusion::origin(const Val* val) const { assertInFusion(val, "Cannot dettect the origin of val, "); - // NOLINT - auto it = origin_.find(const_cast(val)); + auto it = origin_.find(const_cast(val)); // NOLINT if (it == origin_.end()) return nullptr; return it->second;