diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2e36a97e92b3..3450eea1bb6e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -452,6 +452,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp + ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 0edcac55a3df..0082e234ab46 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -370,6 +371,8 @@ void testGPU_FusionClear() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(1)->parallelize(ParallelType::Unroll); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); } // 2. Clear the IR @@ -382,6 +385,8 @@ void testGPU_FusionClear() { TORCH_CHECK(fusion.inputs().empty()); TORCH_CHECK(fusion.outputs().empty()); + TORCH_CHECK(fusion.launch_configs().empty()); + TORCH_CHECK(!fusion.hasReduction()); TORCH_CHECK(!fusion.hasBlockReduction()); TORCH_CHECK(!fusion.hasGridReduction()); @@ -415,8 +420,11 @@ void testGPU_FusionClear() { at::Tensor input2 = at::randn_like(input1); at::Tensor output = at::empty_like(input1); + fuser::cuda::scheduleFusion(&fusion, {input1, input2}); torch::jit::fuser::cuda::compileKernel(&prog); - torch::jit::fuser::cuda::runTestKernel(&prog, {input1, input2}, {output}); + prog.device_ = 0; + torch::jit::fuser::cuda::runKernel( + &prog, {input1, input2}, {output}, output.sizes().vec()); at::Tensor tv2_ref = input2 + 2.0; at::Tensor output_ref = input1 + tv2_ref; @@ -449,6 +457,10 @@ void testGPU_FusionCopy() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + original_fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); + original_fusion.setLaunchConfig( + LaunchConfigType::BIDx, tv3->axis(0)->rawExtent()); } // Test copy before lowering @@ -519,6 +531,9 @@ void testGPU_FusionMove() { tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(-1)->parallelize(ParallelType::TIDx); + + fusion.setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); + fusion.setLaunchConfig(LaunchConfigType::BIDx, tv3->axis(0)->rawExtent()); } std::stringstream original_ir; @@ -1010,6 +1025,10 @@ void testGPU_FusionParser() { prog.block(32); prog.device_ = 0; fuser::cuda::parseJitIR(g, &prog); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input1 = at::randn({16}, options); + at::Tensor input2 = at::randn({16}, options); + fuser::cuda::scheduleFusion(prog.fusion_.get(), {input1, input2}); // CONSIDER: // 1. this can be moved to a dedicated "golden" file @@ -1018,32 +1037,32 @@ void testGPU_FusionParser() { __global__ void CUDAGeneratedKernel(Tensor T0, Tensor T1, Tensor T3){ float T2[4]; if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i40 = 0; i40 < 4; ++i40 ) { - T2[ i40 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i49 = 0; i49 < 4; ++i49 ) { + T2[ i49 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } else { - for(size_t i40 = 0; i40 < 4; ++i40 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T2[ i40 ] - = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] - * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i40 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; + for(size_t i49 = 0; i49 < 4; ++i49 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T2[ i49 ] + = T0[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ] + * T1[ ( ( ( ( ( blockIdx.x * 4 ) + i49 ) * 128 ) + threadIdx.x ) * T1.stride[0] ) ]; } } } if ( ( ( ( ( ( blockIdx.x * 4 ) + ( 4 - 1 ) ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - for(size_t i41 = 0; i41 < 4; ++i41 ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i41 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i50 = 0; i50 < 4; ++i50 ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i50 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } else { - for(size_t i41 = 0; i41 < 4; ++i41 ) { - if ( ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { - T3[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] - = T2[ i41 ] - * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i41 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; + for(size_t i50 = 0; i50 < 4; ++i50 ) { + if ( ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) < T3.size[0] ) ) { + T3[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T3.stride[0] ) ] + = T2[ i50 ] + * T0[ ( ( ( ( ( blockIdx.x * 4 ) + i50 ) * 128 ) + threadIdx.x ) * T0.stride[0] ) ]; } } } diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index b194e4f16928..bb047cfdb20b 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -335,6 +335,7 @@ libtorch_cuda_sources = [ "torch/csrc/jit/codegen/cuda/dispatch.cpp", "torch/csrc/jit/codegen/cuda/expr_evaluator.cpp", "torch/csrc/jit/codegen/cuda/fusion.cpp", + "torch/csrc/jit/codegen/cuda/scheduler.cpp", "torch/csrc/jit/codegen/cuda/graph_fuser.cpp", "torch/csrc/jit/codegen/cuda/index_compute.cpp", "torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp", diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 141c12d1d7b1..8a0af3be5ae3 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -79,6 +79,8 @@ void swap(Fusion& a, Fusion& b) noexcept { swap(a.inputs_, b.inputs_); swap(a.outputs_, b.outputs_); + swap(a.launch_configs_, b.launch_configs_); + // Fixup the Statement::fusion_ links for a for (auto val : a.val_set_) { val->fusion_ = &a; @@ -138,6 +140,11 @@ Fusion::Fusion(const Fusion& other) { inputs_ = ir_cloner.clone(other.inputs_); outputs_ = ir_cloner.clone(other.outputs_); + + for (const auto& kv : other.launch_configs_) { + auto val = ir_cloner.clone(kv.second); + launch_configs_.insert({kv.first, val}); + } } Fusion::Fusion(Fusion&& other) noexcept { @@ -189,6 +196,8 @@ void Fusion::clear() noexcept { inputs_.clear(); outputs_.clear(); + + launch_configs_.clear(); } void Fusion::removeExpr(Expr* expr) { diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index 684dc915585e..7333c1003c8f 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -241,12 +241,30 @@ class TORCH_CUDA_API Fusion final { return outputs_; } + const auto& launch_configs() const { + return launch_configs_; + } + bool hasInput(const Val* val) const; bool hasOutput(const Val* val) const; void replaceInput(Val* replace, Val* with); void replaceOutput(Val* replace, Val* with); + // set new launch config value for given type; return the previous value; + const Val* setLaunchConfig(LaunchConfigType type, Val* val) { + TORCH_CHECK(val->fusion() == this); + const auto ret = getLaunchConfig(type); + launch_configs_[type] = val; + return ret; + } + + // retrieve launch config value for given type; + Val* getLaunchConfig(LaunchConfigType type) { + const auto it = launch_configs_.find(type); + return it != launch_configs_.end() ? it->second : nullptr; + } + private: // Return an int that monotonically increases for each val/expr, some are // explicitly incremented by type. @@ -281,6 +299,10 @@ class TORCH_CUDA_API Fusion final { // Fusion inputs and outputs std::vector inputs_; std::vector outputs_; + + // values for launch configuration: + // compatible_flag, BlockDim.x/y/z, GridDim.x/y/z, shared_memory, + std::unordered_map launch_configs_; }; } // namespace fuser diff --git a/torch/csrc/jit/codegen/cuda/kernel.cpp b/torch/csrc/jit/codegen/cuda/kernel.cpp index 6d7a09163360..24763f5bef19 100644 --- a/torch/csrc/jit/codegen/cuda/kernel.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -544,31 +545,6 @@ void runKernel( // Naive launch config; const size_t numel = outputs[0].numel(); - int blocks = 1; - int thread_x = 1; - int thread_y = 1; - if (!entry->reduction_axes_.empty()) { - // TODO: MAJOR HACK! Expr evaluation makes launch configuration much easier - blocks = numel; - // Translated to `fcd_reduction` - if (entry->reduction_axes_.back() == - outputs[0].dim() + ((int)entry->reduction_axes_.size()) - 1) { - thread_x = kFcdReductionThreadX; - thread_y = 1; - } else { - thread_x = kNonFcdReductionThreadX; - thread_y = kNonFcdReductionThreadY; - } - } else { - // TODO: we can't randomly clap down this until we got striding. - blocks = ceilDiv(numel, kPwThreadX * entry->unroll_factor_); - thread_x = kPwThreadX; - thread_y = 1; - } - const auto nBlocks = blocks; - const auto nThreadx = thread_x; - const auto nThready = thread_y; - KernelArgumentHolder kernel_args; // Naive I/O setup, I'm ignoring all the potential transformation (i.e. I/O @@ -586,10 +562,41 @@ void runKernel( kernel_args.push(output); } + Fusion* fusion = entry->fusion_.get(); + FusionGuard fg(fusion); + EvaluationContext eval_context(fusion); + for (int i = 0; i < (int)inputs.size(); i++) { + if (inputs[i].isTensor()) { + ExtractSizeStride ess(inputs[i].toTensor(), broadcasted_shape); + int nDims = ess.sizes.size(); + TensorView* tv = fusion->inputs()[i]->as(); + for (int j = 0; j < nDims; j++) { + eval_context.bind(tv->getRootDomain()[j]->extent(), ess.sizes[j]); + } + } + } + + auto expr_eval_fn = [&](LaunchConfigType type) { + const auto val = ExpressionEvaluator::evaluate( + fusion->getLaunchConfig(type), &eval_context); + TORCH_CHECK( + val.has_value(), "scheduler didn't bind launch configs properly"); + return val.value(); + }; + + const int nBlocks_x = expr_eval_fn(LaunchConfigType::BIDx); + const int nBlocks_y = expr_eval_fn(LaunchConfigType::BIDy); + const int nBlocks_z = expr_eval_fn(LaunchConfigType::BIDz); + const auto nThreadx = expr_eval_fn(LaunchConfigType::TIDx); + const auto nThready = expr_eval_fn(LaunchConfigType::TIDy); + const auto nThreadz = expr_eval_fn(LaunchConfigType::TIDz); + const auto shared_memory = expr_eval_fn(LaunchConfigType::SharedMemory); + // TODO: this probably won't work for us. if (entry->has_random_) { std::pair philox_engine_inputs; - const auto rand_offset = 4 * (std::ceil(numel / (4.0 * 128 * nBlocks)) + 1); + const auto rand_offset = + 4 * (std::ceil(numel / (4.0 * 128 * nBlocks_x)) + 1); auto gen = at::cuda::detail::getDefaultCUDAGenerator(); { // See Note [Acquire lock when using random generators] @@ -605,13 +612,13 @@ void runKernel( // launch kernel; AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( entry->function_, - nBlocks, - 1, - 1, + nBlocks_x, + nBlocks_y, + nBlocks_z, nThreadx, nThready, - 1, - 0, + nThreadz, + shared_memory, stream, kernel_args.getBuffer(), nullptr)); diff --git a/torch/csrc/jit/codegen/cuda/manager.cpp b/torch/csrc/jit/codegen/cuda/manager.cpp index e8b029cd4835..888df5afcc87 100644 --- a/torch/csrc/jit/codegen/cuda/manager.cpp +++ b/torch/csrc/jit/codegen/cuda/manager.cpp @@ -1,7 +1,9 @@ #include #include +#include #include #include +#include #include #include #include @@ -103,6 +105,7 @@ class CudaFusionManager { // 1. device; // 2. launch config; parseJitIR(graph, cuda_kernel.value()); + scheduleFusion(cuda_kernel.value()->fusion_.get(), inputs); // find device in inputs. for (const auto& input : inputs) { diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index e266f0a94097..23d1339a0b5d 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -86,9 +86,6 @@ class IrParser { MergeQueryFuncPtr merge_f_; }; - private: - static const int unroll_factor = 4; - public: IrParser(std::shared_ptr graph, CudaKernel* cuda_kernel) : graph_(std::move(graph)), cuda_kernel_(cuda_kernel) { @@ -173,7 +170,6 @@ class IrParser { for (auto jit_output : block->outputs()) { TensorView* out = static_cast(value_map_[jit_output->unique()]); - // demote output dtype to be match PyTorch JIT graph. auto tensor_type = jit_output->type()->cast(); TORCH_INTERNAL_ASSERT( @@ -182,133 +178,7 @@ class IrParser { // No need to update value_map_ after this point. out = static_cast(castOp(DataType::Half, out)); } - cuda_kernel_->fusion_->addOutput(out); - - // TODO: has_reduction for scheudling should be done on a per output - // tensor basis. - if (has_reduction) { - // TODO: this scheduling only works for a single reduction operation in - // the fusion, in this case we can coalesc all reduction axes and - // merge them together. (same applies to iteration axes) - // TODO: does this work for multiple outputs? - - // query if fastest changing dimension (FCD) is a reduction - fcd_reduction = out->axis((int)out->nDims() - 1)->isReduction(); - - // TODO: could really use evaluation here. Launch configuration is - // imposed by transformation and the information should be - // embedded in codegen IR. - cuda_kernel_->reduction_axes_ = reductionAxes(out); - - // We coalesc all reduction axes to the right; - size_t num_reduction_axes = coalescReduction(out); - - // Merge all iteration dimensions - while (out->nDims() > num_reduction_axes + 1) { - out->merge(0, 1); - } - // Merge all reduction dimensions - while (out->nDims() > 2) { - out->merge(1, 2); - } - - } else { - // Merge all dimensions because we're only supporting pointwise - while (out->nDims() > 1) - out->merge(0, 1); - // Split into 128 which will be bockDim.x - out->split(0, kPwThreadX); - // Split by another 4 which will be our unroll factor - auto ur_factor = disable_unroll ? 1 : unroll_factor; - if (!disable_unroll) { - out->split(0, ur_factor); - cuda_kernel_->unroll_factor_ = ur_factor; - } - } - } - - if (has_reduction) { - // Run through outputs, grab all inputs of outputs - // squeeze with computeAt to set overall structure. - for (auto output : cuda_kernel_->fusion_->outputs()) { - if (output->getValType() != ValType::TensorView) - continue; - TensorView* out_tv = static_cast(output); - - // fcd_reduction could be queried later via - // cuda_kernel_->reduction_axes_, which would ensure we have proper - // launch configuratoin. - TensorView* intermediate; - if (fcd_reduction) { - out_tv->split(-1, kFcdReductionThreadX); - // necessary to avoid dynamic allocation on intermediates; - intermediate = out_tv->rFactor({-2}); - } else { - // TODO: we don't need a full warp here, this should be determined by - // element data type - out_tv->split(0, kNonFcdReductionThreadX); - out_tv->split( - -1, kNonFcdReductionThreadY); // necessary to avoid dynamic - // allocation on intermediates; - intermediate = out_tv->rFactor({-2}); - } - for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) { - // scheduling of inputs shouldn't change with different fcd_reduction - if (inp->getValType().value() == ValType::TensorView) { - static_cast(inp)->computeAt(intermediate, -1); - } - } - // scheduling of inputs shouldn't change with different fcd_reduction - intermediate->computeAt(out_tv, -2); - if (fcd_reduction) { - out_tv->axis(0)->parallelize(ParallelType::BIDx); - } else { - out_tv->axis(0)->parallelize(ParallelType::BIDx); - out_tv->axis(1)->parallelize(ParallelType::TIDx); - } - } - // Run through all values, unroll, and bind their axes - for (auto val : cuda_kernel_->fusion_->vals()) { - if (val->getValType().value() != ValType::TensorView) - continue; - TensorView* tv = static_cast(val); - if (fcd_reduction) { - tv->axis(-1)->parallelize(ParallelType::TIDx); - } else { - tv->axis(-1)->parallelize(ParallelType::TIDy); - } - } - } else { - // Run through outputs, grab all inputs of outputs - // squeeze with computeAt to set overall structure. - for (auto output : cuda_kernel_->fusion_->outputs()) { - if (output->getValType() != ValType::TensorView) - continue; - TensorView* out_tv = static_cast(output); - for (Val* inp : cuda_kernel_->fusion_->inputsOf(output)) { - if (inp->getValType().value() == ValType::TensorView) - static_cast(inp)->computeAt(out_tv, 1); - } - out_tv->axis(0)->parallelize(ParallelType::BIDx); - } - - // Run through all values, unroll, and bind their axes - for (auto val : cuda_kernel_->fusion_->vals()) { - if (val->getValType().value() != ValType::TensorView) - continue; - TensorView* tv = static_cast(val); - - // Should be true for all intermediates, but if one isn't hooked - // up right, skip it and hope for the best for now - if (!disable_unroll && tv->nDims() == 3) { - tv->axis(-2)->parallelize(ParallelType::Unroll); - tv->axis(-1)->parallelize(ParallelType::TIDx); - } else { - if (tv->nDims() == 2) - tv->axis(-1)->parallelize(ParallelType::TIDx); - } - } } } diff --git a/torch/csrc/jit/codegen/cuda/scheduler.cpp b/torch/csrc/jit/codegen/cuda/scheduler.cpp new file mode 100644 index 000000000000..3064bc6461f9 --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler.cpp @@ -0,0 +1,243 @@ +#include + +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +constexpr int kUnrollFactor = 4; + +namespace { + +std::vector reductionAxes(TensorView* tv) { + size_t n_dims = tv->nDims(); + std::vector reduction_axes; + for (size_t i = 0; i < n_dims; i++) { + if (tv->axis(i)->isReduction()) { + reduction_axes.emplace_back(i); + } + } + return reduction_axes; +} + +// coalesces all reduction to the right side and returns total number of +// reduction axes +size_t coalescReduction(TensorView* tv) { + auto reduction_axes = reductionAxes(tv); + size_t n_dims = tv->nDims(); + std::unordered_map coalesc_permute; + for (size_t i = 0; i < reduction_axes.size(); i++) { + size_t new_pos = i + n_dims - reduction_axes.size(); + if (new_pos == reduction_axes[i]) { + break; + } else { + coalesc_permute[reduction_axes[i]] = new_pos; + } + } + if (!coalesc_permute.empty()) { + tv->reorder(coalesc_permute); + } + return reduction_axes.size(); +} + +} // namespace + +// This one is a total mess and it should go. +bool scheduleFusion(Fusion* fusion, const at::ArrayRef inputs) { + FusionGuard fg(fusion); + // maybe has_reduction for scheudling should be done on a per output tensor + // basis. + const bool has_reduction = fusion->hasReduction(); + const bool disable_unroll = fusion->hasRNG(); + bool fcd_reduction = false; + + for (auto out_val : fusion->outputs()) { + auto out = out_val->as(); + if (has_reduction) { + // TODO: this scheduling only works for a single reduction operation in + // the fusion, in this case we can coalesc all reduction axes and + // merge them together. (same applies to iteration axes) + // TODO: does this work for multiple outputs? + + // query if fastest changing dimension (FCD) is a reduction + fcd_reduction = out->axis((int)out->nDims() - 1)->isReduction(); + + // TODO: could really use evaluation here. Launch configuration is + // imposed by transformation and the information should be + // embedded in codegen IR. + // cuda_kernel_->reduction_axes_ = reductionAxes(out); + + // We coalesc all reduction axes to the right; + size_t num_reduction_axes = coalescReduction(out); + + // Merge all iteration dimensions + while (out->nDims() > num_reduction_axes + 1) { + out->merge(0, 1); + } + // Merge all reduction dimensions + while (out->nDims() > 2) { + out->merge(1, 2); + } + } else { + // Merge all dimensions because we're only supporting pointwise + while (out->nDims() > 1) + out->merge(0, 1); + } + } + + if (has_reduction) { + // Run through outputs, grab all inputs of outputs + // squeeze with computeAt to set overall structure. + for (auto output : fusion->outputs()) { + if (output->getValType() != ValType::TensorView) + continue; + TensorView* out_tv = static_cast(output); + + // fcd_reduction could be queried later via + // cuda_kernel_->reduction_axes_, which would ensure we have proper + // launch configuratoin. + TensorView* intermediate = nullptr; + if (fcd_reduction) { + out_tv->split(-1, kFcdReductionThreadX); + // necessary to avoid dynamic allocation on intermediates; + intermediate = out_tv->rFactor({-2}); + } else { + // TODO: we don't need a full warp here, this should be determined by + // element data type + out_tv->split(0, kNonFcdReductionThreadX); + out_tv->split( + -1, kNonFcdReductionThreadY); // necessary to avoid dynamic + // allocation on intermediates; + intermediate = out_tv->rFactor({-2}); + } + for (Val* inp : fusion->inputsOf(output)) { + // scheduling of inputs shouldn't change with different fcd_reduction + if (inp->getValType().value() == ValType::TensorView) { + static_cast(inp)->computeAt(intermediate, -1); + } + } + // scheduling of inputs shouldn't change with different fcd_reduction + intermediate->computeAt(out_tv, -2); + if (fcd_reduction) { + out_tv->axis(0)->parallelize(ParallelType::BIDx); + } else { + out_tv->axis(0)->parallelize(ParallelType::BIDx); + out_tv->axis(1)->parallelize(ParallelType::TIDx); + } + } + // Run through all values, unroll, and bind their axes + for (auto val : fusion->vals()) { + if (val->getValType().value() != ValType::TensorView) + continue; + TensorView* tv = static_cast(val); + if (fcd_reduction) { + tv->axis(-1)->parallelize(ParallelType::TIDx); + } else { + tv->axis(-1)->parallelize(ParallelType::TIDy); + } + } + + TensorView* out0 = fusion->outputs()[0]->as(); + int ndim = (int)out0->nDims(); + Val* numel = new Int(1); + for (int i = 0; i < ndim; i++) { + if (out0->axis(i)->isBlockDim()) { + numel = mul(numel, out0->axis(i)->rawExtent()); + } + } + if (fcd_reduction) { + // assuming all output to be the same shape; + fusion->setLaunchConfig( + LaunchConfigType::TIDx, new Int(kFcdReductionThreadX)); + fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDx, numel); + fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); + } else { + fusion->setLaunchConfig( + LaunchConfigType::TIDx, new Int(kNonFcdReductionThreadX)); + fusion->setLaunchConfig( + LaunchConfigType::TIDy, new Int(kNonFcdReductionThreadY)); + fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDx, numel); + fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); + } + fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0)); + } else { + // Run through outputs, grab all inputs of outputs + // squeeze with computeAt to set overall structure. + for (auto output : fusion->outputs()) { + if (output->getValType() != ValType::TensorView) + continue; + TensorView* out_tv = static_cast(output); + + // Split into 128 which will be bockDim.x + out_tv->split(0, kPwThreadX); + // Split by another 4 which will be our unroll factor + auto ur_factor = disable_unroll ? 1 : kUnrollFactor; + if (!disable_unroll) { + out_tv->split(0, ur_factor); + } + } + + for (auto output : fusion->outputs()) { + if (output->getValType() != ValType::TensorView) + continue; + TensorView* out_tv = static_cast(output); + for (Val* inp : fusion->inputsOf(output)) { + if (inp->getValType().value() == ValType::TensorView) + static_cast(inp)->computeAt(out_tv, 1); + } + out_tv->axis(0)->parallelize(ParallelType::BIDx); + } + + // Run through all values, unroll, and bind their axes + for (auto val : fusion->vals()) { + if (val->getValType().value() != ValType::TensorView) + continue; + TensorView* tv = static_cast(val); + + // Should be true for all intermediates, but if one isn't hooked + // up right, skip it and hope for the best for now + if (!disable_unroll && tv->nDims() == 3) { + tv->axis(-2)->parallelize(ParallelType::Unroll); + tv->axis(-1)->parallelize(ParallelType::TIDx); + } else { + if (tv->nDims() == 2) + tv->axis(-1)->parallelize(ParallelType::TIDx); + } + } + TensorView* out0 = fusion->outputs()[0]->as(); + int ndim = (int)out0->nDims(); + Val* numel = new Int(1); + for (int i = 0; i < ndim; i++) { + if (out0->axis(i)->isBlockDim()) { + numel = mul(numel, out0->axis(i)->rawExtent()); + } + } + Val* tid_x = new Int(kPwThreadX); + Val* bid_x = numel; + fusion->setLaunchConfig(LaunchConfigType::TIDx, tid_x); + fusion->setLaunchConfig(LaunchConfigType::TIDy, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::TIDz, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDx, bid_x); + fusion->setLaunchConfig(LaunchConfigType::BIDy, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::BIDz, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::Compatible, new Int(1)); + fusion->setLaunchConfig(LaunchConfigType::SharedMemory, new Int(0)); + } + return true; +} + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/scheduler.h b/torch/csrc/jit/codegen/cuda/scheduler.h new file mode 100644 index 000000000000..89263e34df8e --- /dev/null +++ b/torch/csrc/jit/codegen/cuda/scheduler.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cuda { + +// return true or false on whether given fusion could be scheduled; +TORCH_CUDA_API bool scheduleFusion( + Fusion* fusion, + const at::ArrayRef inputs); + +} // namespace cuda +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/type.h b/torch/csrc/jit/codegen/cuda/type.h index afabfa967220..f70ac2f8191f 100644 --- a/torch/csrc/jit/codegen/cuda/type.h +++ b/torch/csrc/jit/codegen/cuda/type.h @@ -147,6 +147,17 @@ TORCH_CUDA_API c10::optional cast_func_str( size_t dataTypeSize(DataType type); +enum class LaunchConfigType { + Compatible, + SharedMemory, + BIDz, + BIDy, + BIDx, + TIDz, + TIDy, + TIDx +}; + } // namespace fuser } // namespace jit } // namespace torch