diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 25a2e6d8b501f0..a2be85268ecf22 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -50,6 +50,8 @@ FILE(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") FILE(GLOB cuda_cu "cuda/*.cu" "cuda/detail/*.cu") FILE(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") FILE(GLOB cudnn_cpp "cudnn/*.cpp") +FILE(GLOB miopen_h "miopen/*.h") +FILE(GLOB miopen_cpp "miopen/*.cpp") FILE(GLOB mkl_cpp "mkl/*.cpp") FILE(GLOB mkldnn_cpp "mkldnn/*.cpp") @@ -58,6 +60,7 @@ FILE(GLOB native_sparse_cpp "native/sparse/*.cpp") FILE(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu") FILE(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp") FILE(GLOB native_cudnn_cpp "native/cudnn/*.cpp") +FILE(GLOB native_miopen_cpp "native/miopen/*.cpp") FILE(GLOB native_cuda_cu "native/cuda/*.cu") FILE(GLOB native_cuda_cpp "native/cuda/*.cpp") FILE(GLOB native_mkl_cpp "native/mkl/*.cpp") @@ -74,9 +77,14 @@ endif() IF(USE_CUDA OR USE_ROCM) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/cuda) set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${cuda_cu} ${native_cuda_cu} ${native_sparse_cuda_cu}) - set(all_cuda_cpp ${native_cudnn_cpp} ${native_sparse_cuda_cpp} ${cuda_cpp} ${native_cuda_cpp} ${cuda_generated_cpp} ${ATen_CUDA_SRCS}) - IF(CUDNN_FOUND) - SET(all_cuda_cpp ${all_cuda_cpp} ${cudnn_cpp}) + set(all_cuda_cpp ${native_sparse_cuda_cpp} ${cuda_cpp} ${native_cuda_cpp} ${cuda_generated_cpp} ${ATen_CUDA_SRCS}) + IF(USE_CUDA) + SET(all_cuda_cpp ${native_cudnn_cpp} ${native_miopen_cpp} ${all_cuda_cpp}) + IF(CUDNN_FOUND) + SET(all_cuda_cpp ${all_cuda_cpp} ${cudnn_cpp}) + ENDIF() + ELSEIF(USE_ROCM) + SET(all_cuda_cpp ${native_cudnn_cpp} ${native_miopen_cpp} ${miopen_cpp} ${all_cuda_cpp}) ENDIF() endif() diff --git a/aten/src/ATen/cuda/CUDAConfig.h.in b/aten/src/ATen/cuda/CUDAConfig.h.in index 72adee50cf84fb..9e4b3d35c09bec 100644 --- a/aten/src/ATen/cuda/CUDAConfig.h.in +++ b/aten/src/ATen/cuda/CUDAConfig.h.in @@ -5,3 +5,4 @@ // c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ +#define AT_MIOPEN_ENABLED() @AT_MIOPEN_ENABLED@ diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 238362f90e1969..e1008342940aad 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -119,6 +119,10 @@ bool CUDAHooks::compiledWithCuDNN() const { return AT_CUDNN_ENABLED(); } +bool CUDAHooks::compiledWithMIOpen() const { + return AT_MIOPEN_ENABLED(); +} + bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const { #if AT_CUDNN_ENABLED() cudaDeviceProp* prop = diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index eae1a802a5cb07..766ab62b8ef79f 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -18,6 +18,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { Allocator* getPinnedMemoryAllocator() const override; void registerCUDATypes(Context*) const override; bool compiledWithCuDNN() const override; + bool compiledWithMIOpen() const override; bool supportsDilatedConvolutionWithCuDNN() const override; long versionCuDNN() const override; double batchnormMinEpsilonCuDNN() const override; diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h index 7ce3da3c9e051c..b901313e9b070c 100644 --- a/aten/src/ATen/cudnn/Descriptors.h +++ b/aten/src/ATen/cudnn/Descriptors.h @@ -339,7 +339,7 @@ union Constant double d; Constant(cudnnDataType_t dataType, double value) { if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) { - f = (float) value; + f = static_cast(value); } else { d = value; } diff --git a/aten/src/ATen/cudnn/Handles.cpp b/aten/src/ATen/cudnn/Handle.cpp similarity index 98% rename from aten/src/ATen/cudnn/Handles.cpp rename to aten/src/ATen/cudnn/Handle.cpp index 4848d2aca2412c..3fad861c2bd611 100644 --- a/aten/src/ATen/cudnn/Handles.cpp +++ b/aten/src/ATen/cudnn/Handle.cpp @@ -1,4 +1,4 @@ -#include "Handles.h" +#include "Handle.h" #include "ATen/cuda/Exceptions.h" diff --git a/aten/src/ATen/cudnn/Handles.h b/aten/src/ATen/cudnn/Handle.h similarity index 100% rename from aten/src/ATen/cudnn/Handles.h rename to aten/src/ATen/cudnn/Handle.h diff --git a/aten/src/ATen/cudnn/Utils.h b/aten/src/ATen/cudnn/Utils.h index 264bf580f8a4df..2ff93a9dc9f11c 100644 --- a/aten/src/ATen/cudnn/Utils.h +++ b/aten/src/ATen/cudnn/Utils.h @@ -4,7 +4,7 @@ #include "ATen/cuda/Exceptions.h" #include "THC/THC.h" #include "cudnn-wrapper.h" -#include "Handles.h" +#include "Handle.h" namespace at { namespace native { diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 1b2e4a43259b5e..17b8af6e0f6821 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -85,6 +85,10 @@ struct AT_API CUDAHooksInterface { return false; } + virtual bool compiledWithMIOpen() const { + return false; + } + virtual bool supportsDilatedConvolutionWithCuDNN() const { return false; } diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp new file mode 100644 index 00000000000000..d1744b9ddaf3b0 --- /dev/null +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -0,0 +1,116 @@ +#include "Descriptors.h" +#include + +namespace at { namespace native { + +namespace { + +inline miopenDataType_t getDataType(const at::Type& t) { + auto scalar_type = t.scalarType(); + if (scalar_type == at::kFloat) { + return miopenFloat; + } else if (scalar_type == at::kHalf) { + return miopenHalf; + } + throw std::runtime_error("TensorDescriptor only supports float and half tensors"); +} + +inline miopenDataType_t getDataType(const at::Tensor& t) { + return getDataType(t.type()); +} + +} // anonymous namespace + + +void TensorDescriptor::set(const at::Tensor &t, size_t pad) { + set(getDataType(t), t.sizes(), t.strides(), pad); +} + +static int MIOPEN_DIM_MAX = 4; + +void TensorDescriptor::set(miopenDataType_t datatype, IntList t_sizes, IntList t_strides, size_t pad) { + size_t dim = t_sizes.size(); + if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) +#define _STR(X) #X +#define STR(X) _STR(X) + throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions"); +#undef _STR +#undef STR + int size[MIOPEN_DIM_MAX]; + int stride[MIOPEN_DIM_MAX]; + for (size_t i = 0; i < dim; ++i) { + size[i] = static_cast(t_sizes[i]); + stride[i] = static_cast(t_strides[i]); + } + for (size_t i = dim; i < pad; ++i) { + size[i] = 1; + stride[i] = 1; + } + set(datatype, static_cast(std::max(dim, pad)), size, stride); +} + +std::string miopenTypeToString(miopenDataType_t dtype) { + switch (dtype) { + case miopenFloat: + return "miopenFloat"; + case miopenHalf: + return "miopenHalf"; + default: + std::ostringstream oss; + oss << "(unknown data-type " << static_cast(dtype) << ")"; + return oss.str(); + } +} + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) { + out << "TensorDescriptor " << static_cast(d.desc()) << "\n"; + int nbDims = 4; + int dimA[MIOPEN_DIM_MAX]; + int strideA[MIOPEN_DIM_MAX]; + miopenDataType_t dtype; + miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA); + out << " type = " << miopenTypeToString(dtype) << "\n"; + out << " nbDims = " << nbDims << "\n"; + // Read out only nbDims of the arrays! + out << " dimA = "; + for (auto i : ArrayRef{dimA, static_cast(nbDims)}) { + out << i << ", "; + } + out << "\n"; + out << " strideA = "; + for (auto i : ArrayRef{strideA, static_cast(nbDims)}) { + out << i << ", "; + } + out << "\n"; + return out; +} + +void TensorDescriptor::print() { std::cout << *this; } + +void FilterDescriptor::set(const at::Tensor &t, int64_t pad) { + auto dim = t.ndimension(); + if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX) +#define _STR(X) #X +#define STR(X) _STR(X) + throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions"); +#undef _STR +#undef STR + if (!t.is_contiguous()) { + throw std::runtime_error("MIOpen filters (a.k.a. weights) must be contiguous"); + } + int size[MIOPEN_DIM_MAX]; + int stride[MIOPEN_DIM_MAX]; + for (int i = 0; i < dim; ++i) { + size[i] = (int) t.size(i); + } + for (int i = dim; i < pad; ++i) { + size[i] = (int) 1; + } + for (int i = dim - 1; i >=0; --i) { + stride[i] = (i == dim - 1) ? 1 : stride[i+1] * size[i+1]; + } + dim = std::max(dim, pad); + set(getDataType(t), (int) dim, size, stride); +} + +}} diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h new file mode 100644 index 00000000000000..f174144c291637 --- /dev/null +++ b/aten/src/ATen/miopen/Descriptors.h @@ -0,0 +1,144 @@ +#pragma once + +#include "Exceptions.h" + +#include "miopen-wrapper.h" +#include +#include + +namespace at { namespace native { + +inline int dataSize(miopenDataType_t dataType) +{ + switch (dataType) { + case miopenHalf: return 2; + case miopenFloat: return 4; + default: return 8; + } +} + +// This function modifies 'stride' in place so that the stride for +// dim i is the product of the sizes of dims i+1 to the end. +static inline void fixSizeOneDimStride(int dim, const int *size, int *stride) { + int64_t z = 1; + for(int d = dim-1; d >= 0; d--) + { + if (size[d] == 1) { + stride[d] = z; + } else { + z *= size[d]; + } + } +} + +template +struct DescriptorDeleter { + void operator()(T* x) { + if (x != nullptr) { + MIOPEN_CHECK(dtor(x)); + } + } +}; + +// A generic class for wrapping MIOpen descriptor types. All you need +// is to give the underlying type the Descriptor_t points to (usually, +// if it's miopenTensorDescriptor_t it points to miopenTensorStruct), +// the constructor and the destructor. Subclasses are responsible +// for defining a set() function to actually set the descriptor. +// +// Descriptors default construct to a nullptr, and have a descriptor +// initialized the first time you call set() or any other initializing +// function. +template +class Descriptor +{ +public: + // Use desc() to access the underlying descriptor pointer in + // a read-only fashion. Most client code should use this. + // If the descriptor was never initialized, this will return + // nullptr. + T* desc() const { return desc_.get(); } + T* desc() { return desc_.get(); } + + // Use mut_desc() to access the underlying desciptor pointer + // if you intend to modify what it points to (e.g., using + // miopenSetFooDescriptor). This will ensure that the descriptor + // is initialized. Code in this file will use this function. + T* mut_desc() { init(); return desc_.get(); } +protected: + void init() { + if (desc_ == nullptr) { + T* raw_desc; + MIOPEN_CHECK(ctor(&raw_desc)); + desc_.reset(raw_desc); + } + } +private: + std::unique_ptr> desc_; +}; + +class TensorDescriptor + : public Descriptor +{ +public: + TensorDescriptor() {} + explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { + set(t, pad); + } + + void set(const at::Tensor &t, size_t pad = 0); + void set(miopenDataType_t dataType, IntList sizes, IntList strides, size_t pad = 0); + + void print(); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + fixSizeOneDimStride(dim, size, stride); + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); + +class FilterDescriptor + : public Descriptor +{ +public: + void set(const at::Tensor &t, int64_t pad = 0); + +private: + void set(miopenDataType_t dataType, int dim, int* size, int* stride) { + MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride)); + } +}; + +struct ConvolutionDescriptor + : public Descriptor +{ + void set(miopenDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) { + miopenDataType_t mathType = dataType; + if (dataType == miopenHalf) mathType = miopenFloat; + MIOPEN_CHECK(miopenInitConvolutionDescriptor(mut_desc(), miopenConvolution, *pad, *pad, *stride, *stride, 1, 1)); + } +}; + +union Constant +{ + float f; + double d; + Constant(miopenDataType_t dataType, double value) { + if (dataType == miopenHalf || dataType == miopenFloat) { + f = static_cast(value); + } else { + d = value; + } + } +}; + +}} // namespace diff --git a/aten/src/ATen/miopen/Exceptions.h b/aten/src/ATen/miopen/Exceptions.h new file mode 100644 index 00000000000000..65d905d7427a81 --- /dev/null +++ b/aten/src/ATen/miopen/Exceptions.h @@ -0,0 +1,43 @@ +#pragma once + +#include "miopen-wrapper.h" +#include +#include +#include + +struct THCState; + +namespace at { namespace native { + +class miopen_exception : public std::runtime_error { +public: + miopenStatus_t status; + miopen_exception(miopenStatus_t status, const char* msg) + : std::runtime_error(msg) + , status(status) {} + miopen_exception(miopenStatus_t status, const std::string& msg) + : std::runtime_error(msg) + , status(status) {} +}; + +inline void MIOPEN_CHECK(miopenStatus_t status) +{ + if (status != miopenStatusSuccess) { + if (status == miopenStatusNotImplemented) { + throw miopen_exception(status, std::string(miopenGetErrorString(status)) + + ". This error may appear if you passed in a non-contiguous input."); + } + throw miopen_exception(status, miopenGetErrorString(status)); + } +} + +inline void HIP_CHECK(hipError_t error) +{ + if (error != hipSuccess) { + std::string msg("HIP error: "); + msg += hipGetErrorString(error); + throw std::runtime_error(msg); + } +} + +}} // namespace at::native diff --git a/aten/src/ATen/miopen/Handle.cpp b/aten/src/ATen/miopen/Handle.cpp new file mode 100644 index 00000000000000..b04f094fb6660d --- /dev/null +++ b/aten/src/ATen/miopen/Handle.cpp @@ -0,0 +1,39 @@ +#include "ATen/miopen/Handle.h" + +#include "Exceptions.h" + +#include +#include + +namespace at { namespace native { + +namespace { + +struct Handle { + miopenHandle_t handle; + Handle() : handle(NULL) { + MIOPEN_CHECK(miopenCreate(&handle)); + } + ~Handle() { + if (handle) { + miopenDestroy(handle); + } + } +}; + +std::mutex mutex; +std::unordered_map handles; + +} // namespace + + +miopenHandle_t getMiopenHandle() +{ + int device; + HIP_CHECK(hipGetDevice(&device)); + + std::lock_guard guard(mutex); + return handles[device].handle; +} + +}} // namespace at::native diff --git a/aten/src/ATen/miopen/Handle.h b/aten/src/ATen/miopen/Handle.h new file mode 100644 index 00000000000000..e8df69270f17f3 --- /dev/null +++ b/aten/src/ATen/miopen/Handle.h @@ -0,0 +1,9 @@ +#pragma once + +#include "miopen-wrapper.h" + +namespace at { namespace native { + +miopenHandle_t getMiopenHandle(); + +}} // namespace diff --git a/aten/src/ATen/miopen/Types.cpp b/aten/src/ATen/miopen/Types.cpp new file mode 100644 index 00000000000000..b954752a368247 --- /dev/null +++ b/aten/src/ATen/miopen/Types.cpp @@ -0,0 +1,23 @@ +#include "Types.h" + +#include +#include "miopen/version.h" + +namespace at { namespace native { + +miopenDataType_t getMiopenDataType(const at::Tensor& tensor) { + if (tensor.type().scalarType() == at::kFloat) { + return miopenFloat; + } else if (tensor.type().scalarType() == at::kHalf) { + return miopenHalf; + } + std::string msg("getMiopenDataType() not supported for "); + msg += at::toString(tensor.type().scalarType()); + throw std::runtime_error(msg); +} + +int64_t miopen_version() { + return (MIOPEN_VERSION_MAJOR<<8) + (MIOPEN_VERSION_MINOR<<4) + MIOPEN_VERSION_PATCH; +} + +}} // namespace at::miopen diff --git a/aten/src/ATen/miopen/Types.h b/aten/src/ATen/miopen/Types.h new file mode 100644 index 00000000000000..0034aa4d84a2b1 --- /dev/null +++ b/aten/src/ATen/miopen/Types.h @@ -0,0 +1,12 @@ +#pragma once + +#include "miopen-wrapper.h" +#include + +namespace at { namespace native { + +miopenDataType_t getMiopenDataType(const at::Tensor& tensor); + +int64_t miopen_version(); + +}} // namespace at::miopen diff --git a/aten/src/ATen/miopen/Utils.h b/aten/src/ATen/miopen/Utils.h new file mode 100644 index 00000000000000..310ef0e902d6c0 --- /dev/null +++ b/aten/src/ATen/miopen/Utils.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include "THC/THC.h" +#include "miopen-wrapper.h" +#include "ATen/miopen/Handle.h" + +namespace at { namespace native { + +inline void setMIOpenStreamToCurrent() { + MIOPEN_CHECK(miopenSetStream(getMiopenHandle(), THCState_getCurrentStream(globalContext().getTHCState()))); +} + +// This function makes tensors which have zero stride contiguous, by +// setting the strides to 1. +inline Tensor contiguousIfZeroInStrides(const Tensor& t) { + for (auto s : t.strides()) { + if (s == 0) return t.contiguous(); + } + return t; +} + +}} diff --git a/aten/src/ATen/miopen/miopen-wrapper.h b/aten/src/ATen/miopen/miopen-wrapper.h new file mode 100644 index 00000000000000..64243bc52d84da --- /dev/null +++ b/aten/src/ATen/miopen/miopen-wrapper.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 4028e989b87022..257c8caf3e6fb3 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -24,6 +24,7 @@ struct ConvParams { bool is_padding_neg() const; void view1d_as_2d(); bool use_cudnn(const at::Tensor& input) const; + bool use_miopen(const at::Tensor& input) const; bool use_mkldnn(const at::Tensor& input) const; bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; }; @@ -118,6 +119,12 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool { return !is_output_padding_big(); } +auto ConvParams::use_miopen(const at::Tensor& input) const -> bool { + if (!detail::getCUDAHooks().compiledWithMIOpen() || !input.type().is_cuda() || !cudnn_enabled) + return false; + return true; +} + auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { #if AT_MKLDNN_ENABLED() return input.type().backend() == kCPU && @@ -355,6 +362,27 @@ at::Tensor _convolution( input, weight, bias, params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); } + } else if (params.use_miopen(input)) { + if (input.type() != weight.type()){ + std::stringstream ss; + ss << "Input type (" << input.type().toString() << ") and weight type (" << weight.type().toString() << ") should be the same"; + throw std::runtime_error(ss.str()); + } + if (bias.defined() && input.type() != bias.type()){ + std::stringstream ss; + ss << "Input type (" << input.type().toString() << ") and bias type (" << bias.type().toString() << ") should be the same"; + throw std::runtime_error(ss.str()); + } + + if (params.transposed) { + output = at::miopen_convolution_transpose( + input, weight, bias, + params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); + } else { + output = at::miopen_convolution( + input, weight, bias, + params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); + } } else if (params.use_mkldnn(input)) { #if AT_MKLDNN_ENABLED() if (input.type() != weight.type()){ diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index ded00828b4e63c..24d8a41fb50271 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -62,6 +62,22 @@ Tensor batch_norm( training, momentum, eps)); } + bool use_miopen = (input.type().is_cuda() + && (input.type().scalarType() != at::kHalf + || weight.type().scalarType() == at::kFloat) + && weight.defined() && bias.defined() + && ((running_mean.defined() && running_var.defined()) + || (!running_mean.defined() && !running_var.defined() && training)) + && detail::getCUDAHooks().compiledWithMIOpen() + ); + + if (use_miopen) { + return std::get<0>(at::miopen_batch_norm( + input, weight, bias, + running_mean, running_var, + training, momentum, eps)); + } + return at::thnn_batch_norm( input.contiguous(), weight, bias, running_mean, running_var, training, momentum, eps); diff --git a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp index 7f0a0f86524c66..6856c465e9e8ef 100644 --- a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp +++ b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp @@ -26,7 +26,7 @@ Tensor cudnn_affine_grid_generator_backward( #else // AT_CUDNN_ENABLED() #include -#include +#include #include #include #include diff --git a/aten/src/ATen/native/cudnn/Conv.cpp b/aten/src/ATen/native/cudnn/Conv.cpp index a2be3e507cbd03..d9d97cc4921ba6 100644 --- a/aten/src/ATen/native/cudnn/Conv.cpp +++ b/aten/src/ATen/native/cudnn/Conv.cpp @@ -216,6 +216,8 @@ static void check_args(CheckedFrom c, IntList args, size_t expected_size, const } +// NOTE [ Convolution checks ] +// // NB: For many call sites, it is not strictly necessary to check all of // these relationships (for example, for forward convolution, we compute // the size of output ourselves, so we don't actually need to check @@ -752,6 +754,8 @@ void cudnn_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const T &one, odesc.desc(), output->data_ptr())); } +// NOTE [ Convolution design ] +// // The general strategy: // // - cudnn_convolution (Tensor) @@ -792,7 +796,6 @@ void cudnn_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const T // - It takes output as a parameter (this should be computed!) // - It doesn't do input checking // - It doesn't resize output (it is assumed to be correctly sized) -// - It takes a ConvolutionParams struct // void raw_cudnn_convolution_forward_out( const Tensor& output, const Tensor& input, const Tensor& weight, @@ -956,6 +959,8 @@ void raw_cudnn_convolution_backward_input_out( &zero, args.idesc.desc(), grad_input.data_ptr())); } +// NOTE [ Backward vs transpose convolutions ] +// // Backward and transpose are algorithmically equivalent, but they // compute their geometry differently. In a backwards, you knew what // the original size of the input tensor was, so you can cache that diff --git a/aten/src/ATen/native/miopen/BatchNorm.cpp b/aten/src/ATen/native/miopen/BatchNorm.cpp new file mode 100644 index 00000000000000..cee60bac032d51 --- /dev/null +++ b/aten/src/ATen/native/miopen/BatchNorm.cpp @@ -0,0 +1,210 @@ +#include +#include +#include + +#include + +#if !AT_MIOPEN_ENABLED() + +namespace at { namespace native { + +// See Note [ATen preprocessor philosophy] + +std::tuple miopen_batch_norm( + const Tensor& input, const Tensor& weight, + const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, + bool training, double exponential_average_factor, double epsilon) { + throw std::runtime_error("miopen_batch_norm: ATen not compiled with MIOpen support"); +} + +std::tuple miopen_batch_norm_backward( + const Tensor& input, const Tensor& grad_output, const Tensor& weight, + const Tensor& running_mean, const Tensor& running_var, + const Tensor& save_mean, const Tensor& save_var, + double epsilon) { + throw std::runtime_error("miopen_batch_norm_backward: ATen not compiled with MIOpen support"); +} + +}} // namespace at::native + +#else // AT_MIOPEN_ENABLED + +#include +#include +#include + +#include + +namespace at { namespace native { + +namespace { + +Tensor expandScale(const Tensor& t, int64_t dim) { + std::vector size{ 1, t.numel() }; + while (static_cast(size.size()) < dim) { + size.emplace_back(1); + } + return t.view(size); +} + +} // namespace + +std::tuple miopen_batch_norm( + const Tensor& input_t, const Tensor& weight_t, + const Tensor& bias_t, const Tensor& running_mean_t, const Tensor& running_var_t, + bool training, double exponential_average_factor, double epsilon) +{ + TensorArg input{ input_t, "input", 1 }, + weight{ weight_t, "weight", 2 }, + bias{ bias_t, "bias", 3 }, + running_mean{ running_mean_t, "running_mean", 4 }, + running_var{ running_var_t, "running_var", 5 }; + CheckedFrom c = "miopen_batch_norm"; + setMIOpenStreamToCurrent(); + + checkAllDefined(c, {input, weight, bias}); + if (!training) { + checkAllDefined(c, {running_mean, running_var}); + } + checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); + if (input->type().scalarType() == ScalarType::Half) { + checkScalarType(c, weight, ScalarType::Float); + } else { + checkAllSameType(c, {input, weight}); + } + checkAllSameType(c, {weight, bias, running_mean, running_var}); + checkAllContiguous(c, {input, weight, bias, running_mean, running_var}); + checkDimRange(c, input, 2, 6 /* exclusive */); + auto num_features = input->size(1); + for (auto t : {weight, bias, running_mean, running_var}) { + if (t->defined()) { + checkNumel(c, t, num_features); + } + } + + miopenBatchNormMode_t mode; + if (input->dim() == 2) { + mode = miopenBNPerActivation; + } else { + mode = miopenBNSpatial; + } + + auto output_t = input->type().tensor(input->sizes()); + TensorArg output{ output_t, "output", 0 }; + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*input); + TensorDescriptor idesc{ *input, 4 }; // input descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, running_mean, etc. + + Constant one(dataType, 1); + Constant zero(dataType, 0); + Tensor save_mean, save_var; + + if (training) { + int64_t num_features = input_t.size(1); + save_mean = weight_t.type().tensor({ num_features }); + save_var = weight_t.type().tensor({ num_features }); + MIOPEN_CHECK(miopenBatchNormalizationForwardTraining( + handle, mode, &one, &zero, + idesc.desc(), input->data_ptr(), + idesc.desc(), output->data_ptr(), + wdesc.desc(), + weight->data_ptr(), + bias->data_ptr(), + exponential_average_factor, + at::maybe_data_ptr(running_mean), + at::maybe_data_ptr(running_var), + epsilon, + save_mean.data_ptr(), + save_var.data_ptr())); + } else { + MIOPEN_CHECK(miopenBatchNormalizationForwardInference( + handle, mode, &one, &zero, + idesc.desc(), input->data_ptr(), + idesc.desc(), output->data_ptr(), + wdesc.desc(), + weight->data_ptr(), + bias->data_ptr(), + running_mean->data_ptr(), + running_var->data_ptr(), + epsilon)); + } + + // save_mean and save_var can be undefined + // If this causes problems, we can initialize them to empty tensors + // of the correct type + return std::tuple{output_t, save_mean, save_var}; +} + +std::tuple miopen_batch_norm_backward( + const Tensor& input_t, const Tensor& grad_output_t, const Tensor& weight_t, + // Unused: but we require them to be passed so that double backwards + // has access + const Tensor& running_mean, const Tensor& running_var, + const Tensor& save_mean_t, const Tensor& save_var_t, + double epsilon) +{ + TensorArg input{ input_t, "input", 1 }, + grad_output{ grad_output_t, "grad_output", 2 }, + weight{ weight_t, "weight", 3 }, + save_mean{ save_mean_t, "save_mean", 4 }, + save_var{ save_var_t, "save_var", 5 }; + CheckedFrom c = "miopen_batch_norm_backward"; + setMIOpenStreamToCurrent(); + + checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); + checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); + if (input->type().scalarType() == ScalarType::Half) { + checkScalarType(c, weight, ScalarType::Float); + } else { + checkAllSameType(c, {input, weight}); + } + checkAllSameType(c, {input, grad_output}); + checkAllSameType(c, {weight, save_mean, save_var}); + checkAllContiguous(c, {input, grad_output, save_mean, save_var}); + checkDimRange(c, input, 2, 6 /* exclusive */); + checkSameSize(c, input, grad_output); + auto num_features = input->size(1); + for (auto t : {weight, save_mean, save_var}) { + checkNumel(c, t, num_features); + } + + miopenBatchNormMode_t mode; + if (input->dim() == 2) { + mode = miopenBNPerActivation; + } else { + mode = miopenBNSpatial; + } + + auto grad_input_t = input->type().tensor(input->sizes()); + auto grad_weight_t = weight->type().tensor(weight->sizes()); + auto grad_bias_t = weight->type().tensor(weight->sizes()); + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*input); + + TensorDescriptor idesc{ *input, 4 }; // input, output, grad_output descriptor + TensorDescriptor wdesc{ expandScale(*weight, input->dim()), 4 }; // descriptor for weight, bias, save_mean, etc. + + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenBatchNormalizationBackward( + handle, mode, &one, &zero, &one, &zero, + idesc.desc(), input->data_ptr(), + idesc.desc(), grad_output->data_ptr(), + idesc.desc(), grad_input_t.data_ptr(), + wdesc.desc(), weight->data_ptr(), + grad_weight_t.data_ptr(), + grad_bias_t.data_ptr(), + epsilon, + save_mean->data_ptr(), + save_var->data_ptr())); + + return std::tuple{grad_input_t, grad_weight_t, grad_bias_t}; +} + +}} // namespace native + +#endif diff --git a/aten/src/ATen/native/miopen/Conv.cpp b/aten/src/ATen/native/miopen/Conv.cpp new file mode 100644 index 00000000000000..97e22e7a1ec072 --- /dev/null +++ b/aten/src/ATen/native/miopen/Conv.cpp @@ -0,0 +1,946 @@ +#include +#include +#include +#include + +#if !AT_MIOPEN_ENABLED() + +namespace at { namespace native { + +// See Note [ATen preprocessor philosophy] + +at::Tensor miopen_convolution( + const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, + IntList padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_backward_input( + IntList input_size, const at::Tensor& grad_output, const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution_backward_input: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_backward_weight( + IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution_backward_weight: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_backward_bias( + const at::Tensor& grad_output) { + throw std::runtime_error("miopen_convolution_backward_bias: ATen not compiled with MIOpen support"); +} + +std::tuple miopen_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic, std::array output_mask) { + throw std::runtime_error("miopen_convolution_backward: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_transpose( + const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias /* optional */, + IntList padding, IntList output_padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution_transpose: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_transpose_backward_input( + const at::Tensor& grad_output, const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); +} + +at::Tensor miopen_convolution_transpose_backward_weight( + IntList weight_size, const at::Tensor& grad_output, const at::Tensor& input, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + throw std::runtime_error("miopen_convolution_transpose_backward_weight: ATen not compiled with MIOpen support"); +} + +std::tuple miopen_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, + IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic, std::array output_mask) { + throw std::runtime_error("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); +} + +}} + +#else // AT_MIOPEN_ENABLED + +#include "THC/THC.h" + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// --------------------------------------------------------------------- +// +// Math +// +// --------------------------------------------------------------------- + +constexpr int input_batch_size_dim = 0; // also grad_input +constexpr int input_channels_dim = 1; +constexpr int output_batch_size_dim = 0; // also grad_output +constexpr int output_channels_dim = 1; +constexpr int weight_output_channels_dim = 0; +constexpr int weight_input_channels_dim = 1; + +// Often written as 2 + max_dim (extra dims for batch size and channels) +constexpr int max_dim = 3; + +// NB: conv_output_size and conv_input_size are not bijections, +// as conv_output_size loses information; this is why conv_input_size +// takes an extra output_padding argument to resolve the ambiguity. + +std::vector conv_output_size( + IntList input_size, IntList weight_size, + IntList padding, IntList stride, IntList dilation, int64_t groups +) { + // ASSERT(input_size.size() > 2) + // ASSERT(input_size.size() == weight_size.size()) + auto dim = input_size.size(); + std::vector output_size(dim); + output_size[0] = input_size[input_batch_size_dim]; + output_size[1] = weight_size[weight_output_channels_dim]; + for (size_t d = 2; d < dim; ++d) { + auto kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; + output_size[d] = (input_size[d] + (2 * padding[d - 2]) + - kernel) / stride[d - 2] + 1; + } + return output_size; +} + +std::vector conv_input_size( + IntList output_size, IntList weight_size, + IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups +) { + // ASSERT(output_size.size() > 2) + // ASSERT(output_size.size() == weight_size.size()) + auto dim = output_size.size(); + std::vector input_size(dim); + input_size[0] = output_size[output_batch_size_dim]; + input_size[1] = weight_size[weight_input_channels_dim] * groups; + for (size_t d = 2; d < dim; ++d) { + int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + + kernel + output_padding[d - 2]; + } + return input_size; +} + +std::vector conv_weight_size( + IntList input_size, IntList output_size, + IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups +) { + auto dim = input_size.size(); + std::vector weight_size(dim); + weight_size[0] = output_size[1]; + weight_size[1] = input_size[1] / groups; + for (size_t d = 2; d < dim; ++d) { + int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + 2 * padding[d - 2] - output_padding[d - 2]; + weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; + } + return weight_size; +} + +Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) { + auto group_size = t.size(dim) / groups; + return t.narrow(dim, group_idx * group_size, group_size); +} + +// --------------------------------------------------------------------- +// +// Checking +// +// --------------------------------------------------------------------- + +// Used on pad, stride and dilation +static void check_args(CheckedFrom c, IntList args, size_t expected_size, const char* arg_name) +{ + if (args.size() > expected_size){ + std::stringstream ss; + ss << "Too many " << arg_name << " values (" << args.size() << ") supplied, expecting " << expected_size << " (while checking arguments for " << c << ")"; + throw std::runtime_error(ss.str()); + } + else if (args.size() < expected_size){ + std::stringstream ss; + ss << "Not enough " << arg_name << " values (" << args.size() << ") supplied, expecting " << expected_size << " (while checking arguments for " << c << ")"; + throw std::runtime_error(ss.str()); + } + + auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;}); + if (num_negative_values > 0){ + std::stringstream ss; + ss << arg_name << " should be greater than zero but got ("; + std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); + ss << args.back() << ")" << " (while checking arguments for " << c << ")"; + throw std::runtime_error(ss.str()); + } +} + +// see NOTE [ Convolution checks] in src/Aten/native/cudnn/Conv.cpp +static void convolution_shape_check( + CheckedFrom c, + const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output, + IntList padding, IntList stride, IntList dilation, int64_t groups) +{ + check_args(c, padding, input->dim() - 2, "padding"); + check_args(c, stride, padding.size(), "stride"); + check_args(c, dilation, padding.size(), "dilation"); + + // Input + checkDimRange(c, input, 3, 6 /* exclusive */); + checkSize(c, input, input_channels_dim, weight->size(1) * groups); + + // Weight + checkSameDim(c, input, weight); + + checkSameDim(c, input, output); +} + +// This POD struct is used to let us easily compute hashes of the +// parameters +struct ConvolutionParams +{ + miopenDataType_t dataType; + int input_size[2 + max_dim]; + int input_stride[2 + max_dim]; + int weight_size[2 + max_dim]; + int padding[max_dim]; + int stride[max_dim]; + int dilation[max_dim]; + int64_t groups; + bool deterministic; + // NB: transposed purposely omitted: transposed just swaps + // forward and backward, so you can reuse the benchmark entry, +}; +// ConvolutionParams must be a POD because we read out its memory +// contenst as char* when hashing +static_assert(std::is_pod::value, "ConvolutionParams not POD"); + +void setConvolutionParams( + ConvolutionParams* params, + const at::Tensor& input, const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, + int64_t groups, bool deterministic) { + + miopenDataType_t dataType = getMiopenDataType(input); + memset(params, 0, sizeof(ConvolutionParams)); + params->dataType = dataType; + // ASSERT(weight.dim() == input.dim()) + for (int i = 0; i != input.dim(); ++i) { + params->input_size[i] = (int) input.size(i); + params->input_stride[i] = (int) input.stride(i); + params->weight_size[i] = (int) weight.size(i); + } + // ASSERT(padding.size() == stride.size()) + // ASSERT(padding.size() == dilation.size()) + for (size_t i = 0; i != padding.size(); ++i) { + params->padding[i] = padding[i]; + params->stride[i] = stride[i]; + params->dilation[i] = dilation[i]; + } + params->groups = groups; + params->deterministic = deterministic; +} + +// Convenience struct for passing around descriptors and data +// pointers +struct ConvolutionArgs { + miopenHandle_t handle; + ConvolutionParams params; + TensorDescriptor idesc, odesc; + FilterDescriptor wdesc; + const Tensor& input, output, weight; + ConvolutionDescriptor cdesc; + + ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { + } +}; + +// --------------------------------------------------------------------- +// +// Benchmarking +// +// --------------------------------------------------------------------- + +// Hashing machinery for ConvolutionParams +struct ParamsHash { + std::size_t operator()(const ConvolutionParams& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < (int)sizeof(ConvolutionParams); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +struct ParamsEqual { + bool operator()(const ConvolutionParams& a, const ConvolutionParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvolutionParams)) == 0; + } +}; + +template +struct BenchmarkCache { + std::mutex mutex; + std::unordered_map map; + + bool find(const ConvolutionParams& params, T* results) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *results = it->second; + return true; + } + + void insert(const ConvolutionParams& params, const T& results) { + std::lock_guard guard(mutex); + map[params] = results; + } +}; + +BenchmarkCache fwd_algos; +BenchmarkCache bwd_data_algos; +BenchmarkCache bwd_filter_algos; + +struct Workspace { + Workspace(size_t size) : size(size), data(NULL) { + data = THCudaMalloc(globalContext().lazyInitCUDA(), size); + } + Workspace(const Workspace&) = delete; + Workspace(Workspace&&) = default; + Workspace& operator=(Workspace&&) = default; + ~Workspace() { + if (data) { + THCudaFree(globalContext().lazyInitCUDA(), data); + } + } + + size_t size; + void* data; +}; + +template +struct algorithm_search { +}; + +size_t getWorkspaceSize( + const ConvolutionArgs& args, const miopenConvFwdAlgorithm_t) +{ + size_t sz = 0; + miopenConvolutionForwardGetWorkSpaceSize( + args.handle, + args.wdesc.desc(), + args.idesc.desc(), + args.cdesc.desc(), + args.odesc.desc(), + &sz); + return sz; +} +size_t getWorkspaceSize( + const ConvolutionArgs& args, const miopenConvBwdDataAlgorithm_t) +{ + size_t sz = 0; + miopenConvolutionBackwardDataGetWorkSpaceSize( + args.handle, + args.odesc.desc(), + args.wdesc.desc(), + args.cdesc.desc(), + args.idesc.desc(), + &sz); + return sz; +} +size_t getWorkspaceSize( + const ConvolutionArgs& args, const miopenConvBwdWeightsAlgorithm_t) +{ + size_t sz = 0; + miopenConvolutionBackwardWeightsGetWorkSpaceSize( + args.handle, + args.odesc.desc(), + args.idesc.desc(), + args.cdesc.desc(), + args.wdesc.desc(), + &sz); + return sz; +} + +template +perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) { + return perfResults[0]; +} + +template<> +struct algorithm_search { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvFwdAlgorithm_t; + + static constexpr auto DEFAULT_ALGO = miopenConvolutionFwdAlgoGEMM; + static BenchmarkCache& cache() { return fwd_algos; } + + static perf_t findAlgorithm(const ConvolutionArgs& args) { + int perf_count; + perf_t perf_results; + size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO); + Workspace ws(max_ws_size); + MIOPEN_CHECK(miopenFindConvolutionForwardAlgorithm( + args.handle, + args.idesc.desc(), args.input.data_ptr(), + args.wdesc.desc(), args.weight.data_ptr(), + args.cdesc.desc(), + args.odesc.desc(), args.output.data_ptr(), + 1, // just return the fastest + &perf_count, + &perf_results, + ws.data, + ws.size, + false)); + return perf_results; + } +}; + +template<> +struct algorithm_search { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvBwdDataAlgorithm_t; + + static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM; + static BenchmarkCache& cache() { return bwd_data_algos; } + + static perf_t findAlgorithm(const ConvolutionArgs& args) { + int perf_count; + perf_t perf_results; + size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO); + Workspace ws(max_ws_size); + MIOPEN_CHECK(miopenFindConvolutionBackwardDataAlgorithm( + args.handle, + args.odesc.desc(), args.output.data_ptr(), + args.wdesc.desc(), args.weight.data_ptr(), + args.cdesc.desc(), + args.idesc.desc(), args.input.data_ptr(), + 1, // just return the fastest + &perf_count, + &perf_results, + ws.data, + ws.size, + false)); + return perf_results; + } +}; + +template<> +struct algorithm_search { + using perf_t = miopenConvAlgoPerf_t; + using algo_t = miopenConvBwdWeightsAlgorithm_t; + + static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM; + static BenchmarkCache& cache() { return bwd_filter_algos; } + + static perf_t findAlgorithm(const ConvolutionArgs& args) { + int perf_count; + perf_t perf_results; + size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO); + Workspace ws(max_ws_size); + MIOPEN_CHECK(miopenFindConvolutionBackwardWeightsAlgorithm( + args.handle, + args.odesc.desc(), args.output.data_ptr(), + args.idesc.desc(), args.input.data_ptr(), + args.cdesc.desc(), + args.wdesc.desc(), args.weight.data_ptr(), + 1, // just return the fastest + &perf_count, + &perf_results, + ws.data, + ws.size, + false)); + return perf_results; + } +}; + +template +void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) { + using search = algorithm_search; + auto& cache = search::cache(); + + if (cache.find(args.params, algo)) { + return; + } + + if (args.params.deterministic && !benchmark) { + *algo = search::DEFAULT_ALGO; + return; + } + + if (cache.find(args.params, algo)) { + // re-check cache since another thread may have benchmarked the algorithm + return; + } + + auto perfResults = search::findAlgorithm(args); + *algo = reinterpret_cast(perfResults); + + cache.insert(args.params, *algo); + + THCCachingAllocator_emptyCache(); +} + +template +Workspace chooseAlgorithm( + const ConvolutionArgs& args, + bool benchmark, + algo_t* algo) +{ + findAlgorithm(args, benchmark, algo); + + using search = algorithm_search; + size_t workspace_size; + workspace_size = getWorkspaceSize(args, *algo); + try { + return Workspace(workspace_size); + } catch (std::runtime_error& e) { + hipGetLastError(); // clear OOM error + + // switch to default algorithm and record it in the cache to prevent + // further OOM errors + *algo = search::DEFAULT_ALGO; + search::cache().insert(args.params, *algo); + + workspace_size = getWorkspaceSize(args, *algo); + return Workspace(workspace_size); + } +} + +// --------------------------------------------------------------------- +// +// Bias addition +// +// --------------------------------------------------------------------- + +// In-place! +void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const TensorArg& bias) +{ + checkAllSameType(c, {output, bias}); + checkAllSameGPU(c, {output, bias}); + checkSize(c, bias, { output->size(output_channels_dim) }); + + TensorDescriptor bdesc, odesc; + bdesc.set(bias->expand({1, bias->size(0)}), output->dim()); + odesc.set(*output); + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*bias); + Constant one(dataType, 1); + + MIOPEN_CHECK(miopenConvolutionForwardBias(handle, &one, bdesc.desc(), bias->data_ptr(), + &one, odesc.desc(), output->data_ptr())); +} + +// see NOTE [ Convolution design ] in src/Aten/native/cudnn/Conv.cpp + + +// --------------------------------------------------------------------- +// +// Convolution forward / Transposed convolution backward +// +// --------------------------------------------------------------------- + +// The raw API directly invokes MIOpen. +// +// There are a few reasons this should never be directly exposed +// via ATen: +// +// - It takes output as a parameter (this should be computed!) +// - It doesn't do input checking +// - It doesn't resize output (it is assumed to be correctly sized) +// +void raw_miopen_convolution_forward_out( + const Tensor& output, const Tensor& input, const Tensor& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + + auto dataType = getMiopenDataType(input); + + ConvolutionArgs args{ input, output, weight }; + args.handle = getMiopenHandle(); + setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic); + args.idesc.set(input); + args.wdesc.set(weight); + args.odesc.set(output); + args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + + miopenConvFwdAlgorithm_t fwdAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenConvolutionForward( + args.handle, + &one, args.idesc.desc(), input.data_ptr(), + args.wdesc.desc(), weight.data_ptr(), + args.cdesc.desc(), fwdAlg, &zero, + args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size)); +} + +Tensor miopen_convolution_forward( + CheckedFrom c, + const TensorArg& input, const TensorArg& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + checkAllSameType(c, {input, weight}); + checkAllSameGPU(c, {input, weight}); + + auto output_t = input->type().tensor( + conv_output_size(input->sizes(), weight->sizes(), + padding, stride, dilation, groups)); + + // Avoid ambiguity of "output" when this is being used as backwards + TensorArg output{ output_t, "result", 0 }; + convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(); + + raw_miopen_convolution_forward_out( + *output, *input, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic); + + return *output; +} + +Tensor miopen_convolution( + const Tensor& input_t, const Tensor& weight_t, const Tensor& bias_t, + IntList padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }, + bias { bias_t, "bias", 3 }; + setMIOpenStreamToCurrent(); + CheckedFrom c = "miopen_convolution"; + auto output_t = miopen_convolution_forward( + c, input, weight, padding, stride, dilation, groups, benchmark, deterministic); + if (bias->defined()) { + miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); + } + return output_t; +} + +Tensor miopen_convolution_transpose_backward_input( + const Tensor& grad_output_t, const Tensor& weight_t, + IntList padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + TensorArg grad_output { grad_output_t, "grad_output", 1 }, + weight { weight_t, "weight", 2 }; + setMIOpenStreamToCurrent(); + return miopen_convolution_forward( + "miopen_convolution_transpose_backward_input", + grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); +} + +std::tuple miopen_convolution_transpose_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(); + + Tensor grad_input, grad_weight, grad_bias; + if (output_mask[0]) { + grad_input = at::miopen_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); + } + if (output_mask[1]) { + grad_weight = at::miopen_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); + } + if (output_mask[2]) { + grad_bias = at::miopen_convolution_backward_bias(grad_output); + } + + return std::tuple{grad_input, grad_weight, grad_bias}; +} + +// --------------------------------------------------------------------- +// +// Convolution backward / Transposed convolution forward +// +// --------------------------------------------------------------------- + +void raw_miopen_convolution_backward_input_out( + const at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + + auto dataType = getMiopenDataType(grad_output); + + ConvolutionArgs args{ grad_input, grad_output, weight }; + args.handle = getMiopenHandle(); + setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic); + args.idesc.set(grad_input); + args.wdesc.set(weight); + args.odesc.set(grad_output); + args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + + miopenConvBwdDataAlgorithm_t bwdDataAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlg); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenConvolutionBackwardData( + args.handle, + &one, args.odesc.desc(), grad_output.data_ptr(), + args.wdesc.desc(), weight.data_ptr(), + args.cdesc.desc(), bwdDataAlg, &zero, + args.idesc.desc(), grad_input.data_ptr(), workspace.data, workspace.size)); +} + +// see NOTE [ Backward vs transpose convolutions ] in src/Aten/native/cudnn/Conv.cpp + +Tensor miopen_convolution_backward_input( + CheckedFrom c, + IntList input_size, const TensorArg& grad_output, const TensorArg& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + checkAllSameType(c, {grad_output, weight}); + checkAllSameGPU(c, {grad_output, weight}); + + auto grad_input_t = grad_output->type().tensor(input_size); + + // Avoid "grad_input" when this is being used as transposed convolution + TensorArg grad_input{ grad_input_t, "result", 0 }; + convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); + + // See #4500 + Tensor weight_contig = weight->contiguous(); + + raw_miopen_convolution_backward_input_out( + *grad_input, *grad_output, weight_contig, + padding, stride, dilation, groups, benchmark, deterministic); + + return *grad_input; +} + +Tensor miopen_convolution_transpose_forward( + CheckedFrom c, + const TensorArg& grad_output, const TensorArg& weight, + IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(), + padding, output_padding, stride, dilation, groups); + return miopen_convolution_backward_input(c, input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic); +} + +Tensor miopen_convolution_backward_input( + IntList input_size, const Tensor& grad_output_t, const Tensor& weight_t, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, + weight{ weight_t, "weight", 2 }; + setMIOpenStreamToCurrent(); + return miopen_convolution_backward_input( + "miopen_convolution_backward_input", + input_size, grad_output, weight, + padding, stride, dilation, groups, benchmark, deterministic); +} + +std::tuple miopen_convolution_backward( + const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic, std::array output_mask) { + + Tensor grad_output = grad_output_t.contiguous(); + + Tensor grad_input, grad_weight, grad_bias; + if (output_mask[0]) { + grad_input = at::miopen_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic); + } + if (output_mask[1]) { + grad_weight = at::miopen_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic); + } + if (output_mask[2]) { + grad_bias = at::miopen_convolution_backward_bias(grad_output); + } + + return std::tuple{grad_input, grad_weight, grad_bias}; +} + +Tensor miopen_convolution_transpose( + const Tensor& input_t, const Tensor& weight_t, const Tensor& bias_t, + IntList padding, IntList output_padding, IntList stride, IntList dilation, + int64_t groups, bool benchmark, bool deterministic) +{ + TensorArg input { input_t, "input", 1 }, + weight { weight_t, "weight", 2 }, + bias { bias_t, "bias", 3 }; + CheckedFrom c = "miopen_convolution_transpose"; + auto output_t = miopen_convolution_transpose_forward( + c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + if (bias->defined()) { + miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias); + } + return output_t; +} + +// --------------------------------------------------------------------- +// +// Convolution backward (weight) +// +// --------------------------------------------------------------------- + +void raw_miopen_convolution_backward_weight_out( + const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) { + + auto dataType = getMiopenDataType(input); + + ConvolutionArgs args{ input, grad_output, grad_weight }; + args.handle = getMiopenHandle(); + setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic); + args.idesc.set(input); + args.wdesc.set(grad_weight); + args.odesc.set(grad_output); + args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups); + + miopenConvBwdWeightsAlgorithm_t bwdFilterAlg; + Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg); + + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenConvolutionBackwardWeights( + args.handle, + &one, args.odesc.desc(), grad_output.data_ptr(), + args.idesc.desc(), input.data_ptr(), + args.cdesc.desc(), bwdFilterAlg, &zero, + args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size)); +} + +Tensor miopen_convolution_backward_weight( + CheckedFrom c, + IntList weight_size, const TensorArg& grad_output, const TensorArg& input, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + + checkAllSameType(c, {grad_output, input}); + checkAllSameGPU(c, {grad_output, input}); + + auto grad_weight_t = grad_output->type().tensor(weight_size); + + // For uniformity with everything else, although it seems grad_weight + // would be unambiguous too. + TensorArg grad_weight{ grad_weight_t, "result", 0 }; + convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups); + + raw_miopen_convolution_backward_weight_out( + *grad_weight, *grad_output, *input, + padding, stride, dilation, groups, benchmark, deterministic); + + return grad_weight_t; +} + +Tensor miopen_convolution_backward_weight( + IntList weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, + input{ input_t, "input", 2 }; + setMIOpenStreamToCurrent(); + return miopen_convolution_backward_weight( + "miopen_convolution_backward_weight", + weight_size, grad_output, input, + padding, stride, dilation, groups, benchmark, deterministic); +} + +Tensor miopen_convolution_transpose_backward_weight( + IntList weight_size, + const Tensor& grad_output_t, + const Tensor& input_t, + IntList padding, IntList stride, IntList dilation, int64_t groups, + bool benchmark, bool deterministic) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }, + input{ input_t, "input", 2 }; + setMIOpenStreamToCurrent(); + return miopen_convolution_backward_weight( + "miopen_convolution_backward_weight", + weight_size, input, grad_output, + padding, stride, dilation, groups, benchmark, deterministic); +} + +// --------------------------------------------------------------------- +// +// Convolution backward (bias) +// +// --------------------------------------------------------------------- + +Tensor miopen_convolution_backward_bias( + const Tensor& grad_output_t) +{ + TensorArg grad_output{ grad_output_t, "grad_output", 1 }; + setMIOpenStreamToCurrent(); + + auto grad_bias_t = grad_output->type().tensor( + { grad_output->size(output_channels_dim) }); + + TensorArg grad_bias{ grad_bias_t, "result", 0 }; + + TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}), + static_cast(grad_output->dim())}; + TensorDescriptor odesc{*grad_output}; + + auto handle = getMiopenHandle(); + auto dataType = getMiopenDataType(*grad_bias); + Constant one(dataType, 1); + Constant zero(dataType, 0); + + MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(), + &zero, bdesc.desc(), grad_bias->data_ptr())); + return *grad_bias; +} + + +}} // namespace + +#endif diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a86ab6cb41ba52..5ff653f8931e77 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1074,6 +1074,63 @@ - func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, std::array output_mask) -> (Tensor, Tensor, Tensor) variants: function +- func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: miopen_batch_norm + +- func: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: miopen_batch_norm_backward + +- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution + +- func: miopen_convolution_backward_input(IntList self_size, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_backward_input + +- func: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: miopen_convolution_backward + +- func: miopen_convolution_backward_bias(Tensor grad_output) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_backward_bias + +- func: miopen_convolution_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_backward_weight + +- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_transpose + +# NB: output_padding not strictly needed here, but it's helpful for the double +# backwards +- func: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: miopen_convolution_transpose_backward + +- func: miopen_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_transpose_backward_input + +- func: miopen_convolution_transpose_backward_weight(IntList weight_size, Tensor grad_output, Tensor self, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor + variants: function + dispatch: + CUDA: miopen_convolution_transpose_backward_weight + - func: mm(Tensor self, Tensor mat2) -> Tensor - func: mm_out(Tensor result, Tensor self, Tensor mat2) -> Tensor diff --git a/aten/src/ATen/test/cudnn_test.cpp b/aten/src/ATen/test/cudnn_test.cpp index 7c1bc96dc2d2db..7194c83c0be717 100644 --- a/aten/src/ATen/test/cudnn_test.cpp +++ b/aten/src/ATen/test/cudnn_test.cpp @@ -3,7 +3,7 @@ #include "ATen/ATen.h" #include "ATen/cudnn/Descriptors.h" -#include "ATen/cudnn/Handles.h" +#include "ATen/cudnn/Handle.h" #include "test_seed.h" using namespace at; diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1cbcf6486e9b43..8d65e90ac35c3f 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1216,6 +1216,14 @@ if (BUILD_ATEN) set(AT_CUDNN_ENABLED 1) ENDIF() + IF (NOT USE_ROCM) + MESSAGE(STATUS "MIOpen not found. Compiling without MIOpen support") + set(AT_MIOPEN_ENABLED 0) + ELSE() + INCLUDE_DIRECTORIES(BEFORE ${MIOPEN_INCLUDE_DIRS}) + set(AT_MIOPEN_ENABLED 1) + ENDIF() + if (NO_MKLDNN) message("disabling MKLDNN because NO_MKLDNN is set") set(AT_MKLDNN_ENABLED 0) diff --git a/cmake/Modules/FindMIOpen.cmake b/cmake/Modules/FindMIOpen.cmake new file mode 100644 index 00000000000000..6a047df7e52f3e --- /dev/null +++ b/cmake/Modules/FindMIOpen.cmake @@ -0,0 +1,63 @@ +# - Try to find MIOpen +# +# The following variables are optionally searched for defaults +# MIOPEN_ROOT_DIR: Base directory where all MIOpen components are found +# +# The following are set after configuration is done: +# MIOPEN_FOUND +# MIOPEN_INCLUDE_DIRS +# MIOPEN_LIBRARIES +# MIOPEN_LIBRARY_DIRS +# +# Borrowed from https://github.com/caffe2/caffe2/blob/master/cmake/Modules/FindCuDNN.cmake + +include(FindPackageHandleStandardArgs) + +set(MIOPEN_ROOT_DIR "" CACHE PATH "Folder contains MIOpen") + +if($ENV{MIOPEN_INCLUDE_DIR}) + SET(MIOPEN_INCLUDE_DIR $ENV{MIOPEN_INCLUDE_DIR}) +else($ENV{MIOPEN_INCLUDE_DIR}) + find_path(MIOPEN_INCLUDE_DIR miopen.h + HINTS ${MIOPEN_ROOT_DIR} + PATH_SUFFIXES include include/miopen) +endif($ENV{MIOPEN_INCLUDE_DIR}) + +if($ENV{MIOPEN_LIBRARY}) + SET(MIOPEN_LIBRARY $ENV{MIOPEN_LIBRARY}) +else($ENV{MIOPEN_LIBRARY}) + find_library(MIOPEN_LIBRARY MIOpen + HINTS ${MIOPEN_LIB_DIR} ${MIOPEN_ROOT_DIR} + PATH_SUFFIXES lib) +endif($ENV{MIOPEN_LIBRARY}) + +find_package_handle_standard_args( + MIOPEN DEFAULT_MSG MIOPEN_INCLUDE_DIR MIOPEN_LIBRARY) + +if(MIOPEN_FOUND) + # get MIOpen version + file(READ ${MIOPEN_INCLUDE_DIR}/version.h MIOPEN_HEADER_CONTENTS) + string(REGEX MATCH "define MIOPEN_VERSION_MAJOR * +([0-9]+)" + MIOPEN_VERSION_MAJOR "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_MAJOR * +([0-9]+)" "\\1" + MIOPEN_VERSION_MAJOR "${MIOPEN_VERSION_MAJOR}") + string(REGEX MATCH "define MIOPEN_VERSION_MINOR * +([0-9]+)" + MIOPEN_VERSION_MINOR "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_MINOR * +([0-9]+)" "\\1" + MIOPEN_VERSION_MINOR "${MIOPEN_VERSION_MINOR}") + string(REGEX MATCH "define MIOPEN_VERSION_PATCH * +([0-9]+)" + MIOPEN_VERSION_PATCH "${MIOPEN_HEADER_CONTENTS}") + string(REGEX REPLACE "define MIOPEN_VERSION_PATCH * +([0-9]+)" "\\1" + MIOPEN_VERSION_PATCH "${MIOPEN_VERSION_PATCH}") + # Assemble MIOpen version + if(NOT MIOPEN_VERSION_MAJOR) + set(MIOPEN_VERSION "?") + else() + set(MIOPEN_VERSION "${MIOPEN_VERSION_MAJOR}.${MIOPEN_VERSION_MINOR}.${MIOPEN_VERSION_PATCH}") + endif() + + set(MIOPEN_INCLUDE_DIRS ${MIOPEN_INCLUDE_DIR}) + set(MIOPEN_LIBRARIES ${MIOPEN_LIBRARY}) + message(STATUS "Found MIOpen: v${MIOPEN_VERSION} (include: ${MIOPEN_INCLUDE_DIR}, library: ${MIOPEN_LIBRARY})") + mark_as_advanced(MIOPEN_ROOT_DIR MIOPEN_LIBRARY MIOPEN_INCLUDE_DIR) +endif() diff --git a/setup.py b/setup.py index 47815d19ae5c7c..16cda18cbb7803 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,9 @@ # NO_CUDNN # disables the cuDNN build # +# NO_MIOPEN +# disables the MIOpen build +# # NO_MKLDNN # disables the MKLDNN build # @@ -69,6 +72,11 @@ # CUDNN_LIBRARY # specify where cuDNN is installed # +# MIOPEN_LIB_DIR +# MIOPEN_INCLUDE_DIR +# MIOPEN_LIBRARY +# specify where MIOpen is installed +# # NCCL_ROOT_DIR # NCCL_LIB_DIR # NCCL_INCLUDE_DIR @@ -109,7 +117,7 @@ # Before we run the setup_helpers, let's look for NO_* and WITH_* # variables and hotpatch the environment with the USE_* equivalent -config_env_vars = ['CUDA', 'CUDNN', 'MKLDNN', 'NNPACK', 'DISTRIBUTED', 'DISTRIBUTED_MW', +config_env_vars = ['CUDA', 'CUDNN', 'MIOPEN', 'MKLDNN', 'NNPACK', 'DISTRIBUTED', 'DISTRIBUTED_MW', 'SYSTEM_NCCL', 'GLOO_IBVERBS'] @@ -129,6 +137,8 @@ def hotpatch_var(var): from tools.setup_helpers.rocm import USE_ROCM, ROCM_HOME, ROCM_VERSION from tools.setup_helpers.cudnn import (USE_CUDNN, CUDNN_LIBRARY, CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR) +from tools.setup_helpers.miopen import (USE_MIOPEN, MIOPEN_LIBRARY, + MIOPEN_LIB_DIR, MIOPEN_INCLUDE_DIR) from tools.setup_helpers.nccl import USE_NCCL, USE_SYSTEM_NCCL, NCCL_LIB_DIR, \ NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB from tools.setup_helpers.mkldnn import (USE_MKLDNN, MKLDNN_LIBRARY, @@ -328,6 +338,10 @@ def build_libs(libs): my_env["CUDNN_LIB_DIR"] = CUDNN_LIB_DIR my_env["CUDNN_LIBRARY"] = CUDNN_LIBRARY my_env["CUDNN_INCLUDE_DIR"] = CUDNN_INCLUDE_DIR + if USE_MIOPEN: + my_env["MIOPEN_LIB_DIR"] = MIOPEN_LIB_DIR + my_env["MIOPEN_LIBRARY"] = MIOPEN_LIBRARY + my_env["MIOPEN_INCLUDE_DIR"] = MIOPEN_INCLUDE_DIR if USE_MKLDNN: my_env["MKLDNN_LIB_DIR"] = MKLDNN_LIB_DIR my_env["MKLDNN_LIBRARY"] = MKLDNN_LIBRARY @@ -488,6 +502,10 @@ def run(self): print('-- Detected cuDNN at ' + CUDNN_LIBRARY + ', ' + CUDNN_INCLUDE_DIR) else: print('-- Not using cuDNN') + if USE_MIOPEN: + print('-- Detected MIOpen at ' + MIOPEN_LIBRARY + ', ' + MIOPEN_INCLUDE_DIR) + else: + print('-- Not using MIOpen') if USE_CUDA: print('-- Detected CUDA at ' + CUDA_HOME) else: @@ -926,6 +944,14 @@ def run(self): extra_link_args.insert(0, '-Wl,-rpath,' + CUDNN_LIB_DIR) extra_compile_args += ['-DUSE_CUDNN'] +if USE_MIOPEN: + main_libraries += [MIOPEN_LIBRARY] + include_dirs.insert(0, MIOPEN_INCLUDE_DIR) + extra_link_args.append('-L' + MIOPEN_LIB_DIR) + if not IS_WINDOWS: + extra_link_args.insert(0, '-Wl,-rpath,' + MIOPEN_LIB_DIR) + extra_compile_args += ['-DWITH_MIOPEN'] + if DEBUG: if IS_WINDOWS: extra_link_args.append('/DEBUG:FULL') diff --git a/test/cpp_extensions/cudnn_extension.cpp b/test/cpp_extensions/cudnn_extension.cpp index dbb662425922d9..754c09c8179e75 100644 --- a/test/cpp_extensions/cudnn_extension.cpp +++ b/test/cpp_extensions/cudnn_extension.cpp @@ -14,7 +14,7 @@ #include // for TensorDescriptor #include // for CUDNN_CHECK -#include // for getCudnnHandle +#include // for getCudnnHandle // Name of function in python module and name used for error messages by // at::check* functions. diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a66cb77f8ce9dd..453cea9c970ce7 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1206,6 +1206,28 @@ - name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, Tensor dropout_state) input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)" +# miopen + +- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) + self, weight, bias: miopen_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) + +- name: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask) + +- name: miopen_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) + self, weight, bias: miopen_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask) + +- name: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) + grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask) + +- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) + input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : thnn_batch_norm_backward(grad.contiguous(), input, weight, running_mean, running_var, training, epsilon, result1, result2, grad_input_mask)" + +- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon) + save_mean: not_implemented("miopen_batch_norm_backward save_mean") + save_var: not_implemented("miopen_batch_norm_backward save_var") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + # mkldnn - name: mkldnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups) self, weight, bias: mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) diff --git a/tools/setup_helpers/miopen.py b/tools/setup_helpers/miopen.py new file mode 100644 index 00000000000000..59ca3b918990b2 --- /dev/null +++ b/tools/setup_helpers/miopen.py @@ -0,0 +1,17 @@ +import os +import glob + +from .env import IS_WINDOWS, IS_CONDA, CONDA_DIR, check_env_flag, gather_paths +from .rocm import USE_ROCM, ROCM_HOME + + +USE_MIOPEN = False +MIOPEN_LIB_DIR = None +MIOPEN_INCLUDE_DIR = None +MIOPEN_LIBRARY = None +if USE_ROCM and not check_env_flag('NO_MIOPEN'): + USE_MIOPEN = True + MIOPEN_LIB_DIR = ROCM_HOME + "/miopen/lib" + MIOPEN_INCLUDE_DIR = ROCM_HOME + "/miopen/include/miopen" + MIOPEN_LIBRARY = "MIOpen" + MIOPEN_FOUND = "True"