diff --git a/.jenkins/pytorch/build-asan.sh b/.jenkins/pytorch/build-asan.sh index b3ac091541b314..4ece2aee66a65f 100755 --- a/.jenkins/pytorch/build-asan.sh +++ b/.jenkins/pytorch/build-asan.sh @@ -16,6 +16,6 @@ export ASAN_OPTIONS=detect_leaks=0:symbolize=1 # TODO: Make the ASAN flags a more unified env var CC="clang" CXX="clang++" LDSHARED="clang --shared" \ - CFLAGS="-fsanitize=address -shared-libasan" \ + CFLAGS="-fsanitize=address -fsanitize=undefined -fno-sanitize-recover=all -shared-libasan" \ NO_CUDA=1 DEBUG=1 \ python setup.py install diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 01fee9b40d63a4..f5aac680a6cbbc 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -21,17 +21,27 @@ popd # ASAN test is not working if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then export ASAN_OPTIONS=detect_leaks=0:symbolize=1 + export UBSAN_OPTIONS=print_stacktrace=1 export PYTORCH_TEST_WITH_ASAN=1 + export PYTORCH_TEST_WITH_UBSAN=1 # TODO: Figure out how to avoid hard-coding these paths export ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-5.0/bin/llvm-symbolizer export LD_PRELOAD=/usr/lib/llvm-5.0/lib/clang/5.0.0/lib/linux/libclang_rt.asan-x86_64.so # Increase stack size, because ASAN red zones use more stack ulimit -s 81920 + function get_exit_code() { + set +e + "$@" + retcode=$? + set -e + return $retcode + } (cd test && python -c "import torch") - echo "The next two invocations are expected to crash; if they don't that means ASAN is misconfigured" - (cd test && ! python -c "import torch; torch._C._crash_if_csrc_asan(3)") - (cd test && ! python -c "import torch; torch._C._crash_if_aten_asan(3)") + echo "The next three invocations are expected to crash; if they don't that means ASAN/UBSAN is misconfigured" + (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_asan(3)") + (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_csrc_ubsan(0)") + (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)") fi export ATEN_DISABLE_AVX= diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index debc2b9cb1f7e7..0b0b44a5e6799b 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -828,6 +828,13 @@ - arg: THTensor* self broadcast: other fallback - THTensor* other +]] +[[ + name: _th_min + variants: + - method + - function + options: - cname: min return: argument 0,1 scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1) @@ -860,6 +867,13 @@ - arg: THTensor* self broadcast: other fallback - THTensor* other +]] +[[ + name: _th_max + variants: + - method + - function + options: - cname: max return: argument 0,1 scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1) @@ -875,12 +889,13 @@ default: "false" ]] [[ - name: kthvalue + name: _th_kthvalue backends: - CPU variants: - method - function + cname: kthvalue return: argument 0,1 scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1) arguments: @@ -897,10 +912,11 @@ default: "false" ]] [[ - name: mode + name: _th_mode variants: - method - function + cname: mode return: argument 0,1 scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1) arguments: @@ -926,6 +942,15 @@ return: real arguments: - THTensor* self +]] +[[ + name: _th_median + variants: + - method + - function + cname: median + return: argument 0,1 + options: - cname: median scalar_check: self_->isScalar() || (keepdim == false && self_->dim() == 1) arguments: diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 28dd3206ea0852..fd07f8088b8b24 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -62,9 +62,9 @@ void SparseTensorImpl::set_indices_and_values(const Tensor& indices, const Tenso // dimensions at the moment bool empty = values.numel() == 0; AT_CHECK(values.type().toSparse() == type(), "values type must match sparse tensor type"); - AT_CHECK(indices.type().scalarType() == kLong); - AT_CHECK(indices.type().backend() == values.type().backend()); - AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device()); + AT_CHECK(indices.type().scalarType() == kLong, "indices must be an int64 tensor"); + AT_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")"); + AT_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")"); if (!empty) { AT_CHECK(indices.dim() == 2, "indices must be nDim x nnz"); AT_CHECK(indices.size(1) == values.size(0), "indices and values must have same nnz"); diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index cd5bc1103b9c81..26522125700d13 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -20,32 +20,24 @@ std::ostream& operator<<(std::ostream & out, TensorGeometryArg t) { } void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) { - if (t->dim() != dim) { - std::ostringstream oss; - oss << "Expected " << dim << "-dimensional tensor, but got " - << t->dim() << "-dimensional tensor for " << t - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK(t->dim() == dim, + "Expected ", dim, "-dimensional tensor, but got ", t->dim(), + "-dimensional tensor for ", t," (while checking arguments for ", c, ")"); } void checkDimRange(CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) { - if (t->dim() < dim_start || t->dim() >= dim_end) { - std::ostringstream oss; - oss << "Expected " << dim_start << " to " << (dim_end - 1) << " dimensions, but got " - << t->dim() << "-dimensional tensor for " << t - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->dim() >= dim_start && t->dim() < dim_end, + "Expected ", dim_start, " to ", (dim_end - 1), " dimensions, but got ", + t->dim(), "-dimensional tensor for ", t, " (while checking arguments for ", + c, ")"); } void checkContiguous(CheckedFrom c, const TensorGeometryArg& t) { - if (!t->is_contiguous()) { - std::ostringstream oss; - oss << "Expected contiguous tensor, but got non-contiguous tensor for " << t - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->is_contiguous(), + "Expected contiguous tensor, but got non-contiguous tensor for ", t, + " (while checking arguments for ", c, ")"); } void checkAllContiguous(CheckedFrom c, at::ArrayRef ts) { @@ -57,23 +49,18 @@ void checkAllContiguous(CheckedFrom c, at::ArrayRef ts) { void checkSize(CheckedFrom c, const TensorGeometryArg& t, IntList sizes) { checkDim(c, t, sizes.size()); - if (!t->sizes().equals(sizes)) { - std::ostringstream oss; - oss << "Expected tensor of size " << sizes << ", but got tensor of size " - << t->sizes() << " for " << t - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->sizes().equals(sizes), + "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(), + " for ", t, " (while checking arguments for ", c, ")"); } void checkSize(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) { - if (t->size(dim) != size) { - std::ostringstream oss; - oss << "Expected tensor to have size " << size << " at dimension " << dim - << ", but got size " << t->size(dim) << " for " << t - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->size(dim) == size, + "Expected tensor to have size ", size, " at dimension ", dim, + ", but got size ", t->size(dim), " for ", t, + " (while checking arguments for ", c, ")"); } void checkAllSame(CheckedFrom c, ArrayRef tensors, void(*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) { @@ -89,13 +76,11 @@ void checkAllSame(CheckedFrom c, ArrayRef tensors, void(*fn)(CheckedF } void checkSameSize(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { - if (!t1->sizes().equals(t2->sizes())) { - std::ostringstream oss; - oss << "Expected tensor for " << t1 << " to have same size as tensor for " - << t2 << "; but " << t1->sizes() << " does not equal " << t2->sizes() - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t1->sizes().equals(t2->sizes()), + "Expected tensor for ", t1, " to have same size as tensor for ", t2, + "; but ", t1->sizes(), " does not equal ", t2->sizes(), + " (while checking arguments for ", c, ")"); } void checkAllSameSize(CheckedFrom c, ArrayRef tensors) { @@ -103,23 +88,20 @@ void checkAllSameSize(CheckedFrom c, ArrayRef tensors) { } void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel) { - if (t->numel() != numel) { - std::ostringstream oss; - oss << "Expected tensor for " << t << " to have " - << numel << " elements; but it actually has " << t->numel() << " elements" - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->numel() == numel, + "Expected tensor for ", t, " to have ", numel, + " elements; but it actually has ", t->numel(), " elements", + " (while checking arguments for ", c, ")"); } void checkSameNumel(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { - if (t1->numel() != t2->numel()) { - std::ostringstream oss; - oss << "Expected tensor for " << t1 << " to have same number of elements as tensor for " - << t2 << "; but " << t1->numel() << " does not equal " << t2->numel() - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t1->numel() == t2->numel(), + "Expected tensor for ", t1, + " to have same number of elements as tensor for ", t2, "; but ", + t1->numel(), " does not equal ", t2->numel(), + " (while checking arguments for ", c, ")"); } void checkAllSameNumel(CheckedFrom c, ArrayRef tensors) { @@ -136,17 +118,14 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { oss << "Tensor for " << t2 << " is on CPU, "; } oss << "but expected " << ((!(t1->is_cuda() || t2->is_cuda())) ? "them" : "it") - << " to be on GPU (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } - if (t1->get_device() != t2->get_device()) { - std::ostringstream oss; - oss << "Expected tensor for " << t1 << " to have the same device as " - << "tensor for " << t2 << "; but device " << t1->get_device() << " " - << "does not equal " << t2->get_device() - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); + << " to be on GPU (while checking arguments for " << c << ")"; + AT_ERROR(oss.str()); } + AT_CHECK( + t1->get_device() == t2->get_device(), + "Expected tensor for ", t1, " to have the same device as tensor for ", t2, + "; but device ", t1->get_device(), " does not equal ", t2->get_device(), + " (while checking arguments for ", c, ")"); } void checkAllSameGPU(CheckedFrom c, ArrayRef tensors) { @@ -154,24 +133,19 @@ void checkAllSameGPU(CheckedFrom c, ArrayRef tensors) { } void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { - if (t1->type() != t2->type()) { - std::ostringstream oss; - oss << "Expected tensor for " << t1 << " to have the same type as " - << "tensor for " << t2 << "; but type " << t1->toString() << " " - << "does not equal " << t2->toString() - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t1->type() == t2->type(), + "Expected tensor for ", t1, " to have the same type as tensor for ", t2, + "; but type ", t1->toString(), " does not equal ", t2->toString(), + " (while checking arguments for ", c, ")"); } void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) { - if (t->type().scalarType() != ty) { - std::ostringstream oss; - oss << "Expected tensor for " << t << " to have scalar type " - << toString(ty) << "; but got " << t->toString() - << " instead (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->type().scalarType() == ty, + "Expected tensor for ", t, " to have scalar type ", toString(ty), + "; but got ", t->toString(), " instead (while checking arguments for ", c, + ")"); } void checkScalarTypes(CheckedFrom c, const TensorArg& t, @@ -190,7 +164,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t, } oss << "; but got " << t->toString() << " instead (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); + AT_ERROR(oss.str()); } } @@ -199,24 +173,18 @@ void checkAllSameType(CheckedFrom c, ArrayRef tensors) { } void checkSameDim(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) { - if (t1->dim() != t2->dim()) { - std::ostringstream oss; - oss << "Expected tensor for " << t1 << " to have the same dimension as " - << "tensor for " << t2 << "; but " << t1->dim() << " " - << "does not equal " << t2->dim() - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t1->dim() == t2->dim(), + "Expected tensor for ", t1, " to have the same dimension as tensor for ", + t2, "; but ", t1->dim(), " does not equal ", t2->dim(), + " (while checking arguments for ", c, ")"); } void checkDefined(CheckedFrom c, const TensorArg& t) { - if (!t->defined()) { - std::ostringstream oss; - oss << "Expected tensor for " << t << " to be non-null, " - << "but it was undefined " - << " (while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t->defined(), + "Expected tensor for ", t, " to be non-null, but it was undefined ", + " (while checking arguments for ", c, ")"); } void checkAllDefined(CheckedFrom c, ArrayRef ts) { @@ -227,13 +195,11 @@ void checkAllDefined(CheckedFrom c, ArrayRef ts) { } void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) { - if (t.type().backend() != backend) { - std::ostringstream oss; - oss << "Expected tensor to have " << toString(backend) << " Backend, but got tensor with " - << toString(t.type().backend()) << " Backend " - << "(while checking arguments for " << c << ")"; - throw std::runtime_error(oss.str()); - } + AT_CHECK( + t.type().backend() == backend, + "Expected tensor to have ", toString(backend), + " Backend, but got tensor with ", toString(t.type().backend()), " Backend ", + "(while checking arguments for ", c, ")"); } void checkBackend(CheckedFrom c, ArrayRef tensors, at::Backend backend) { diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 88800f6863a39c..18b562130ce9d4 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -10,6 +10,16 @@ #include #include +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_vptr__ __attribute__((no_sanitize("vptr"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_vptr__ +#endif + namespace at { AT_API int _crash_if_asan(int); diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index 9820b57746eef5..4e119bd79f72a0 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -4,6 +4,8 @@ #include #include +#include "ATen/Utils.h" + #if defined(__GNUC__) #define __at_align32__ __attribute__((aligned(32))) #elif defined(_WIN32) @@ -173,7 +175,7 @@ template Vec256 operator*(const Vec256 &a, const Vec256 &b) { return c; } -template Vec256 operator/(const Vec256 &a, const Vec256 &b) { +template Vec256 operator/(const Vec256 &a, const Vec256 &b) __ubsan_ignore_float_divide_by_zero__ { Vec256 c = Vec256(); for (int i = 0; i != Vec256::size; i++) { c.values[i] = a.values[i] / b.values[i]; diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index ceb24c646d96ae..7a5e6e40760ec8 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -234,6 +234,11 @@ def __init__(self, reason): 'long': 'int64_t', } +NATIVE_DYNAMIC_TYPE = { + 'Tensor &': 'Tensor', + 'const Tensor &': 'Tensor', +} + TYPE_RETURN = { 'THTensor*': 'Tensor', 'THIndexTensor*': 'Tensor', @@ -871,13 +876,13 @@ def insert(argument): # not clear we need dynamic_type translation as we can specify the correct type # directly in native functions - def add_type_as_dynamic_type(argument, option): + def add_dynamic_type(argument, option): # type: (AtFormal, FunctionOption) -> AtFormal - argument['dynamic_type'] = argument['type'] + argument['dynamic_type'] = NATIVE_DYNAMIC_TYPE.get(argument['type'], argument['type']) return argument result = pos_args + kwd_args - result = [add_type_as_dynamic_type(argument, option) for argument in result] + result = [add_dynamic_type(argument, option) for argument in result] # ensure we get reference-type formals when appropriate def native_translate_formals(argument, option): @@ -928,7 +933,7 @@ def native_get_return_types(option): rtype = { 'type': actual_return_type, - 'dynamic_type': t, + 'dynamic_type': NATIVE_DYNAMIC_TYPE.get(t, t), } # type: ReturnType if name is not None: rtype['name'] = name diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index 9be5060cc826a4..7eb1501f35c5e2 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -18,7 +18,6 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding", indices_arg, kLong); - checkContiguous("embedding", indices_arg); // TODO: use tensor.index() after improving perf if (indices.dim() == 1) { @@ -29,7 +28,7 @@ Tensor embedding(const Tensor & weight, const Tensor & indices, for (auto d : weight.sizes().slice(1)) { size.push_back(d); } - return weight.index_select(0, indices.view(-1)).view(size); + return weight.index_select(0, indices.reshape(-1)).view(size); } Tensor embedding_backward( @@ -50,7 +49,6 @@ Tensor embedding_sparse_backward( auto indices_arg = TensorArg(indices_, "indices", 2); checkScalarType("embedding_backward", indices_arg, kLong); - checkContiguous("embedding_backward", indices_arg); // TODO: implement scale_grad_by_freq if (scale_grad_by_freq) { @@ -77,20 +75,20 @@ Tensor embedding_sparse_backward( dense_type.tensor(), weight_size); } - auto index = indices.view({1, -1}); - auto values = grad.contiguous().view({-1, num_features}); + auto index = indices.reshape({1, -1}); + auto values = grad.reshape({-1, num_features}); return sparse_type._sparse_coo_tensor_unsafe(index, values, weight_size); } -Tensor embedding_backward_cpu( +Tensor embedding_dense_backward_cpu( const Tensor & grad_, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { auto indices_arg = TensorArg(indices, "indices", 2); checkScalarType("embedding_backward", indices_arg, kLong); - checkContiguous("embedding_backward", indices_arg); - auto indices_data = indices.data(); + auto indices_contig = indices.contiguous(); + auto indices_data = indices_contig.data(); int64_t numel = indices.numel(); std::unique_ptr counts; @@ -105,7 +103,7 @@ Tensor embedding_backward_cpu( } auto grad = grad_.contiguous().view({numel, grad_.size(-1)}); - auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.type()); + auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); #ifdef _OPENMP if (numel > 1000) { @@ -154,13 +152,13 @@ Tensor & embedding_renorm_cpu_( Tensor & self, const Tensor & indices, double max_norm, double norm_type) { auto self_arg = TensorArg(self, "self", 1); auto indices_arg = TensorArg(indices, "indices", 2); - checkContiguous("embedding_renorm_", self_arg); checkDim("embedding_renorm_", self_arg, 2); - checkContiguous("embedding_renorm_", indices_arg); checkScalarType("embedding_renorm_", indices_arg, kLong); + auto indices_contig = indices.contiguous(); + auto num_indices = indices.numel(); - auto data_ptr = indices.data(); + auto data_ptr = indices_contig.data(); auto sorted_indices = std::vector(data_ptr, data_ptr + num_indices); std::sort(sorted_indices.begin(), sorted_indices.end(), std::less()); diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 77d8f220cc75d6..d1718931ba00cf 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -45,15 +45,20 @@ static void index_select_add(const Tensor &select_indices, auto output_data = output.data(); auto numel = add_indices.numel(); int64_t ddim = src.size(1); + auto src_stride0 = src.stride(0); + auto src_stride1 = src.stride(1); + auto output_stride0 = output.stride(0); + auto output_stride1 = output.stride(1); for (int64_t i = 0; i < numel; i++) { - THBlas_axpy(ddim, 1, src_data + ddim * select_indices_data[i], 1, - output_data + ddim * add_indices_data[i], 1); + THBlas_axpy(ddim, 1, + src_data + src_stride0 * select_indices_data[i], src_stride1, + output_data + output_stride0 * add_indices_data[i], output_stride1); } } static void make_bag_size(const Tensor &offsets, const Tensor &indices, const int64_t mode, Tensor &bag_size) { - if (mode == 1 || mode == 2) { + if (mode == MODE_MEAN || mode == MODE_MAX) { // Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards) if (offsets.size(0) != 1) { bag_size.slice(0, 0, bag_size.size(0) - 1, 1) = @@ -67,7 +72,7 @@ static void make_bag_size(const Tensor &offsets, const Tensor &indices, static Tensor apply_bag_size(const Tensor &offsets, const Tensor &indices, const int64_t mode, Tensor &output, const Tensor &bag_size) { - if (mode == 1) { // MODE_MEAN + if (mode == MODE_MEAN) { if (offsets.size(0) == 1) { auto bag_size_ = indices.size(0); output /= bag_size_; @@ -88,7 +93,7 @@ static Tensor apply_bag_size_backward(const Tensor &offsets, const Tensor &indices, const int64_t mode, Tensor &output, const Tensor &offset2bag, const Tensor &bag_size) { - if (mode == 1) { // MODE_MEAN + if (mode == MODE_MEAN) { if (offsets.size(0) == 1) { auto bag_size_ = indices.size(0); output /= bag_size_; @@ -119,7 +124,8 @@ std::tuple embedding_bag_cpu_max( auto weight_data = weight.data(); auto output_data = output.data(); - auto weight_stride = weight.stride(0); + auto weight_stride0 = weight.stride(0); + auto weight_stride1 = weight.stride(1); auto output_stride = output.stride(0); for (int i = 0; i < numel; i++) { @@ -129,7 +135,7 @@ std::tuple embedding_bag_cpu_max( for (int dim = 0; dim < dims; dim++) { auto& current_item = output_data[output_stride * bag + dim]; - auto weight_item = weight_data[weight_stride * word_idx + dim]; + auto weight_item = weight_data[weight_stride0 * word_idx + dim * weight_stride1]; bool is_first_for_bag = (i == 0) || offset2bag_data[i - 1] != bag; @@ -143,16 +149,27 @@ std::tuple embedding_bag_cpu_max( return std::tuple(output, offset2bag, bag_size, max_indices); } +// embedding_bag wrapper to enforce contiguity in tensors other than `weight`. +// This is created to save extra `.contiguous()` call in backward. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details std::tuple -embedding_bag_cpu(const Tensor &weight, const Tensor &indices__, - const Tensor &offsets__, const bool scale_grad_by_freq, +embedding_bag(const Tensor &weight, const Tensor &indices, + const Tensor &offsets, const bool scale_grad_by_freq, + const int64_t mode, bool sparse) { + return at::_embedding_bag(weight, indices.contiguous(), offsets.contiguous(), + scale_grad_by_freq, mode, sparse); + }; + +// Assumes all input tensors except for `weight` are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details +std::tuple +_embedding_bag_cpu(const Tensor &weight, const Tensor &indices, + const Tensor &offsets, const bool scale_grad_by_freq, const int64_t mode, bool sparse) { - auto indices_arg = TensorArg(indices__, "indices__", 1); + auto indices_arg = TensorArg(indices, "indices", 1); + checkScalarType("embedding_bag", indices_arg, kLong); + auto offsets_arg = TensorArg(offsets, "offsets", 1); checkScalarType("embedding_bag", indices_arg, kLong); - auto offsets_arg = TensorArg(offsets__, "offsets__", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); - Tensor indices = indices__.contiguous(); - Tensor offsets = offsets__.contiguous(); auto weight_arg = TensorArg(weight, "weight", 1); checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble}); @@ -164,13 +181,13 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__, // throw out of bounds error. So to keep it simple we just add one more // entry to the end then get rid of it after make_offset2bag. auto offset2bag = at::zeros( - {indices.sizes()[0] + 1}, indices__.type()); // offset2bag = [0 0 0 0 0] + {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] make_offset2bag(offsets, indices, offset2bag); offset2bag.resize_({indices.sizes()[0]}); - auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.type()); + auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options()); if (mode == MODE_MEAN || mode == MODE_SUM) { if (weight.type().scalarType() == kFloat) { @@ -189,53 +206,51 @@ embedding_bag_cpu(const Tensor &weight, const Tensor &indices__, } } -Tensor embedding_bag_backward(const Tensor &grad_, const Tensor &indices__, - const Tensor &offsets__, - const Tensor &offset2bag__, +// Assumes all input tensors are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details +Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, + const Tensor &offsets, + const Tensor &offset2bag, const Tensor &bag_size_, const Tensor &max_indices_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) { - auto indices_arg = TensorArg(indices__, "indices__", 1); + auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag", indices_arg, kLong); - auto offsets_arg = TensorArg(offsets__, "offsets__", 1); + checkContiguous("embedding_bag", indices_arg); + auto offsets_arg = TensorArg(offsets, "offsets", 1); checkScalarType("embedding_bag", offsets_arg, kLong); - auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1); + checkContiguous("embedding_bag", offsets_arg); + auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); checkScalarType("embedding_bag", offset2bag_arg, kLong); checkContiguous("embedding_bag", offset2bag_arg); - Tensor indices = indices__.contiguous(); - Tensor offsets = offsets__.contiguous(); if (sparse) { - return at::embedding_bag_sparse_backward( - grad_, indices, offsets, offset2bag__, bag_size_, num_weights, + return at::_embedding_bag_sparse_backward( + grad, indices, offsets, offset2bag, bag_size_, num_weights, scale_grad_by_freq, mode); } else { - return at::embedding_bag_dense_backward( - grad_, indices, offsets, offset2bag__, bag_size_, max_indices_, num_weights, + return at::_embedding_bag_dense_backward( + grad, indices, offsets, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode); } } -Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__, - const Tensor &offsets__, +Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_, + const Tensor &offsets_, const Tensor &offset2bag__, const Tensor &bag_size_, const Tensor& max_indices_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) { + // indices_, offsets_ and offset2bag__ are assumed having correct dtypes and + // contiguous here due to the checks in _embedding_bag_backward above. + // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml + // for more details. + auto grad = grad_.contiguous(); auto grad_arg = TensorArg(grad, "grad_", 1); checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble}); - auto indices_arg = TensorArg(indices__, "indices__", 1); - checkScalarType("embedding_bag", indices_arg, kLong); - auto offsets_arg = TensorArg(offsets__, "offsets__", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); - auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); - checkContiguous("embedding_bag", offset2bag_arg); - Tensor indices_ = indices__.contiguous(); - Tensor offsets_ = offsets__.contiguous(); Tensor &offset2bag_ = const_cast(offset2bag__); @@ -320,19 +335,15 @@ Tensor embedding_bag_backward_cpu(const Tensor &grad_, const Tensor &indices__, return index_grad_weight; } -Tensor embedding_bag_sparse_backward( - const Tensor &grad_, const Tensor &indices__, const Tensor &offsets__, - const Tensor &offset2bag__, const Tensor &bag_size_, int64_t num_weights, + +Tensor _embedding_bag_sparse_backward( + const Tensor &grad_, const Tensor &indices, const Tensor &offsets, + const Tensor &offset2bag, const Tensor &bag_size_, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) { - auto indices_arg = TensorArg(indices__, "indices__", 1); - checkScalarType("embedding_bag", indices_arg, kLong); - auto offsets_arg = TensorArg(offsets__, "offsets__", 1); - checkScalarType("embedding_bag", offsets_arg, kLong); - auto offset2bag_arg = TensorArg(offset2bag__, "offset2bag__", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); - Tensor indices = indices__.contiguous(); - Tensor offsets = offsets__.contiguous(); - Tensor offset2bag = offset2bag__.contiguous(); + // indices, offsets and offset2bag are assumed having correct dtypes and + // contiguous here due to the checks in _embedding_bag_backward above. + // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml + // for more details. Tensor grad = grad_; Tensor index_grad = grad_.index_select(0, offset2bag); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index ff7b20bde3c17e..ea87d42dfa58f0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -19,11 +19,7 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor p.squeeze_(0); lu.squeeze_(0); int int_info = info.squeeze_().toCInt(); - if (int_info < 0) { - std::ostringstream ss; - ss << "LU factorization (getrf) failed with info = " << int_info; - throw std::runtime_error(ss.str()); - } + AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info); auto n = self.size(0); auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0); if (num_exchanges % 2 == 1) { @@ -34,13 +30,10 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor } Tensor det(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D " - << "square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), + "det(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U; int info; @@ -53,13 +46,10 @@ Tensor det(const Tensor& self) { } Tensor logdet(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "logdet(" << self.type() << "{" << self.sizes() << "}): expected a " - << "2D square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), + "logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U, det; int info; @@ -77,13 +67,10 @@ Tensor logdet(const Tensor& self) { } std::tuple slogdet(const Tensor& self) { - if (!at::isFloatingType(self.type().scalarType()) || - self.dim() != 2 || self.size(0) != self.size(1)) { - std::ostringstream ss; - ss << "slogdet(" << self.type() << "{" << self.sizes() << "}): expected a " - << "2D square tensor of floating types"; - throw std::runtime_error(ss.str()); - } + AT_CHECK(at::isFloatingType(self.type().scalarType()) && + self.dim() == 2 && self.size(0) == self.size(1), + "slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " + "of floating types"); double det_P; Tensor diag_U, det; int info; @@ -96,10 +83,19 @@ std::tuple slogdet(const Tensor& self) { return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); } +Tensor pinverse(const Tensor& self, double rcond) { + AT_CHECK(at::isFloatingType(self.type().scalarType()) && self.dim() == 2, + "pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " + "of floating types"); + Tensor U, S, V; + std::tie(U, S, V) = self.svd(); + double max_val = S[0].toCDouble(); + Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options())); + return V.mm(S_pseudoinv.diag().mm(U.t())); +} + static void check_1d(const Tensor& t, const char* arg, const char* fn) { - if (t.dim() != 1) { - AT_ERROR(fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); - } + AT_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D"); } Tensor ger(const Tensor& self, const Tensor& vec2) { diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 12569cee7cd7a6..ded00828b4e63c 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -55,13 +55,15 @@ Tensor batch_norm( if (use_cudnn && eps >= detail::getCUDAHooks().batchnormMinEpsilonCuDNN()) { return std::get<0>(at::cudnn_batch_norm( - input, weight, bias, - running_mean, running_var, + input.contiguous(), weight.contiguous(), + bias.contiguous(), + running_mean.defined() ? running_mean.contiguous() : running_mean, + running_var.defined() ? running_var.contiguous() : running_var, training, momentum, eps)); } return at::thnn_batch_norm( - input, weight, bias, + input.contiguous(), weight, bias, running_mean, running_var, training, momentum, eps); } diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 51e319288218f8..9202a0f9f7d802 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -85,10 +85,80 @@ Tensor _s_where_cpu(const Tensor& condition, const Tensor& self, const Tensor& o return ret; } +std::tuple kthvalue(const Tensor& self, int64_t k, int64_t dim, bool keepdim) { + Tensor values = self.type().tensor(); + Tensor indices = self.type().toScalarType(kLong).tensor(); + return at::native::kthvalue_out(values, indices, self, k, dim, keepdim); +} + +std::tuple kthvalue_out(Tensor& values, Tensor& indices, + const Tensor& self, int64_t k, int64_t dim, bool keepdim) { + AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, + "kthvalue only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend())); + dim = maybe_wrap_dim(dim, self.dim()); + return at::_th_kthvalue_out(values, indices, self, k, dim, keepdim); +} + +std::tuple median(const Tensor& self, int64_t dim, bool keepdim) { + Tensor values = self.type().tensor(); + Tensor indices = self.type().toScalarType(kLong).tensor(); + return at::native::median_out(values, indices, self, dim, keepdim); +} + +std::tuple median_out(Tensor& values, Tensor& indices, + const Tensor& self, int64_t dim, bool keepdim) { + AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, + "median only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend())); + dim = maybe_wrap_dim(dim, self.dim()); + return at::_th_median_out(values, indices, self, dim, keepdim); +} + +std::tuple mode(const Tensor& self, int64_t dim, bool keepdim) { + Tensor values = self.type().tensor(); + Tensor indices = self.type().toScalarType(kLong).tensor(); + return at::native::mode_out(values, indices, self, dim, keepdim); +} + +std::tuple mode_out(Tensor& values, Tensor& indices, + const Tensor& self, int64_t dim, bool keepdim) { + AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, + "mode only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend())); + dim = maybe_wrap_dim(dim, self.dim()); + return at::_th_mode_out(values, indices, self, dim, keepdim); +} + +std::tuple max(const Tensor& self, int64_t dim, bool keepdim) { + Tensor max = self.type().tensor(); + Tensor max_indices = self.type().toScalarType(kLong).tensor(); + return at::native::max_out(max, max_indices, self, dim, keepdim); +} + +std::tuple max_out(Tensor& max, Tensor& max_indices, + const Tensor& self, int64_t dim, bool keepdim) { + AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, + "max only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend())); + dim = maybe_wrap_dim(dim, self.dim()); + return at::_th_max_out(max, max_indices, self, dim, keepdim); +} + Tensor max_values(const Tensor& self, int64_t dim, bool keepdim) { return std::get<0>(self.max(dim, keepdim)); } +std::tuple min(const Tensor& self, int64_t dim, bool keepdim) { + Tensor min = self.type().tensor(); + Tensor min_indices = self.type().toScalarType(kLong).tensor(); + return at::native::min_out(min, min_indices, self, dim, keepdim); +} + +std::tuple min_out(Tensor& min, Tensor& min_indices, + const Tensor& self, int64_t dim, bool keepdim) { + AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, + "min only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend())); + dim = maybe_wrap_dim(dim, self.dim()); + return at::_th_min_out(min, min_indices, self, dim, keepdim); +} + Tensor min_values(const Tensor& self, int64_t dim, bool keepdim) { return std::get<0>(self.min(dim, keepdim)); } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 179dced2f601c4..337b80f96b39af 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -7,6 +7,7 @@ #include #include +#include namespace at { namespace native { @@ -604,5 +605,30 @@ int64_t numel(const Tensor& self) { return self.pImpl->numel(); } +std::vector meshgrid(TensorList tensors) { + int64_t size = tensors.size(); + AT_CHECK(size > 0, "meshgrid expects a non-empty TensorList"); + std::vector shape(size); + for(int64_t i = 0; i < size; i++) { + switch (tensors[i].dim()) { + case 0: + shape[i] = 1; + break; + case 1: + shape[i] = tensors[i].size(0); + break; + default: + AT_ERROR("Expected scalar or 1D tensor in the tensor list but got: ", tensors[i]); + } + } + std::vector grids; + for(int64_t i = 0; i < size; i++) { + std::vector view_shape(size, 1); + view_shape[i] = -1; + grids.push_back(tensors[i].view(view_shape).expand(shape)); + } + return grids; +} + } } diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp index 797ca89f1483a5..8bce12cac2a691 100644 --- a/aten/src/ATen/native/TensorTransformations.cpp +++ b/aten/src/ATen/native/TensorTransformations.cpp @@ -19,7 +19,7 @@ Tensor flip_cpu(const Tensor& self, IntList dims) { auto indices = std::vector(flip_dims_size); for (int64_t i = 0; i < flip_dims_size; i++) { - indices[i] = at::arange(self.size(i) - 1, -1, -1, self.type().toScalarType(at::kLong)); + indices[i] = at::arange(self.size(flip_dims_v[i]) - 1, -1, -1, self.type().toScalarType(at::kLong)); // creates a meshgrid auto temp = std::vector(flip_dims_size, 1); temp[i] = indices[i].size(0); diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 96f648c681ab2f..affe20d71c7914 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -163,18 +163,19 @@ __global__ void embedding_backward_kernel( template __global__ void renorm_kernel( scalar_t* weights, int64_t* indices, accscalar_t max_norm, - accscalar_t norm_type, int dim) { + accscalar_t norm_type, int64_t dim, + int64_t weights_stride0, int64_t weights_stride1) { // Some casting hacks since dynamic shared memory and templates don't work together: extern __shared__ unsigned char smem[]; auto sdata = reinterpret_cast(smem); int tid = threadIdx.x; - int base_index = indices[blockIdx.x] * dim; + int base_index = indices[blockIdx.x] * weights_stride0; accscalar_t v = 0; for (int i = tid; i < dim; i += blockDim.x) { - auto x = static_cast(weights[base_index + i]); + auto x = static_cast(weights[base_index + i * weights_stride1]); if (norm_type == 1) { v += std::abs(x); } else if (norm_type == 2) { @@ -196,30 +197,31 @@ __global__ void renorm_kernel( if (sdata[0] > max_norm) { auto factor = static_cast(max_norm / (sdata[0] + 1e-7)); for (int i = tid; i < dim; i += blockDim.x) { - weights[base_index + i] *= factor; + weights[base_index + i * weights_stride1] *= factor; } } } } // anonymous namespace -Tensor embedding_backward_cuda(const Tensor & grad_, const Tensor & indices, +Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { auto grad_arg = TensorArg(grad_, "grad", 1); auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_backward", indices_arg, kLong); - checkContiguous("embedding_backward", indices_arg); checkSameGPU("embedding_backward", grad_arg, indices_arg); auto num_indices = indices.numel(); auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)}); - auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.type()); + auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); int64_t stride = grad_weight.stride(0); cudaStream_t stream = globalContext().getCurrentCUDAStream(); if (num_indices <= 768 && !scale_grad_by_freq) { + auto indices_contig = indices.contiguous(); + dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE)); dim3 block(WARP_SIZE, BLOCKDIMY); @@ -234,7 +236,7 @@ Tensor embedding_backward_cuda(const Tensor & grad_, const Tensor & indices, block, sizeof(accscalar_t)*WARP_SIZE*BLOCKDIMY + sizeof(int)*WARP_SIZE*BLOCKDIMY, stream>>> - (indices.data(), + (indices_contig.data(), grad.data(), grad_weight.data(), num_indices, @@ -246,8 +248,8 @@ Tensor embedding_backward_cuda(const Tensor & grad_, const Tensor & indices, return grad_weight; } - auto sorted_indices = indices.type().tensor(indices.sizes()); - auto orig_indices = indices.type().tensor(indices.sizes()); + auto sorted_indices = at::empty_like(indices); + auto orig_indices = at::empty_like(indices); using device_ptr = thrust::device_ptr; // Sort the inputs into sorted with the corresponding indices; we @@ -272,7 +274,7 @@ Tensor embedding_backward_cuda(const Tensor & grad_, const Tensor & indices, Tensor count; if (scale_grad_by_freq) { - count = indices.type().tensor(indices.sizes()); + count = at::empty_like(indices); auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); auto policy = thrust::cuda::par(allocator).on(stream); @@ -327,8 +329,6 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, double max_norm, double norm_type) { auto self_arg = TensorArg(self, "self", 1); auto indices_arg = TensorArg(indices, "indices", 1); - checkContiguous("embedding_renorm_", self_arg); - checkContiguous("embedding_renorm", indices_arg); checkDim("embedding_renorm_", self_arg, 2); checkSameGPU("embedding_renorm", self_arg, indices_arg); @@ -339,7 +339,8 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, using device_ptr = thrust::device_ptr; auto num_indices = indices.numel(); - auto indices_data = device_ptr(indices.data()); + auto indices_contig = indices.contiguous(); + auto indices_data = device_ptr(indices_contig.data()); // FIXME: thrust::unique only removes consecutive elements that are equal. // We have race conditions when indices contain duplicates which are not @@ -360,7 +361,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices, unique_indices.data(), static_cast(max_norm), static_cast(norm_type), - dim); + dim, self.stride(0), self.stride(1)); }); THCudaCheck(cudaGetLastError()); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 075dbbc2fc1661..9169cb0375c552 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -23,25 +23,27 @@ namespace native { namespace { +// This kernel assumes that all input tensors except `weight` are contiguous. template __global__ void EmbeddingBag_updateOutputKernel( int64_t *input, int64_t *offsets, scalar_t *weight, scalar_t *output, - int64_t *offset2bag, int64_t numIndices, int64_t numBags, int64_t stride, + int64_t *offset2bag, int64_t numIndices, int64_t numBags, + int64_t featureSize, int64_t weight_stide0, int64_t weight_stride1, int mode, int64_t *bag_size, int64_t *max_indices) { // the strategy here is that each bag x feature is handled by a single thread using accscalar_t = acc_type; - int64_t chunksPerBag = THCCeilDiv(stride, (int64_t)blockDim.x); + int64_t chunksPerBag = THCCeilDiv(featureSize, (int64_t)blockDim.x); int64_t numChunks = numBags * chunksPerBag; int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; int64_t chunkStride = gridDim.x * blockDim.y; for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; - if (featureDim < stride) { + if (featureDim < featureSize) { int64_t bag = chunk / chunksPerBag; - scalar_t *weightFeat = weight + featureDim; + scalar_t *weightFeat = weight + featureDim * weight_stride1; int64_t begin = offsets[bag]; int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; assert(end >= begin); @@ -52,7 +54,7 @@ __global__ void EmbeddingBag_updateOutputKernel( int64_t bag_size_ = 0; int64_t maxWord = -1; for (int64_t emb = begin; emb < end; emb++) { - const int weightRow = ((int)input[emb]) * stride; + const int64_t weightRow = input[emb] * weight_stide0; scalar_t weightValue = weightFeat[weightRow]; if (mode == MODE_MAX) { @@ -75,11 +77,11 @@ __global__ void EmbeddingBag_updateOutputKernel( } if (mode == MODE_MEAN || mode == MODE_SUM) { - output[bag * stride + featureDim] = static_cast(weightFeatSum); + output[bag * featureSize + featureDim] = static_cast(weightFeatSum); } else if (mode == MODE_MAX) { - max_indices[bag * stride + featureDim] = maxWord; - output[bag * stride + featureDim] = weightFeatMax; + max_indices[bag * featureSize + featureDim] = maxWord; + output[bag * featureSize + featureDim] = weightFeatMax; } } } @@ -90,6 +92,7 @@ __global__ void EmbeddingBag_updateOutputKernel( // does not need EmbeddingBag (LookupTable + Sum works fine), but would // still be nice to not be slow in that case. +// This kernel assumes that all input tensors are contiguous. template __global__ void EmbeddingBag_accGradParametersKernel_sum_avg( int64_t *input, int64_t *indices, scalar_t *gradOutput, @@ -298,40 +301,39 @@ Tensor embedding_bag_backward_cuda_max(const Tensor &grad, } } +// Assumes all input tensors are contiguous. +// See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details std::tuple -embedding_bag_cuda(const Tensor &weight, const Tensor &indices, +_embedding_bag_cuda(const Tensor &weight, const Tensor &indices, const Tensor &offsets, const bool scale_grad_by_freq, const int64_t mode, bool sparse) { auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag_cuda", indices_arg, kLong); - checkContiguous("embedding_bag_cuda", indices_arg); auto offsets_arg = TensorArg(offsets, "offsets", 1); checkScalarType("embedding_bag_cuda", offsets_arg, kLong); - checkContiguous("embedding_bag_cuda", offsets_arg); auto weight_arg = TensorArg(weight, "weight", 1); - checkContiguous("embedding_bag_cuda", weight_arg); checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); int64_t numIndices = indices.size(0); int64_t numBags = offsets.size(0); - int64_t stride = weight.size(1); + int64_t featureSize = weight.size(1); - auto bag_size = at::zeros(offsets.sizes(), indices.type()); + auto bag_size = at::zeros(offsets.sizes(), indices.options()); auto offset2bag = - at::zeros({indices.size(0)}, indices.type()); // offset2bag = [0 0 0 0 0] + at::zeros({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0] cudaStream_t stream = globalContext().getCurrentCUDAStream(); - auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.type()); + auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options()); Tensor max_indices; if (mode == MODE_MAX) { - max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.type()); + max_indices = at::zeros({offsets.size(0), weight.size(1)}, indices.options()); } else { // No need to allocate if we aren't doing a backwards pass - max_indices = at::zeros({0}, indices.type()); + max_indices = at::zeros({0}, indices.options()); } dim3 block = dim3(32, 8); @@ -340,30 +342,32 @@ embedding_bag_cuda(const Tensor &weight, const Tensor &indices, EmbeddingBag_updateOutputKernel<<>>( indices.data(), offsets.data(), weight.data(), output.data(), - offset2bag.data(), numIndices, numBags, stride, mode, - bag_size.data(), mode == MODE_MAX ? max_indices.data() : NULL); + offset2bag.data(), numIndices, numBags, featureSize, + weight.stride(0), weight.stride(1), mode, bag_size.data(), + mode == MODE_MAX ? max_indices.data() : NULL); }); THCudaCheck(cudaGetLastError()); return std::tuple(output, offset2bag, bag_size, max_indices); } -Tensor embedding_bag_backward_cuda(const Tensor &grad_, const Tensor &indices, +Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices, const Tensor &offsets, const Tensor &offset2bag, const Tensor &bag_size_, const Tensor &max_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) { + // indices, offsets and offset2bag are assumed having correct dtypes and + // contiguous here due to the checks in _embedding_bag_backward in + // EmbeddingBag.cpp. + // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml + // for more details. + Tensor grad = grad_.contiguous(); auto indices_arg = TensorArg(indices, "indices", 1); - checkScalarType("embedding_bag_cuda", indices_arg, kLong); - checkContiguous("embedding_bag_cuda", indices_arg); auto offsets_arg = TensorArg(offsets, "offsets", 1); - checkScalarType("embedding_bag_cuda", offsets_arg, kLong); - checkContiguous("embedding_bag_cuda", offsets_arg); auto grad_arg = TensorArg(grad, "grad", 1); - checkContiguous("embedding_bag_cuda", grad_arg); checkSameGPU("embedding_bag_cuda", grad_arg, offsets_arg); checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index bbfc49d1ad4c04..f753304838758d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -451,8 +451,8 @@ - func: embedding_dense_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor variants: function dispatch: - CPU: embedding_backward_cpu - CUDA: embedding_backward_cuda + CPU: embedding_dense_backward_cpu + CUDA: embedding_dense_backward_cuda - func: embedding_renorm_(Tensor self, IndexTensor indices, double max_norm, double norm_type) -> Tensor variants: function @@ -463,23 +463,35 @@ - func: embedding_sparse_backward(Tensor grad, IndexTensor indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) -> Tensor variants: function +# NOTE [ embedding_bag Native Functions ] +# The `_embedding_bag.*` variants assume that input tensors except for `weight`, +# e.g. `indices` and `offsets` (and `offset2bag`), are contiguous. +# We really only need to enforce this for `_embedding_bag` (the forward) because +# the backward inputs are the same as forward ones. +# The above `embedding_bag` wrapper is created to achieve this, e.g., +# applying indices = indices.contiguous(). +# The backward functions apply a check that these input tensors are contiguous. + - func: embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor, Tensor) variants: function + +- func: _embedding_bag(Tensor weight, IndexTensor indices, IndexTensor offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false) -> (Tensor, Tensor, Tensor, Tensor) + variants: function dispatch: - CPU: embedding_bag_cpu - CUDA: embedding_bag_cuda + CPU: _embedding_bag_cpu + CUDA: _embedding_bag_cuda -- func: embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor +- func: _embedding_bag_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse) -> Tensor variants: function -- func: embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor +- func: _embedding_bag_sparse_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor variants: function -- func: embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor +- func: _embedding_bag_dense_backward(Tensor grad, IndexTensor indices, IndexTensor offsets, IndexTensor offset2bag, IndexTensor bag_size, IndexTensor maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode) -> Tensor variants: function dispatch: - CPU: embedding_bag_backward_cpu - CUDA: embedding_bag_backward_cuda + CPU: _embedding_bag_dense_backward_cpu + CUDA: _embedding_bag_dense_backward_cuda - func: empty(IntList size, TensorOptions options={}) -> Tensor variants: function @@ -703,6 +715,11 @@ - func: is_sparse(Tensor self) -> bool device_guard: false +- func: kthvalue(Tensor self, int64_t k, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) + +- func: kthvalue_out(Tensor values, Tensor indices, Tensor self, int64_t k, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) + variants: function + - func: layer_norm(Tensor input, IntList normalized_shape, Tensor? weight={}, Tensor? bias={}, double eps=1e-5, bool cudnn_enable=True) -> Tensor variants: function @@ -819,6 +836,11 @@ - func: matmul_out(Tensor result, Tensor self, Tensor other) -> Tensor variants: function +- func: max(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + +- func: max_out(Tensor max, Tensor max_values, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + variants: function + - func: max_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor - func: max_pool1d_with_indices(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor) @@ -853,6 +875,16 @@ - func: mean_out(Tensor result, Tensor self, int64_t dim, *, ScalarType dtype) -> Tensor variants: function +- func: median(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + +- func: median_out(Tensor values, Tensor indices, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + variants: function + +- func: min(Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + +- func: min_out(Tensor min, Tensor min_indices, Tensor self, int64_t dim, bool keepdim=false) -> (Tensor, Tensor) + variants: function + - func: min_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor - func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation) -> Tensor @@ -872,6 +904,11 @@ - func: mm_out(Tensor result, Tensor self, Tensor mat2) -> Tensor variants: function +- func: mode(Tensor self, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) + +- func: mode_out(Tensor values, Tensor indices, Tensor self, int64_t dim=-1, bool keepdim=false) -> (Tensor, Tensor) + variants: function + - func: mv(Tensor self, Tensor vec) -> Tensor - func: mv_out(Tensor result, Tensor self, Tensor vec) -> Tensor @@ -903,6 +940,8 @@ - func: pin_memory(Tensor self) -> Tensor +- func: pinverse(Tensor self, double rcond=1e-15) -> Tensor + - func: rand(IntList size, *, TensorOptions options={}) -> Tensor variants: function @@ -1803,3 +1842,6 @@ - func: get_device(Tensor self) -> int64_t device_guard: False + +- func: meshgrid(TensorList tensors) -> TensorList + variants: function diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 7554e82b56f47c..0cac9bcb9131fa 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -330,6 +330,9 @@ SparseTensor& sparse_mask_out_cpu(SparseTensor& r, const Tensor& t, const Sparse AT_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced"); AT_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ", t.sizes(), " but mask has size ", mask.sizes()); + AT_ASSERT(!t.is_cuda()); // we were supposed to have dispatched on this + AT_CHECK(!r.is_cuda(), "sparse_mask: expected 'out' to be CPU, but got CUDA"); + AT_CHECK(!mask.is_cuda(), "sparse_mask: expected 'mask' to be CPU, but got CUDA"); resize_as_sparse_(r, mask); if (mask._nnz() == 0) { r.zero_(); diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index d84f9476a314ee..4a25665c1e4fea 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -107,7 +107,7 @@ SparseTensor& log1p_out_sparse(SparseTensor& r, const SparseTensor& t) { if (isSameTensor(r, t)) { // don't have in-place log1p for uncoalesced input because coalesce() is not in-place AT_CHECK( - r.is_coalesced(), "in-place log1p on uncoalesced tensors is not supported yet!"); + r.is_coalesced(), "log1p: in-place on uncoalesced tensors is not supported yet!"); } else { r = raw_copy_sparse_(r, t.coalesce()); @@ -117,7 +117,7 @@ SparseTensor& log1p_out_sparse(SparseTensor& r, const SparseTensor& t) { } SparseTensor& log1p_sparse_(SparseTensor& t) { - AT_CHECK(t.is_coalesced(), "in-place log1p on uncoalesced tensors is not supported yet!"); + AT_CHECK(t.is_coalesced(), "log1p: in-place on uncoalesced tensors is not supported yet!"); return log1p_out_sparse(t, t); } @@ -130,7 +130,7 @@ SparseTensor& log1p_sparse_(SparseTensor& t) { SparseTensor& pow_out_sparse_scalar(SparseTensor& r, const SparseTensor& t_, Scalar value) { AT_ASSERT(r.is_sparse()); AT_ASSERT(t_.is_sparse()); - AT_CHECK(value.toDouble() != 0, "cannot raise to zeroth power on sparse tensor; it would make the result tensor dense"); + AT_CHECK(value.toDouble() != 0, "pow: cannot raise to zeroth power on sparse tensor; it would make the result tensor dense"); // This coalesce is why we can't easily provide an inplace variant SparseTensor t = t_.coalesce(); @@ -202,8 +202,11 @@ Tensor norm_sparse(const SparseTensor& self, Scalar value) { SparseTensor& s_add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { AT_ASSERT(r.is_sparse()); AT_ASSERT(t.is_sparse()); + AT_ASSERT(!t.is_cuda()); // the dispatch argument + AT_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor"); - AT_CHECK(t.sizes().equals(src.sizes()), "cadd operands have incompatible sizes"); + AT_CHECK(t.sizes().equals(src.sizes()), "add: expected sizes of 'self' and 'other' to match, but ", t.sizes(), " != ", src.sizes()); if (src._nnz() == 0) { return raw_copy_sparse_(r, t); @@ -212,7 +215,7 @@ SparseTensor& s_add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const return mul_out_sparse_scalar(r, src, value); } - AT_CHECK(_is_same_density(t, src), "cadd operands have incompatible desnitities"); + AT_CHECK(_is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t._sparseDims(), " sparse dimensions while 'other' has ", src._sparseDims(), " sparse dimensions"); // saving those because they can be overwritten when doing in-place operations int64_t t_nnz = t._nnz(), s_nnz = src._nnz(), max_nnz = t_nnz + s_nnz; @@ -329,11 +332,19 @@ void add_dense_sparse_worker_cpu(Tensor& r, Scalar value, const SparseTensor& sp } Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, SparseTensorRef sparse__, Scalar value) { + const SparseTensor& sparse_ = sparse__.tref; + AT_ASSERT(!r.is_sparse()); AT_ASSERT(!dense.is_sparse()); - AT_ASSERT(sparse__.tref.is_sparse()); + AT_ASSERT(sparse_.is_sparse()); + + AT_ASSERT(!dense.is_cuda()); // dispatch argument + AT_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!sparse_.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor"); + + AT_CHECK(dense.sizes().equals(sparse_.sizes()), "add: expected 'self' and 'other' to have same size, but self has size ", + dense.sizes(), " while other has size ", sparse_.sizes(), " (FYI: dense-sparse addition does not currently support broadcasting)"); - const SparseTensor& sparse_ = sparse__.tref; r.resize_as_(dense); SparseTensor sparse = sparse_.coalesce(); @@ -381,6 +392,10 @@ Tensor& add_dense_sparse_cpu_(Tensor& t, SparseTensorRef src, Scalar alpha) { // -------------------------------------------------------------------- SparseTensor& s_sub_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { + AT_ASSERT(!t.is_cuda()); // dispatch argument + AT_CHECK(!r.is_cuda(), "sub: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!src.is_cuda(), "sub: expected 'other' to be a CPU tensor, but got a CUDA tensor"); + // UGH... We're doing two dispatches on scalar type here for no good reason. // NB: I tried adding an operator- to Scalar, but there isn't any good way // to negate the tensor, because I have a TensorBase... @@ -409,6 +424,12 @@ SparseTensor& s_sub_sparse_cpu_(SparseTensor& t, const SparseTensor& src, Scalar SparseTensor& s_mul_out_sparse_cpu(SparseTensor& r, const SparseTensor& t_, const SparseTensor& src_) { AT_CHECK(t_.sizes().equals(src_.sizes()), "mul operands have incompatible sizes"); + AT_ASSERT(!t_.is_cuda()); // dispatch argument + AT_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor"); + + AT_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same sizes, but ", t_.sizes(), " != ", src_.sizes()); + if (src_._nnz() == 0 || t_._nnz() == 0) { return r.zero_(); } @@ -549,7 +570,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, dense_ptr + col * dense_stride0, dense_stride1, r_ptr + h * r_stride0, r_stride1); } else { - AT_ERROR("index out of bound. spmm: ", col, " not between 1 and ", dim_j); + AT_ERROR("addmm: index out of bound: ", col, " not between 1 and ", dim_j); } } } @@ -564,10 +585,15 @@ Tensor& s_addmm_out_sparse_dense_cpu( Scalar alpha ) { // TODO: This error message seems awfully opaque - AT_CHECK(sparse_._sparseDims() == 2, "matrices expected, got ", sparse_._sparseDims(), "D tensor"); - AT_CHECK(sparse_._denseDims() == 0, "scalar values expected, got ", sparse_._denseDims(), "D values"); - AT_CHECK(dense.numel() != 0, "matrices expected, got empty tensor"); - AT_CHECK(dense.dim() == 2, "matrices expected, got ", dense.dim(), "D tensor"); + AT_ASSERT(!t.is_cuda()); + AT_CHECK(!r.is_cuda(), "addmm: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!sparse_.is_cuda(), "addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor"); + AT_CHECK(!dense.is_cuda(), "addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor"); + + AT_CHECK(sparse_._sparseDims() == 2, "addmm: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + AT_CHECK(sparse_._denseDims() == 0, "addmm: scalar values expected, got ", sparse_._denseDims(), "D values"); + AT_CHECK(dense.numel() != 0, "addmm: matrices expected, got empty tensor"); + AT_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor"); SparseTensor sparse = sparse_.coalesce(); @@ -576,14 +602,14 @@ Tensor& s_addmm_out_sparse_dense_cpu( int64_t dim_j = sparse.size(1); int64_t dim_k = dense.size(1); - r.resize_({dim_i, dim_k}); - AT_CHECK(dense.size(0) == dim_j, - "Argument #3 (dense): Expected dim 0 size ", dim_j, ", got ", dense.size(0)); + "addmm: Argument #3 (dense): Expected dim 0 size ", dim_j, ", got ", dense.size(0)); AT_CHECK(t.size(0) == dim_i, - "Argument #1 (t): Expected dim 0 size ", dim_i, ", got ", t.size(0)); + "addmm: Argument #1 (t): Expected dim 0 size ", dim_i, ", got ", t.size(0)); AT_CHECK(t.size(1) == dim_k, - "Argument #1 (t): Expected dim 1 size ", dim_k, ", got ", t.size(1)); + "addmm: Argument #1 (t): Expected dim 1 size ", dim_k, ", got ", t.size(1)); + + r.resize_({dim_i, dim_k}); int64_t nnz = sparse._nnz(); @@ -636,19 +662,25 @@ Tensor& s_addmm_sparse_dense_cpu_( SparseTensor& hspmm_out_sparse_cpu(SparseTensor& r, const SparseTensor& sparse_, const Tensor& dense) { // TODO: Make this a real argument Scalar alpha = 1; + + AT_ASSERT(!sparse_.is_cuda()); // dispatch argument + AT_CHECK(!r.is_cuda(), "hspmm: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!dense.is_cuda(), "hspmm: expected 'other' to be a CPU tensor, but got a CUDA tensor"); + AT_CHECK(sparse_._sparseDims() == 2, - "Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + "hspmm: Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); AT_CHECK(sparse_._denseDims() == 0, - "Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); + "hspmm: Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); AT_CHECK(dense.dim() == 2, - "Argument #3: matrices expected, got ", dense.dim(), "D tensor"); + "hspmm: Argument #3: matrices expected, got ", dense.dim(), "D tensor"); int64_t m = sparse_.size(0); int64_t k = sparse_.size(1); int64_t n = dense.size(1); AT_CHECK(dense.size(0) == k, - "Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0)); + "hspmm: Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0)); + _get_sparse_impl(r)->raw_resize_(1, 1, {m, n}); SparseTensor sparse = sparse_.coalesce(); @@ -711,12 +743,17 @@ SparseTensor& _sspaddmm_out_cpu( Scalar beta, Scalar alpha ) { + AT_ASSERT(!t.is_cuda()); // dispatch argument + AT_CHECK(!r.is_cuda(), "sspaddmm: expected 'out' to be CPU tensor, but got CUDA tensor"); + AT_CHECK(!sparse_.is_cuda(), "sspaddmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor"); + AT_CHECK(!dense.is_cuda(), "sspaddmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor"); + AT_CHECK(sparse_._sparseDims() == 2, - "Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + "sspaddmm: Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); AT_CHECK(sparse_._denseDims() == 0, - "Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); + "sspaddmm: Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); AT_CHECK(dense.dim() == 2, - "Argument #2: matrices expected, got ", dense.dim(), "D tensor"); + "sspaddmm: Argument #2: matrices expected, got ", dense.dim(), "D tensor"); SparseTensor sparse = sparse_.coalesce(); @@ -725,14 +762,16 @@ SparseTensor& _sspaddmm_out_cpu( int64_t dim_j = sparse.size(1); int64_t dim_k = dense.size(1); + // NB: This has to occur before the checks, because r may alias t. + // See test_saddmm r.sparse_raw_resize_({dim_i, dim_k}, 2, 0); AT_CHECK(dense.size(0) == dim_j, - "Argument #3: Expected dim 0 size ", dim_j, ", got ", dense.size(0)); + "sspaddmm: Argument #3: Expected dim 0 size ", dim_j, ", got ", dense.size(0)); AT_CHECK(t.size(0) == dim_i, - "Argument #1: Expected dim 0 size ", dim_i, ", got ", t.size(0)); + "sspaddmm: Argument #1: Expected dim 0 size ", dim_i, ", got ", t.size(0)); AT_CHECK(t.size(1) == dim_k, - "Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1)); + "sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1)); int64_t nnz = sparse._nnz(); LongTensor indices = sparse._indices(); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp index ff25867255fb04..68ab33ae6d5466 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp @@ -9,6 +9,9 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars AT_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced"); AT_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ", t.sizes(), " but mask has size ", mask.sizes()); + AT_ASSERT(t.is_cuda()); // dispatch argument + AT_CHECK(mask.is_cuda(), "sparse_mask: expected 'mask' to be CUDA, but got CPU"); + AT_CHECK(r.is_cuda(), "sparse_mask: expected 'out' to be CUDA, but got CPU"); AT_CHECK(_check_device({r, t, mask}), "sparse_mask: arguments are located on different devices; self is on device ", t.get_device(), ", mask is on device ", mask.get_device(), ", out is on device ", r.get_device()); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index f56cb83e706024..eb369113dc899b 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -43,29 +43,31 @@ namespace { Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) { #ifndef __HIP_PLATFORM_HCC__ + AT_ASSERT(t.is_cuda()); // dispatch argument + AT_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU"); + AT_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU"); + AT_CHECK(dense.is_cuda(), "addmm: expected 'mat2' to be CUDA, but got CPU"); + AT_CHECK(_check_device({sparse_, r_, t, dense})); - // THCudaIntTensor *csr; - // THCIndexTensor *indices; - // THCTensor *values, *r__, *dense_; // TODO: This error message seems awfully opaque - AT_CHECK(sparse_._sparseDims() == 2, "matrices expected, got ", sparse_._sparseDims(), "D tensor"); - AT_CHECK(sparse_._denseDims() == 0, "scalar values expected, got ", sparse_._denseDims(), "D values"); - AT_CHECK(dense.dim() == 2, "matrices expected, got ", dense.dim(), "D tensor"); + AT_CHECK(sparse_._sparseDims() == 2, "addmm: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + AT_CHECK(sparse_._denseDims() == 0, "addmm: scalar values expected, got ", sparse_._denseDims(), "D values"); + AT_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor"); // mxk * kxn = mxn int64_t m = sparse_.size(0); int64_t k = sparse_.size(1); int64_t n = dense.size(1); - r_.resize_({m, n}); - AT_CHECK(t.size(0) == m, - "Argument #1 (t): Expected dim 0 size ", m, ", got ", t.size(0)); + "addmm: Argument #1 (t): Expected dim 0 size ", m, ", got ", t.size(0)); AT_CHECK(t.size(1) == n, - "Argument #1 (t): Expected dim 1 size ", n, ", got ", t.size(1)); + "addmm: Argument #1 (t): Expected dim 1 size ", n, ", got ", t.size(1)); AT_CHECK(dense.size(0) == k, - "Argument #3 (dense): Expected dim 0 size ", k, ", got ", dense.size(0)); + "addmm: Argument #3 (dense): Expected dim 0 size ", k, ", got ", dense.size(0)); + + r_.resize_({m, n}); SparseTensor sparse = sparse_.coalesce(); @@ -174,27 +176,32 @@ Tensor& s_addmm_sparse_dense_cuda_( SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) { #ifndef __HIP_PLATFORM_HCC__ - cudaStream_t stream = globalContext().getCurrentCUDAStream(); - auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); - auto policy = thrust::cuda::par(allocator).on(stream); + AT_ASSERT(sparse_.is_cuda()); // dispatch argument + AT_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU"); + AT_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU"); AT_CHECK(_check_device({r_, sparse_, dense})); AT_CHECK(sparse_._sparseDims() == 2, - "Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + "hspmm: Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); AT_CHECK(sparse_._denseDims() == 0, - "Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); + "hspmm: Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); AT_CHECK(dense.dim() == 2, - "Argument #3: matrices expected, got ", dense.dim(), "D tensor"); + "hspmm: Argument #3: matrices expected, got ", dense.dim(), "D tensor"); int64_t m = sparse_.size(0); int64_t k = sparse_.size(1); int64_t n = dense.size(1); AT_CHECK(dense.size(0) == k, - "Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0)); + "hspmm: Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0)); + _get_sparse_impl(r_)->raw_resize_(1, 1, {m, n}); + cudaStream_t stream = globalContext().getCurrentCUDAStream(); + auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA()); + auto policy = thrust::cuda::par(allocator).on(stream); + SparseTensor sparse = sparse_.coalesce(); int64_t nnz = sparse._nnz(); @@ -240,8 +247,15 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR #ifndef __HIP_PLATFORM_HCC__ const SparseTensor& sparse = sparse_.tref; + AT_ASSERT(dense.is_cuda()); // dispatch argument + AT_CHECK(sparse.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); + AT_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); + AT_CHECK(_check_device({sparse, r_, dense})); + AT_CHECK(dense.sizes().equals(sparse.sizes()), "add: expected 'self' and 'other' to have same size, but self has size ", + dense.sizes(), " while other has size ", sparse.sizes(), " (FYI: dense-sparse addition does not currently support broadcasting)"); + const int64_t nnz = sparse._nnz(); if (nnz == 0) { r_.resize_as_(dense); @@ -254,7 +268,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR r_.resize_as_(dense); r_.copy_(dense); } else { - AT_CHECK(r_.is_contiguous(), "CUDA dense-sparse addition known bug"); + AT_CHECK(r_.is_contiguous(), "add: CUDA dense-sparse addition with a non-contiguous output tensor does not work; shout if you need it (see https://github.com/pytorch/pytorch/issues/1521 )"); r = r_.contiguous(); } @@ -271,7 +285,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR cudaGetDevice(&curDevice); cudaStream_t stream = globalContext().getCurrentCUDAStreamOnDevice(curDevice); if (sparse._denseDims() == 0) { - AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "Argument #0: tensor too large or too many dimensions"); + AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); AT_DISPATCH_ALL_TYPES_AND_HALF( values.type(), "add_out_dense_sparse_cuda", [&] { @@ -282,7 +296,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR static_cast(nnz)); }); } else { - AT_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "Argument #0: tensor too large or too many dimensions"); + AT_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); AT_DISPATCH_ALL_TYPES_AND_HALF( values.type(), "add_out_dense_sparse_cuda", [&] { @@ -294,8 +308,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR }); } } else { - LongTensor indices1D = _newFlattenedIndices(sparse, 0); - indices1D.resize_({nnz}); + LongTensor indices1D = _newFlattenedIndices(sparse, 0).squeeze_(0).narrow(0, 0, nnz); // FIXME: at some point we can wrap the scale into indexAdd // NB: Purposely not inplace! @@ -316,7 +329,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, SparseTensorR } Tensor r_view = r.view({view_rows, view_columns}); - values.resize_({nnz, view_columns}); + values = values.narrow(0, 0, nnz).reshape({nnz, view_columns}); r_view.index_add_(0, indices1D, values); } THCudaCheck(cudaGetLastError()); @@ -343,8 +356,12 @@ Tensor& add_dense_sparse_cuda_(Tensor& t, SparseTensorRef src, Scalar alpha) { SparseTensor& s_add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) { #ifndef __HIP_PLATFORM_HCC__ + AT_ASSERT(t.is_cuda()); // dispatch argument + AT_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); + AT_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); + AT_CHECK(_check_device({r_, t, src})); - AT_CHECK(t.sizes().equals(src.sizes()), "cadd operands have incompatible sizes"); + AT_CHECK(t.sizes().equals(src.sizes()), "add: expected 'self' and 'other' to have same size, but ", t.sizes(), " != ", src.sizes()); if (src._nnz() == 0) { return raw_copy_sparse_(r_, t); @@ -353,7 +370,7 @@ SparseTensor& s_add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, con return mul_out_sparse_scalar(r_, src, value); } - AT_CHECK(_is_same_density(t, src), "cadd operands have incompatible desnitities"); + AT_CHECK(_is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t._sparseDims(), " sparse dimensions while 'other' has ", src._sparseDims(), " sparse dimensions"); // We deliberately choose to simply concat the indices and values tensors // rather than merging them. This removes the need to synchronously fetch nnz @@ -405,6 +422,10 @@ SparseTensor& s_add_sparse_cuda_(SparseTensor& t, const SparseTensor& src, Scala // -------------------------------------------------------------------- SparseTensor& s_sub_out_sparse_cuda(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { + AT_ASSERT(t.is_cuda()); // dispatch argument + AT_CHECK(src.is_cuda(), "sub: expected 'other' to be CUDA, but got CPU"); + AT_CHECK(r.is_cuda(), "sub: expected 'out' to be CUDA, but got CPU"); + AT_DISPATCH_ALL_TYPES( t.type(), "sub_sparse", [&] { scalar_t cast_value = value.to(); @@ -430,8 +451,12 @@ SparseTensor& s_sub_sparse_cuda_(SparseTensor& t, const SparseTensor& src, Scala SparseTensor& s_mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, const SparseTensor& src_) { #ifndef __HIP_PLATFORM_HCC__ + AT_ASSERT(t_.is_cuda()); // dispatch argument + AT_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU"); + AT_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); + AT_CHECK(_check_device({r_, t_, src_})); - AT_CHECK(t_.sizes().equals(src_.sizes()), "mul operands have incompatible sizes"); + AT_CHECK(t_.sizes().equals(src_.sizes()), "mul: expected 'self' and 'other' to have same size, but ", t_.sizes(), " != ", src_.sizes()); SparseTensor t = t_.coalesce(); SparseTensor src = src_.coalesce(); @@ -459,7 +484,7 @@ SparseTensor& s_mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, co int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = globalContext().getCurrentCUDAStreamOnDevice(curDevice); - AT_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "Argument #0: tensor too large or too many dimensions"); + AT_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions"); LongTensor resultNnz = at::empty({1}, CUDA(kLong)); AT_DISPATCH_ALL_TYPES_AND_HALF( diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index c40309638580a7..13d852d2f1e148 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -124,12 +124,14 @@ def run(paths): fn_name, arguments = func_decl.split('(') arguments = arguments.split(')')[0] declaration['name'] = func.get('name', fn_name) - declaration['return'] = list(func.get('return', return_type)) + return_type = list(func.get('return', return_type)) + arguments = parse_arguments(arguments, func, declaration['name'], return_type) + output_arguments = [x for x in arguments if x.get('output')] + declaration['return'] = return_type if len(output_arguments) == 0 else output_arguments declaration['variants'] = func.get('variants', ['method', 'function']) declaration['deprecated'] = func.get('deprecated', False) declaration['device_guard'] = func.get('device_guard', True) - declaration['arguments'] = func.get('arguments', parse_arguments(arguments, func, - declaration['name'], declaration['return'])) + declaration['arguments'] = func.get('arguments', arguments) declaration['type_method_definition_dispatch'] = func.get('dispatch', declaration['name']) declaration['aten_sparse'] = has_sparse_dispatches( declaration['type_method_definition_dispatch']) diff --git a/aten/src/TH/THGeneral.h.in b/aten/src/TH/THGeneral.h.in index a7abd684c54d4d..9038dfb2b10929 100644 --- a/aten/src/TH/THGeneral.h.in +++ b/aten/src/TH/THGeneral.h.in @@ -56,6 +56,12 @@ # define TH_UNUSED #endif +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#endif + #ifndef M_PI # define M_PI 3.14159265358979323846 #endif diff --git a/aten/src/TH/THStorage.cpp b/aten/src/TH/THStorage.cpp index a73681b7384f28..aa99e0d19e3d1a 100644 --- a/aten/src/TH/THStorage.cpp +++ b/aten/src/TH/THStorage.cpp @@ -17,17 +17,16 @@ void THStorage_free(THStorage *storage) { AT_ASSERT(storage->backend == at::kCPU); - if(!storage) + if (!storage) { return; + } - if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount.load() > 0)) - { - if(--storage->refcount == 0) - { - if(storage->flag & TH_STORAGE_FREEMEM) { + if ((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount.load() > 0)) { + if (--storage->refcount == 0) { + if (storage->flag & TH_STORAGE_FREEMEM) { static_cast(storage->allocatorVoidPtr)->free(storage->allocatorContext, storage->data_ptr); } - if(storage->flag & TH_STORAGE_VIEW) { + if (storage->flag & TH_STORAGE_VIEW) { THStorage_free(storage->view); } storage->refcount.~atomic(); @@ -96,3 +95,139 @@ THStorage* THStorage_newWithAllocator(at::ScalarType scalar_type, ptrdiff_t size storage->device = INT_MIN; // device is not meaningful on CPU return storage; } + +ptrdiff_t THStorage_size(const THStorage *self) +{ + return self->size; +} + +size_t THStorage_elementSize(const THStorage *self) +{ + return at::elementSize(self->scalar_type); +} + +THStorage* THStorage_newWithMapping(at::ScalarType scalar_type, const char *filename, ptrdiff_t size, int flags) +{ + THMapAllocatorContext *ctx = THMapAllocatorContext_new(filename, flags); + + THStorage *storage = THStorage_newWithAllocator(scalar_type, size, + &THMapAllocator, + ctx); + + if (size <= 0) { + storage->size = THMapAllocatorContext_size(ctx)/THStorage_elementSize(storage); + } + + THStorage_clearFlag(storage, TH_STORAGE_RESIZABLE); + + return storage; +} + +void THStorage_setFlag(THStorage *storage, const char flag) +{ + storage->flag |= flag; +} + +void THStorage_clearFlag(THStorage *storage, const char flag) +{ + storage->flag &= ~flag; +} + +void THStorage_retain(THStorage *storage) +{ + if (storage && (storage->flag & TH_STORAGE_REFCOUNTED)) { + ++storage->refcount; + } +} + +int THStorage_retainIfLive(THStorage *storage) +{ + // TODO: Check if TH_STORAGE_REFCOUNTED? + int refcount = storage->refcount.load(); + while (refcount > 0) { + if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) { + return 1; + } + refcount = storage->refcount.load(); + } + return 0; +} + +THStorage* THStorage_newWithData(at::ScalarType scalar_type, void *data, ptrdiff_t size) +{ + return THStorage_newWithDataAndAllocator(scalar_type, data, size, + &THDefaultAllocator, NULL); +} + +THStorage* THStorage_newWithDataAndAllocator(at::ScalarType scalar_type, + void* data, ptrdiff_t size, + THAllocator* allocator, + void* allocatorContext) { + THStorage *storage = static_cast(THAlloc(sizeof(THStorage))); + storage->backend = at::kCPU; + storage->scalar_type = scalar_type; + storage->data_ptr = data; + storage->size = size; + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + storage->allocatorVoidPtr = allocator; + storage->allocatorContext = allocatorContext; + storage->device = 0; + return storage; +} + +void THStorage_resize(THStorage *storage, ptrdiff_t size) +{ + AT_ASSERT(storage->backend == at::kCPU); + + auto* th_allocator = static_cast(storage->allocatorVoidPtr); + + if (storage->flag & TH_STORAGE_RESIZABLE) + { + if (th_allocator->realloc == nullptr) { + /* case when the allocator does not have a realloc defined */ + void *old_data = storage->data_ptr; + ptrdiff_t old_size = storage->size; + if (size == 0) { + storage->data_ptr = nullptr; + } else { + storage->data_ptr = th_allocator->malloc( + storage->allocatorContext, + at::elementSize(storage->scalar_type)*size); + } + storage->size = size; + if (old_data != nullptr) { + ptrdiff_t copy_size = old_size; + if (storage->size < copy_size) { + copy_size = storage->size; + } + if (copy_size > 0) { + memcpy(storage->data_ptr, old_data, at::elementSize(storage->scalar_type)*copy_size); + } + th_allocator->free(storage->allocatorContext, old_data); + } + } else { + storage->data_ptr = th_allocator->realloc( + storage->allocatorContext, + storage->data_ptr, + at::elementSize(storage->scalar_type)*size); + storage->size = size; + } + } else { + THError("Trying to resize storage that is not resizable"); + } +} + +void THStorage_swap(THStorage *storage1, THStorage *storage2) +{ +#define SWAP(val) { std::swap(storage1->val, storage2->val); } + SWAP(data_ptr); + SWAP(size); + SWAP(flag); + // don't swap refcount! + SWAP(allocatorVoidPtr); + SWAP(allocatorContext); + SWAP(view); + SWAP(device); +#undef SWAP +} diff --git a/aten/src/TH/THStorage.h b/aten/src/TH/THStorage.h index 9c00bffb791ef8..954def8143c776 100644 --- a/aten/src/TH/THStorage.h +++ b/aten/src/TH/THStorage.h @@ -1,5 +1,4 @@ -#ifndef TH_STORAGE_INC -#define TH_STORAGE_INC +#pragma once #include "THGeneral.h" #include "THAllocator.h" @@ -23,5 +22,3 @@ TH_API void THStorage_free(THStorage *storage); TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size); TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement); - -#endif diff --git a/aten/src/TH/THStorage.hpp b/aten/src/TH/THStorage.hpp index d5a68ec181a664..f179d248b11d72 100644 --- a/aten/src/TH/THStorage.hpp +++ b/aten/src/TH/THStorage.hpp @@ -44,3 +44,18 @@ TH_API THStorage* THStorage_newWithSize(at::ScalarType scalar_type, ptrdiff_t si TH_API THStorage* THStorage_newWithAllocator(at::ScalarType scalar_type, ptrdiff_t size, THAllocator *allocator, void *allocatorContext); + +ptrdiff_t THStorage_size(const THStorage *self); +size_t THStorage_elementSize(); +THStorage* THStorage_newWithMapping(at::ScalarType scalar_type, const char *filename, ptrdiff_t size, int flags); +void THStorage_setFlag(THStorage *storage, const char flag); +void THStorage_clearFlag(THStorage *storage, const char flag); +void THStorage_retain(THStorage *storage); +int THStorage_retainIfLive(THStorage *storage); +THStorage* THStorage_newWithData(at::ScalarType scalar_type, void *data, ptrdiff_t size); +THStorage* THStorage_newWithDataAndAllocator(at::ScalarType scalar_type, + void* data, ptrdiff_t size, + THAllocator* allocator, + void* allocatorContext); +void THStorage_resize(THStorage *storage, ptrdiff_t size); +void THStorage_swap(THStorage *storage1, THStorage *storage2); diff --git a/aten/src/TH/generic/THStorage.cpp b/aten/src/TH/generic/THStorage.cpp index 7b163cbbfe491a..66e3a625a4a205 100644 --- a/aten/src/TH/generic/THStorage.cpp +++ b/aten/src/TH/generic/THStorage.cpp @@ -11,7 +11,7 @@ real* THStorage_(data)(const THStorage *self) ptrdiff_t THStorage_(size)(const THStorage *self) { - return self->size; + return THStorage_size(self); } size_t THStorage_(elementSize)() @@ -39,18 +39,7 @@ THStorage* THStorage_(newWithAllocator)(ptrdiff_t size, THStorage* THStorage_(newWithMapping)(const char *filename, ptrdiff_t size, int flags) { - THMapAllocatorContext *ctx = THMapAllocatorContext_new(filename, flags); - - THStorage *storage = THStorage_(newWithAllocator)(size, - &THMapAllocator, - ctx); - - if(size <= 0) - storage->size = THMapAllocatorContext_size(ctx)/sizeof(real); - - THStorage_(clearFlag)(storage, TH_STORAGE_RESIZABLE); - - return storage; + return THStorage_newWithMapping(at::CTypeToScalarType>::to(), filename, size, flags); } THStorage* THStorage_(newWithSize1)(real data0) @@ -93,31 +82,22 @@ THStorage* THStorage_(newWithSize4)(real data0, real data1, real data2, real dat void THStorage_(setFlag)(THStorage *storage, const char flag) { - storage->flag |= flag; + THStorage_setFlag(storage, flag); } void THStorage_(clearFlag)(THStorage *storage, const char flag) { - storage->flag &= ~flag; + THStorage_clearFlag(storage, flag); } void THStorage_(retain)(THStorage *storage) { - if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) - ++storage->refcount; + THStorage_retain(storage); } int THStorage_(retainIfLive)(THStorage *storage) { - // TODO: Check if TH_STORAGE_REFCOUNTED? - int refcount = storage->refcount.load(); - while (refcount > 0) { - if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) { - return 1; - } - refcount = storage->refcount.load(); - } - return 0; + return THStorage_retainIfLive(storage); } void THStorage_(free)(THStorage *storage) @@ -127,66 +107,18 @@ void THStorage_(free)(THStorage *storage) THStorage* THStorage_(newWithData)(real *data, ptrdiff_t size) { - return THStorage_(newWithDataAndAllocator)(data, size, - &THDefaultAllocator, NULL); + return THStorage_newWithData(at::CTypeToScalarType>::to(), data, size); } THStorage* THStorage_(newWithDataAndAllocator)(real* data, ptrdiff_t size, THAllocator* allocator, void* allocatorContext) { - THStorage *storage = static_cast(THAlloc(sizeof(THStorage))); - storage->backend = at::kCPU; - storage->scalar_type = at::CTypeToScalarType>::to(); - storage->data_ptr = data; - storage->size = size; - storage->refcount = 1; - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocatorVoidPtr = allocator; - storage->allocatorContext = allocatorContext; - storage->device = 0; - return storage; + return THStorage_newWithDataAndAllocator(at::CTypeToScalarType>::to(), data, size, allocator, allocatorContext); } void THStorage_(resize)(THStorage *storage, ptrdiff_t size) { - AT_ASSERT(storage->backend == at::kCPU); - - auto* th_allocator = static_cast(storage->allocatorVoidPtr); - - if(storage->flag & TH_STORAGE_RESIZABLE) - { - if(th_allocator->realloc == NULL) { - /* case when the allocator does not have a realloc defined */ - real *old_data = THStorage_(data)(storage); - ptrdiff_t old_size = storage->size; - if (size == 0) { - storage->data_ptr = NULL; - } else { - storage->data_ptr = th_allocator->malloc( - storage->allocatorContext, - sizeof(real)*size); - } - storage->size = size; - if (old_data != NULL) { - ptrdiff_t copy_size = old_size; - if (storage->size < copy_size) { - copy_size = storage->size; - } - if (copy_size > 0) { - memcpy(THStorage_(data)(storage), old_data, sizeof(real)*copy_size); - } - th_allocator->free(storage->allocatorContext, old_data); - } - } else { - storage->data_ptr = th_allocator->realloc( - storage->allocatorContext, - THStorage_(data)(storage), - sizeof(real)*size); - storage->size = size; - } - } else { - THError("Trying to resize storage that is not resizable"); - } + return THStorage_resize(storage, size); } void THStorage_(fill)(THStorage *storage, real value) @@ -210,24 +142,7 @@ real THStorage_(get)(const THStorage *self, ptrdiff_t idx) void THStorage_(swap)(THStorage *storage1, THStorage *storage2) { -#define SWAP(val) { val = storage1->val; storage1->val = storage2->val; storage2->val = val; } - void *data_ptr; - ptrdiff_t size; - char flag; - void *allocatorVoidPtr; - void *allocatorContext; - struct THStorage *view; - int device; - - SWAP(data_ptr); - SWAP(size); - SWAP(flag); - // don't swap refcount! - SWAP(allocatorVoidPtr); - SWAP(allocatorContext); - SWAP(view); - SWAP(device); -#undef SWAP + THStorage_swap(storage1, storage2); } #endif diff --git a/aten/src/TH/vector/AVX.cpp b/aten/src/TH/vector/AVX.cpp index 5fbc8bd806e652..b39b803c86c695 100644 --- a/aten/src/TH/vector/AVX.cpp +++ b/aten/src/TH/vector/AVX.cpp @@ -6,6 +6,7 @@ #endif #include "AVX.h" +#include "THGeneral.h" void THDoubleVector_copy_AVX(double *y, const double *x, const ptrdiff_t n) { ptrdiff_t i; @@ -36,7 +37,7 @@ void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n) { } } -void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) { +void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { ptrdiff_t i; __m256d YMM0, YMM1, YMM2, YMM3; for (i=0; i<=((n)-8); i+=8) { @@ -54,7 +55,7 @@ void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const } } -void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n) { +void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { ptrdiff_t i; __m256d YMM15 = _mm256_set_pd(c, c, c, c); __m256d YMM0, YMM1; @@ -168,7 +169,7 @@ void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n) { } } -void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n) { +void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { ptrdiff_t i; __m256 YMM0, YMM1, YMM2, YMM3; for (i=0; i<=((n)-16); i+=16) { @@ -186,7 +187,7 @@ void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrd } } -void THFloatVector_divs_AVX(float *y, const float *x, const float c, const ptrdiff_t n) { +void THFloatVector_divs_AVX(float *y, const float *x, const float c, const ptrdiff_t n) __ubsan_ignore_float_divide_by_zero__ { ptrdiff_t i; __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); __m256 YMM0, YMM1; diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 933079bfa08ffb..90feef38fa1ec8 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -59,6 +59,7 @@ set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} INSTALL(FILES THC.h ${CMAKE_CURRENT_BINARY_DIR}/THCGeneral.h + THCGeneral.hpp THCBlas.h THCSleep.h THCStorage.h diff --git a/aten/src/THC/THCStorage.cpp b/aten/src/THC/THCStorage.cpp index 93d58bd379bf88..161cefd49d12c4 100644 --- a/aten/src/THC/THCStorage.cpp +++ b/aten/src/THC/THCStorage.cpp @@ -60,12 +60,6 @@ THCStorage* THCStorage_newWithAllocator(THCState *state, return storage; } -void THCStorage_retain(THCState *state, THCStorage *self) -{ - if(self && (self->flag & TH_STORAGE_REFCOUNTED)) - self->refcount++; -} - void THCStorage_free(THCState *state, THCStorage *self) { AT_ASSERT(self->backend == at::kCUDA); @@ -163,3 +157,35 @@ void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size) int THCStorage_getDevice(THCState* state, const THCStorage* storage) { return storage->device; } + +THCStorage* THCStorage_newWithData(THCState *state, at::ScalarType scalar_type, void *data, ptrdiff_t size) +{ + return THCStorage_newWithDataAndAllocator(state, scalar_type, data, size, + state->cudaDeviceAllocator, + state->cudaDeviceAllocator->state); +} + +THCStorage* THCStorage_newWithDataAndAllocator( + THCState *state, at::ScalarType scalar_type, void *data, ptrdiff_t size, + THCDeviceAllocator *allocator, void *allocatorContext) { + THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); + memset(storage, 0, sizeof(THCStorage)); + storage->backend = at::kCUDA; + storage->scalar_type = scalar_type; + storage->data_ptr = data; + storage->size = size; + storage->refcount = 1; + storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; + storage->allocatorVoidPtr = allocator; + storage->allocatorContext = allocatorContext; + int device; + if (data) { + struct cudaPointerAttributes attr; + THCudaCheck(cudaPointerGetAttributes(&attr, data)); + device = attr.device; + } else { + THCudaCheck(cudaGetDevice(&device)); + } + storage->device = device; + return storage; +} diff --git a/aten/src/THC/THCStorage.hpp b/aten/src/THC/THCStorage.hpp index f626d75c9daf16..10fe9a7025b3d3 100644 --- a/aten/src/THC/THCStorage.hpp +++ b/aten/src/THC/THCStorage.hpp @@ -33,3 +33,8 @@ THC_API void THCStorage_free(THCState *state, THCStorage *self); THC_API void THCStorage_resize(THCState *state, THCStorage *storage, ptrdiff_t size); THC_API int THCStorage_getDevice(THCState* state, const THCStorage* storage); + +THC_API THCStorage* THCStorage_newWithData(THCState *state, at::ScalarType scalar_type, void *data, ptrdiff_t size); +THC_API THCStorage* THCStorage_newWithDataAndAllocator( + THCState *state, at::ScalarType scalar_type, void *data, ptrdiff_t size, + THCDeviceAllocator *allocator, void *allocatorContext); diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp index 4dfa763b6bd26a..598a6128a727b7 100644 --- a/aten/src/THC/THCTensor.cpp +++ b/aten/src/THC/THCTensor.cpp @@ -187,7 +187,7 @@ void THCTensor_setStorageNd(THCState *state, THCTensor *self, THCStorage *storag if(storage) { self->storage = storage; - THCStorage_retain(state, self->storage); + THStorage_retain(self->storage); } else self->storage = THCStorage_new(state, scalar_type); diff --git a/aten/src/THC/generic/THCStorage.cpp b/aten/src/THC/generic/THCStorage.cpp index 6be2aa9b009468..7e8ea52658ec72 100644 --- a/aten/src/THC/generic/THCStorage.cpp +++ b/aten/src/THC/generic/THCStorage.cpp @@ -9,7 +9,7 @@ real* THCStorage_(data)(THCState *state, const THCStorage *self) ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage *self) { - return self->size; + return THStorage_size(self); } int THCStorage_(elementSize)(THCState *state) @@ -98,62 +98,33 @@ THCStorage* THCStorage_(newWithMapping)(THCState *state, const char *fileName, p THCStorage* THCStorage_(newWithData)(THCState *state, real *data, ptrdiff_t size) { - return THCStorage_(newWithDataAndAllocator)(state, data, size, - state->cudaDeviceAllocator, - state->cudaDeviceAllocator->state); + return THCStorage_newWithData(state, at::CTypeToScalarType::to(), data, size); } THCStorage* THCStorage_(newWithDataAndAllocator)( THCState *state, real *data, ptrdiff_t size, THCDeviceAllocator *allocator, void *allocatorContext) { - THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); - memset(storage, 0, sizeof(THCStorage)); - storage->backend = at::kCUDA; - storage->scalar_type = at::CTypeToScalarType::to(); - storage->data_ptr = data; - storage->size = size; - storage->refcount = 1; - storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; - storage->allocatorVoidPtr = allocator; - storage->allocatorContext = allocatorContext; - int device; - if (data) { - struct cudaPointerAttributes attr; - THCudaCheck(cudaPointerGetAttributes(&attr, data)); - device = attr.device; - } else { - THCudaCheck(cudaGetDevice(&device)); - } - storage->device = device; - return storage; + return THCStorage_newWithDataAndAllocator(state, at::CTypeToScalarType::to(), data, size, allocator, allocatorContext); } void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const char flag) { - storage->flag |= flag; + THStorage_setFlag(storage, flag); } void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag) { - storage->flag &= ~flag; + THStorage_clearFlag(storage, flag); } void THCStorage_(retain)(THCState *state, THCStorage *self) { - THCStorage_retain(state, self); + THStorage_retain(self); } int THCStorage_(retainIfLive)(THCState *state, THCStorage *storage) { - // TODO: Check if THC_STORAGE_REFCOUNTED? - int refcount = storage->refcount.load(); - while (refcount > 0) { - if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) { - return 1; - } - refcount = storage->refcount.load(); - } - return 0; + return THStorage_retainIfLive(storage); } void THCStorage_(free)(THCState *state, THCStorage *self) diff --git a/caffe2/core/dispatch/OpSchema_test.cpp b/caffe2/core/dispatch/OpSchema_test.cpp index 03e0f36f4fd206..77936a0347a041 100644 --- a/caffe2/core/dispatch/OpSchema_test.cpp +++ b/caffe2/core/dispatch/OpSchema_test.cpp @@ -22,8 +22,3 @@ static_assert(6 == OpSchema::signature::num_args, "test num_dispatch_ static_assert(3 == OpSchema::signature::num_tensor_args, "test num_dispatch_args"); static_assert(std::is_same::signature::return_type>::value, "test num_dispatch_args"); static_assert(std::is_same, float, Tensor, Tensor, unsigned int>, typename OpSchema::signature::parameter_types>::value, "test num_dispatch_args"); - -int main() { - return 0; -} - diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 8979a5364b435f..a74005904e1da8 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -335,6 +335,8 @@ Caffe2Backend::get_special_operators() const { const static std:: unordered_map kSpecialOperators = { + {"ArgMax", &Caffe2Backend::CreateArgMaxMin}, + {"ArgMin", &Caffe2Backend::CreateArgMaxMin}, {"Cast", &Caffe2Backend::CreateCast}, {"Constant", &Caffe2Backend::CreateConstant}, {"Conv", &Caffe2Backend::CreateConvPoolOpBase}, @@ -363,6 +365,17 @@ Caffe2Backend::get_special_operators() const { // Special Operator Converters //============================ +Caffe2Ops Caffe2Backend::CreateArgMaxMin( + OnnxNode* onnx_node, + int opset_version) { + auto& attributes = onnx_node->attributes; + if (!attributes.HasAttribute("axis")) { + auto* attr = attributes.AddRewrittenAttribute("axis"); + attr->set_i(0); + } + return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version); +} + Caffe2Ops Caffe2Backend::CreateCast(OnnxNode* onnx_node, int opset_version) { auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version); diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index 2c954e4adae9fa..e8a8ec3c65bc57 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -160,6 +160,8 @@ class Caffe2Backend { Caffe2Ops CommonOnnxNodeToCaffe2Ops(OnnxNode* onnx_node, int opset_version); + Caffe2Ops CreateArgMaxMin(OnnxNode* onnx_node, int opset_version); + Caffe2Ops CreateCast(OnnxNode* onnx_node, int opset_version); Caffe2Ops CreateConstant(OnnxNode* onnx_node, int opset_version); diff --git a/caffe2/onnx/onnx_exporter.cc b/caffe2/onnx/onnx_exporter.cc index 9190724248c8ec..2cfe8a2f6e9cf4 100644 --- a/caffe2/onnx/onnx_exporter.cc +++ b/caffe2/onnx/onnx_exporter.cc @@ -222,6 +222,8 @@ const std::unordered_map& OnnxExporter::get_special_operators() const { const static std::unordered_map kSpecialOperators = { + {"ArgMax", &OnnxExporter::CreateArgMaxMinOpNodes}, + {"ArgMin", &OnnxExporter::CreateArgMaxMinOpNodes}, {"Add", &OnnxExporter::CreateBinaryElementwiseOpNodes}, {"Sub", &OnnxExporter::CreateBinaryElementwiseOpNodes}, {"Mul", &OnnxExporter::CreateBinaryElementwiseOpNodes}, @@ -351,6 +353,25 @@ ConvertedResult OnnxExporter::CommonCaffe2OpToOnnxNodes( return result; } +ConvertedResult OnnxExporter::CreateArgMaxMinOpNodes( + const caffe2::OperatorDef& def, + const std::unordered_map& shapes) { + auto result = CommonCaffe2OpToOnnxNodes(def); + auto& nodes = result.first; + + CAFFE_ENFORCE_EQ(nodes.size(), 1); + auto& node = nodes.back(); + + if (!ArgumentHelper::HasArgument(def, "axis")) { + const auto& x = def.input(0); + const auto& x_shape = shapes.at(x); + node.add_attribute()->CopyFrom( + MakeAttribute("axis", x_shape.dims().size() - 1)); + } + + return result; +} + ConvertedResult OnnxExporter::CreateBinaryElementwiseOpNodes( const caffe2::OperatorDef& def, const std::unordered_map& shapes) { diff --git a/caffe2/onnx/onnx_exporter.h b/caffe2/onnx/onnx_exporter.h index d6c72c7df4b1ff..7fcd54044d9d66 100644 --- a/caffe2/onnx/onnx_exporter.h +++ b/caffe2/onnx/onnx_exporter.h @@ -52,6 +52,10 @@ class OnnxExporter { private: ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def); + ConvertedResult CreateArgMaxMinOpNodes( + const caffe2::OperatorDef& def, + const std::unordered_map& shapes); + ConvertedResult CreateBinaryElementwiseOpNodes( const caffe2::OperatorDef& def, const std::unordered_map& shapes); diff --git a/caffe2/operators/arg_max_op.cc b/caffe2/operators/arg_max_op.cc deleted file mode 100644 index 641384bea87395..00000000000000 --- a/caffe2/operators/arg_max_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -#include "caffe2/operators/arg_max_op.h" - -#include "caffe2/core/operator.h" -#include "caffe2/core/types.h" - -namespace caffe2 { - -vector TensorInferenceForRowWiseArgMax( - const OperatorDef& /* def */, - const vector& in) { - std::vector output_dims(2); - output_dims[0] = in[0].dims(0); // N - output_dims[1] = 1; // 1 - return vector{ - CreateTensorShape(vector{output_dims}, TensorProto::INT64)}; -} - -template <> -bool RowWiseArgMaxOp::RunOnDevice() { - auto& X = Input(0); - auto* result = Output(0); - CAFFE_ENFORCE(X.ndim() == 2, "Input should be a 2D tensor"); - const int N = X.dim32(0); - const int D = X.dim32(1); - const float* X_data = X.data(); - result->Resize(N, 1); - int* result_data = result->mutable_data(); - for (int n = 0; n < N; ++n) { - float mx = X_data[n * D]; - int argmx = n * D; - for (int d = 1; d < D; ++d) { - int idx = n * D + d; - if (X_data[idx] > mx) { - mx = X_data[idx]; - argmx = idx; - } - result_data[n] = argmx - (n * D); - } - } - return true; -} - -// RowWiseArgMax -REGISTER_CPU_OPERATOR(RowWiseArgMax, RowWiseArgMaxOp); -OPERATOR_SCHEMA(RowWiseArgMax) - .NumInputs(1) - .NumOutputs(1) - .TensorInferenceFunction(TensorInferenceForRowWiseArgMax) - .SetDoc(R"DOC( - Given a 2D (N X D) input tensor, this operator returns a 2D (N X 1) output - tensor with with the index of the maximum value in each row. If there are - duplicate max values in a row the index of the first occurence is returned. - )DOC") - .Input(0, "X", "2D (N X D) input tensor") - .Output(0, "Z", "2D (N X 1) output tensor"); - -NO_GRADIENT(RowWiseArgMax); -} // namespace caffe2 diff --git a/caffe2/operators/arg_max_op.h b/caffe2/operators/arg_max_op.h deleted file mode 100644 index 916ded1f8fbbeb..00000000000000 --- a/caffe2/operators/arg_max_op.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef CAFFE2_OPERATORS_ARG_MAX_OP_H_ -#define CAFFE2_OPERATORS_ARG_MAX_OP_H_ - -#include "caffe2/core/context.h" -#include "caffe2/core/operator.h" - -namespace caffe2 { - -template -class RowWiseArgMaxOp : public Operator { - public: - RowWiseArgMaxOp(const OperatorDef& def, Workspace* ws) - : Operator(def, ws) {} - USE_OPERATOR_CONTEXT_FUNCTIONS; - - bool RunOnDevice() override; - - protected: - INPUT_TAGS(X_IN); - OUTPUT_TAGS(ROWWISE_ARGMAX_OUT); -}; - -} // namespace caffe2 - -#endif // CAFFE2_OPERATORS_DISTANCE_OP_H_ diff --git a/caffe2/python/operator_test/arg_max_test.py b/caffe2/python/operator_test/arg_max_test.py deleted file mode 100644 index 6cbbe2e7071ffa..00000000000000 --- a/caffe2/python/operator_test/arg_max_test.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from caffe2.python import core, workspace -import caffe2.python.hypothesis_test_util as hu -from hypothesis import given -import hypothesis.strategies as st -import numpy as np - - -class TestArgMaxOp(hu.HypothesisTestCase): - def _test_op( - self, - op_name, - original_inp, - expected_values, - ): - op = core.CreateOperator( - op_name, - ['input'], - ['output'], - ) - workspace.FeedBlob('input', np.array(original_inp, dtype=np.float32)) - workspace.RunOperatorOnce(op) - np.testing.assert_array_equal( - workspace.FetchBlob('output'), - np.array(expected_values), - ) - - def test_rowwise_argmax_op_with_large_input(self): - X = [[1, 2, 3, 100, 3, 2, 1, 1.5, 1, 1, 1, 1, 1, 1.0], - [1, 2, 3, 1, 3, 2, 1, 1.5, 1, 100, 1, 1, 1, 1.0], - [1, 2, 3, 1, 3, 2, 1, 1.5, 1, 1, 100, 1, 1, 1.0], - [1, 2, 3, 1, 3, 2, 1, 1.5, 1, 1, 1, 100, 1, 1.0], - [1, 2, 3, 1, 3, 2, 1, 1.5, 1, 1, 1, 1, 100, 1.0], - [100, 2, 3, 1, 3, 2, 1, 1.5, 1, 1, 1, 1, 1, 1.0], - [1, 2, 3, 100, 3, 2, 1, 1.5, 1, 1, 1, 1, 1, 1.0], - [1, 2, 3, 1, 3, 2, 100, 1.5, 1, 1, 1, 1, 1, 1.0]] - - self._test_op( - op_name='RowWiseArgMax', - original_inp=X, - expected_values=[[3], [9], [10], [11], [12], [0], [3], [6]], - ) - - def test_rowwise_argmax_op_with_small_input(self): - X = [[4.2, 6, 3.1], - [10, 20, 40.4], - [100.01, 25, 3]] - - self._test_op( - op_name='RowWiseArgMax', - original_inp=X, - expected_values=[[1], [2], [0]], - ) - - def test_rowwise_argmax_with_duplicate_values(self): - X = [[2, 2], [3, 3]] - self._test_op( - op_name='RowWiseArgMax', - original_inp=X, - expected_values=[[0], [0]], - ) - - def test_rowwise_argmax_with_1x1_tensor(self): - X = [[1]] - self._test_op( - op_name='RowWiseArgMax', - original_inp=X, - expected_values=[[0]], - ) - - @given( - x=hu.tensor( - min_dim=2, max_dim=2, dtype=np.float32, - elements=st.integers(min_value=-100, max_value=100)), - ) - def test_rowwise_argmax_shape_inference(self, x): - workspace.FeedBlob('x', x) - - net = core.Net("rowwise_argmax_test") - result = net.RowWiseArgMax(['x']) - (shapes, types) = workspace.InferShapesAndTypes([net]) - workspace.RunNetOnce(net) - - self.assertEqual(shapes[result], list(workspace.blobs[result].shape)) - self.assertEqual(types[result], core.DataType.INT64) - - -if __name__ == "__main__": - import unittest - unittest.main() diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3a9ed06862ea73..7c396da1b39d45 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1,6 +1,26 @@ +# UBSAN triggers when compiling protobuf, so we need to disable it. +set(UBSAN_FLAG "-fsanitize=undefined") + +macro(disable_ubsan) + if (CMAKE_C_FLAGS MATCHES ${UBSAN_FLAG} OR CMAKE_CXX_FLAGS MATCHES ${UBSAN_FLAG}) + set(CAFFE2_UBSAN_ENABLED ON) + string(REPLACE ${UBSAN_FLAG} "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) + string(REPLACE ${UBSAN_FLAG} "" CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + endif() +endmacro() + +macro(enable_ubsan) + if (CAFFE2_UBSAN_ENABLED) + set(CMAKE_C_FLAGS "${UBSAN_FLAG} ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${UBSAN_FLAG} ${CMAKE_CXX_FLAGS}") + endif() +endmacro() + # ---[ Custom Protobuf if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) + disable_ubsan() include(${CMAKE_CURRENT_LIST_DIR}/ProtoBuf.cmake) + enable_ubsan() endif() # ---[ Threads diff --git a/docs/source/scripts/build_activation_images.py b/docs/source/scripts/build_activation_images.py index dface307129d83..ce424d1ff188fa 100644 --- a/docs/source/scripts/build_activation_images.py +++ b/docs/source/scripts/build_activation_images.py @@ -67,14 +67,7 @@ def plot_function(function, **args): # Start a new plot pylab.clf() - - # Add an overlay on the background of faint traces of all the other - # functions. This is nice when flipping through images - for background_function in functions: - plot_function( - torch.nn.modules.activation.__dict__[background_function](), - alpha=0.03, color='k' - ) + pylab.grid(color='k', alpha=0.2, linestyle='--') # Plot the current function plot_function(function) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 5ff9f441d62572..16ab49c26e11b8 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -222,6 +222,7 @@ view of a storage and defines numeric operations on it. .. automethod:: expand_as .. automethod:: exponential_ .. automethod:: fill_ + .. automethod:: flip .. automethod:: float .. automethod:: floor .. automethod:: floor_ @@ -306,6 +307,7 @@ view of a storage and defines numeric operations on it. .. automethod:: ormqr .. automethod:: permute .. automethod:: pin_memory + .. automethod:: pinverse .. automethod:: potrf .. automethod:: potri .. automethod:: potrs diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 3c6e6aa367d89b..d71fd97b8c8e1f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -253,6 +253,7 @@ Other Operations .. autofunction:: diagflat .. autofunction:: diagonal .. autofunction:: einsum +.. autofunction:: flip .. autofunction:: histc .. autofunction:: renorm .. autofunction:: trace @@ -288,6 +289,7 @@ BLAS and LAPACK Operations .. autofunction:: mv .. autofunction:: orgqr .. autofunction:: ormqr +.. autofunction:: pinverse .. autofunction:: potrf .. autofunction:: potri .. autofunction:: potrs diff --git a/setup.py b/setup.py index f80d33cd4fea54..7b1f484423bbf8 100644 --- a/setup.py +++ b/setup.py @@ -1060,9 +1060,9 @@ def make_relative_rpath(path): 'lib/include/ATen/cuda/detail/*.cuh', 'lib/include/pybind11/*.h', 'lib/include/pybind11/detail/*.h', - 'lib/include/TH/*.h', - 'lib/include/TH/generic/*.h', - 'lib/include/THC/*.h', + 'lib/include/TH/*.h*', + 'lib/include/TH/generic/*.h*', + 'lib/include/THC/*.h*', 'lib/include/THC/*.cuh', 'lib/include/THC/generic/*.h', 'lib/include/THCUNN/*.cuh', diff --git a/test/common.py b/test/common.py index d28c18b51af198..5debc66905624c 100644 --- a/test/common.py +++ b/test/common.py @@ -70,7 +70,9 @@ def run_tests(argv=UNITTEST_ARGS): TEST_MKL = torch.backends.mkl.is_available() -NO_MULTIPROCESSING_SPAWN = 'NO_MULTIPROCESSING_SPAWN' in os.environ +NO_MULTIPROCESSING_SPAWN = os.environ.get('NO_MULTIPROCESSING_SPAWN', '0') == '1' +TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', '0') == '1' +TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1' def skipIfNoLapack(fn): diff --git a/test/common_nn.py b/test/common_nn.py index 2f5b6a213b91fa..ba161b39f0b2eb 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -754,7 +754,10 @@ def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_para jacobian_param = torch.zeros(num_param, output_size) for i in range(output_size): - _, d_param = self._get_parameters(module) + param, d_param = self._get_parameters(module) + # make non grad zeros + d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)] + d_out = torch.zeros_like(output) flat_d_out = d_out.view(-1) flat_d_out[i] = 1 @@ -948,7 +951,16 @@ def noncontiguize(self, obj): return [self.noncontiguize(o) for o in obj] tensor = obj ndim = tensor.dim() - noncontig = torch.stack([torch.zeros_like(tensor), tensor], ndim).select(ndim, 1).detach() + # Always making only the last dimension noncontiguous is easy to hide + # bugs because .view(-1) will still work. So try to find a dim with size + # > 1 and make that non-contiguous, i.e., stack + select on the + # dimension directly after that. + dim = ndim + for d in range(ndim): + if tensor.size(d) > 1: + dim = d + 1 + break + noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach() assert noncontig.numel() == 1 or not noncontig.is_contiguous() noncontig.requires_grad = tensor.requires_grad return noncontig @@ -977,11 +989,7 @@ def test_noncontig(self, test_case, module, input): test_case._zero_grad_parameters(module) test_case._zero_grad_input(i) with freeze_rng_state(): - try: - out = test_case._forward(module, i) - except Exception: - # Some modules will fail because of non contiguous inputs and we're ok with that - continue + out = test_case._forward(module, i) grad = test_case._backward(module, i, out, go) test_case.assertEqual(out, output) diff --git a/test/expect/TestCudaSparse.test_add_dense_sparse_mismatch.expect b/test/expect/TestCudaSparse.test_add_dense_sparse_mismatch.expect new file mode 100644 index 00000000000000..b6af4e9f4280b5 --- /dev/null +++ b/test/expect/TestCudaSparse.test_add_dense_sparse_mismatch.expect @@ -0,0 +1 @@ +add: expected 'self' and 'other' to have same size, but self has size [3, 4] while other has size [3, 4, 4] (FYI: dense-sparse addition does not currently support broadcasting) \ No newline at end of file diff --git a/test/expect/TestCudaSparse.test_log1p-backward.expect b/test/expect/TestCudaSparse.test_log1p-backward.expect new file mode 100644 index 00000000000000..8e4e1fc8c1c18f --- /dev/null +++ b/test/expect/TestCudaSparse.test_log1p-backward.expect @@ -0,0 +1 @@ +log1p of a sparse tensor is made to be non-differentiable since local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. Use a different mathematical operation which preserves sparsity of gradients, or report a bug if you think this is an error. \ No newline at end of file diff --git a/test/expect/TestCudaSparse.test_log1p-uncoalesced.expect b/test/expect/TestCudaSparse.test_log1p-uncoalesced.expect new file mode 100644 index 00000000000000..b25c0d3db02b2e --- /dev/null +++ b/test/expect/TestCudaSparse.test_log1p-uncoalesced.expect @@ -0,0 +1 @@ +log1p: in-place on uncoalesced tensors is not supported yet! \ No newline at end of file diff --git a/test/expect/TestCudaUncoalescedSparse.test_add_dense_sparse_mismatch.expect b/test/expect/TestCudaUncoalescedSparse.test_add_dense_sparse_mismatch.expect new file mode 100644 index 00000000000000..b6af4e9f4280b5 --- /dev/null +++ b/test/expect/TestCudaUncoalescedSparse.test_add_dense_sparse_mismatch.expect @@ -0,0 +1 @@ +add: expected 'self' and 'other' to have same size, but self has size [3, 4] while other has size [3, 4, 4] (FYI: dense-sparse addition does not currently support broadcasting) \ No newline at end of file diff --git a/test/expect/TestCudaUncoalescedSparse.test_log1p-backward.expect b/test/expect/TestCudaUncoalescedSparse.test_log1p-backward.expect new file mode 100644 index 00000000000000..8e4e1fc8c1c18f --- /dev/null +++ b/test/expect/TestCudaUncoalescedSparse.test_log1p-backward.expect @@ -0,0 +1 @@ +log1p of a sparse tensor is made to be non-differentiable since local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. Use a different mathematical operation which preserves sparsity of gradients, or report a bug if you think this is an error. \ No newline at end of file diff --git a/test/expect/TestCudaUncoalescedSparse.test_log1p-uncoalesced.expect b/test/expect/TestCudaUncoalescedSparse.test_log1p-uncoalesced.expect new file mode 100644 index 00000000000000..b25c0d3db02b2e --- /dev/null +++ b/test/expect/TestCudaUncoalescedSparse.test_log1p-uncoalesced.expect @@ -0,0 +1 @@ +log1p: in-place on uncoalesced tensors is not supported yet! \ No newline at end of file diff --git a/test/expect/TestSparse.test_add_dense_sparse_mismatch.expect b/test/expect/TestSparse.test_add_dense_sparse_mismatch.expect new file mode 100644 index 00000000000000..b6af4e9f4280b5 --- /dev/null +++ b/test/expect/TestSparse.test_add_dense_sparse_mismatch.expect @@ -0,0 +1 @@ +add: expected 'self' and 'other' to have same size, but self has size [3, 4] while other has size [3, 4, 4] (FYI: dense-sparse addition does not currently support broadcasting) \ No newline at end of file diff --git a/test/expect/TestSparse.test_log1p-backward.expect b/test/expect/TestSparse.test_log1p-backward.expect new file mode 100644 index 00000000000000..8e4e1fc8c1c18f --- /dev/null +++ b/test/expect/TestSparse.test_log1p-backward.expect @@ -0,0 +1 @@ +log1p of a sparse tensor is made to be non-differentiable since local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. Use a different mathematical operation which preserves sparsity of gradients, or report a bug if you think this is an error. \ No newline at end of file diff --git a/test/expect/TestSparse.test_log1p-uncoalesced.expect b/test/expect/TestSparse.test_log1p-uncoalesced.expect new file mode 100644 index 00000000000000..b25c0d3db02b2e --- /dev/null +++ b/test/expect/TestSparse.test_log1p-uncoalesced.expect @@ -0,0 +1 @@ +log1p: in-place on uncoalesced tensors is not supported yet! \ No newline at end of file diff --git a/test/expect/TestSparseOneOff.test_cuda_from_cpu.expect b/test/expect/TestSparseOneOff.test_cuda_from_cpu.expect new file mode 100644 index 00000000000000..fab1614da93d4a --- /dev/null +++ b/test/expect/TestSparseOneOff.test_cuda_from_cpu.expect @@ -0,0 +1 @@ +backend of indices (CUDA) must match backend of values (CPU) \ No newline at end of file diff --git a/test/expect/TestSparseOneOff.test_cuda_sparse_cpu_dense_add.expect b/test/expect/TestSparseOneOff.test_cuda_sparse_cpu_dense_add.expect new file mode 100644 index 00000000000000..77b0b500f3b692 --- /dev/null +++ b/test/expect/TestSparseOneOff.test_cuda_sparse_cpu_dense_add.expect @@ -0,0 +1 @@ +add: expected 'other' to be a CPU tensor, but got a CUDA tensor \ No newline at end of file diff --git a/test/expect/TestUncoalescedSparse.test_add_dense_sparse_mismatch.expect b/test/expect/TestUncoalescedSparse.test_add_dense_sparse_mismatch.expect new file mode 100644 index 00000000000000..b6af4e9f4280b5 --- /dev/null +++ b/test/expect/TestUncoalescedSparse.test_add_dense_sparse_mismatch.expect @@ -0,0 +1 @@ +add: expected 'self' and 'other' to have same size, but self has size [3, 4] while other has size [3, 4, 4] (FYI: dense-sparse addition does not currently support broadcasting) \ No newline at end of file diff --git a/test/expect/TestUncoalescedSparse.test_log1p-backward.expect b/test/expect/TestUncoalescedSparse.test_log1p-backward.expect new file mode 100644 index 00000000000000..8e4e1fc8c1c18f --- /dev/null +++ b/test/expect/TestUncoalescedSparse.test_log1p-backward.expect @@ -0,0 +1 @@ +log1p of a sparse tensor is made to be non-differentiable since local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. Use a different mathematical operation which preserves sparsity of gradients, or report a bug if you think this is an error. \ No newline at end of file diff --git a/test/expect/TestUncoalescedSparse.test_log1p-uncoalesced.expect b/test/expect/TestUncoalescedSparse.test_log1p-uncoalesced.expect new file mode 100644 index 00000000000000..b25c0d3db02b2e --- /dev/null +++ b/test/expect/TestUncoalescedSparse.test_log1p-uncoalesced.expect @@ -0,0 +1 @@ +log1p: in-place on uncoalesced tensors is not supported yet! \ No newline at end of file diff --git a/test/test_autograd.py b/test/test_autograd.py index 6e8a746aa4175d..ecdde404668504 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2040,6 +2040,29 @@ def run_test(input_size, exponent): run_test((10, 10), torch.zeros(10, 10)) run_test((10,), 0) + def test_pinverse(self): + # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? + # 1. Pseudo-inverses are not generally continuous, which means that they are not differentiable + # 2. Derivatives for pseudo-inverses exist typically for constant rank (Golub et al, 1973) + # 3. This method creates two orthogonal matrices, and a constructs a test case with large + # singular values (given by x to the function). + # 4. This will ensure that small perturbations don't affect the rank of matrix, in which case + # a derivative exists. + # 5. This test exists since pinverse is implemented using SVD, and is hence a backpropable method + m, n = 5, 10 + U = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + V = torch.randn(n, m).qr()[0].t() # Orthogonal with dimensions m x n + + def func(x): + S = torch.cat([x, torch.zeros(n - m)], 0) + M = U.mm(torch.diag(S)).mm(V.t()) + return M.pinverse() + + gradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) + gradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) + gradgradcheck(func, [torch.rand(m).add_(1).requires_grad_()]) + gradgradcheck(func, [torch.rand(m).add_(10).requires_grad_()]) + def test_profiler(self): x = torch.randn(10, 10) diff --git a/test/test_cuda.py b/test/test_cuda.py index 6f6ac0f0826e5f..22703e2e18000d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -104,7 +104,14 @@ def cast_tensor(tensor, t): def make_tensor(t, *sizes): - return t(*sizes).copy_(torch.randn(*sizes)) + if 'Half' in t.__name__: + return t(*sizes).copy_(torch.randn(*sizes)) + else: + tensor = t(*sizes) + if tensor.is_floating_point(): + return tensor.normal_() + else: + return tensor.random_(0, 10) def make_sparse_tensor(t, n, *sizes): @@ -408,7 +415,7 @@ def tmp(t): ('unsqueeze', new_t(2, 3, 4), lambda t: [2],), ('unsqueeze', new_t(2, 3, 4), lambda t: [-2], 'neg_dim'), ('view', small_3d, lambda t: [100, 10], 'contiguous'), - ('view_as', small_3d, lambda t: [t(100, 10)],), + ('view_as', small_3d, lambda t: [make_tensor(t, 100, 10)],), ('zero', small_3d, lambda t: [],), ('zeros', small_3d, lambda t: [1, 2, 3, 4],), ('eye', small_2d, lambda t: [3, 4],), @@ -1385,6 +1392,10 @@ def test_caching_pinned_memory_multi_gpu(self): def _select_broadcastable_dims(dims_full=None): return TestTorch._select_broadcastable_dims(dims_full) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") + def test_pinverse(self): + TestTorch._test_pinverse(self, lambda t: t.cuda()) + @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_det_logdet_slogdet(self): TestTorch._test_det_logdet_slogdet(self, lambda t: t.cuda()) diff --git a/test/test_distributions.py b/test/test_distributions.py index e47c16d2f0347a..2f97370f713d8f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -30,7 +30,7 @@ from random import shuffle import torch -from common import TestCase, run_tests, set_rng_seed +from common import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN from common_cuda import TEST_CUDA from torch.autograd import grad, gradcheck from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, @@ -597,14 +597,16 @@ def _gradcheck_log_prob(self, dist_ctor, ctor_params): # performs gradient checks on log_prob distribution = dist_ctor(*ctor_params) s = distribution.sample() + if s.is_floating_point(): + s.detach_().requires_grad_() expected_shape = distribution.batch_shape + distribution.event_shape self.assertEqual(s.size(), expected_shape) - def apply_fn(*params): + def apply_fn(s, *params): return dist_ctor(*params).log_prob(s) - gradcheck(apply_fn, ctor_params, raise_exception=True) + gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True) def _check_log_prob(self, dist, asset_fn): # checks that the log_prob matches a reference function @@ -1099,8 +1101,10 @@ def test_relaxed_one_hot_categorical_1d(self): def test_relaxed_one_hot_categorical_2d(self): probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]] probabilities_1 = [[1.0, 0.0], [0.0, 1.0]] - temp = torch.tensor([3.00], requires_grad=True) - temp_2 = torch.tensor([0.2], requires_grad=True) + temp = torch.tensor([3.0], requires_grad=True) + # The lower the temperature, the more unstable the log_prob gradcheck is + # w.r.t. the sample. Values below 0.25 empirically fail the default tol. + temp_2 = torch.tensor([0.25], requires_grad=True) p = torch.tensor(probabilities, requires_grad=True) s = torch.tensor(probabilities_1, requires_grad=True) self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3)) @@ -3704,6 +3708,7 @@ def test_valid(self): for i, param in enumerate(params): Dist(validate_args=True, **param) + @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") def test_invalid(self): for Dist, params in BAD_EXAMPLES: for i, param in enumerate(params): diff --git a/test/test_indexing.py b/test/test_indexing.py index 4c1fe642637cf7..18be819a641023 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -442,6 +442,36 @@ def test_boolean_indexing_twodim(self): [4, 0, 6], [0, 8, 0]])) + def test_boolean_indexing_weirdness(self): + # Weird boolean indexing things + a = torch.ones((2, 3, 4)) + if torch._C._use_zero_size_dim(): + self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) + else: + self.assertEqual((0,), a[False, True, ...].shape) + self.assertEqual(torch.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]]) + if torch._C._use_zero_size_dim(): + self.assertRaises(RuntimeError, lambda: a[False, [0, 1], ...]) + + def test_boolean_indexing_weirdness_tensors(self): + # Weird boolean indexing things + false = torch.tensor(False) + true = torch.tensor(True) + a = torch.ones((2, 3, 4)) + if torch._C._use_zero_size_dim(): + self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape) + else: + self.assertEqual((0,), a[False, True, ...].shape) + self.assertEqual(torch.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]]) + if torch._C._use_zero_size_dim(): + self.assertRaises(RuntimeError, lambda: a[false, [0, 1], ...]) + + def test_boolean_indexing_alldims(self): + true = torch.tensor(True) + a = torch.ones((2, 3)) + self.assertEqual((1, 2, 3), a[True, True].shape) + self.assertEqual((1, 2, 3), a[true, true].shape) + def test_everything_returns_views(self): # Before `...` would return a itself. a = tensor([5]) diff --git a/test/test_jit.py b/test/test_jit.py index f186f42d48e3b8..453b6169fcfec8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9,7 +9,7 @@ from torch.autograd.function import traceable from torch.testing import assert_allclose from torch.onnx import OperatorExportTypes -from common import TestCase, run_tests, IS_WINDOWS +from common import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN from textwrap import dedent import os import io @@ -4732,6 +4732,45 @@ class TestJitGenerated(TestCase): pass +# UBSAN per-function exclusions don't seem to work with OpenMP pragmas, +# and we have to disable the failing tests here instead. +UBSAN_BLACKLISTED_TESTS = [ + "test___rdiv___constant", + "test___rdiv___scalar_constant", + "test_addcdiv", + "test_addcdiv_broadcast_all", + "test_addcdiv_broadcast_rhs", + "test_addcdiv_scalar", + "test_addcdiv_scalar_broadcast_lhs", + "test_addcdiv_scalar_broadcast_rhs", + "test_addcdiv_scalar_scale", + "test_addcdiv_scalar_scale_broadcast_lhs", + "test_addcdiv_scalar_scale_broadcast_rhs", + "test_addcdiv_scale", + "test_addcdiv_scale_broadcast_all", + "test_addcdiv_scale_broadcast_rhs", + "test_add_broadcast_all", + "test_add_broadcast_lhs", + "test_add_broadcast_rhs", + "test_add_constant", + "test_add_scalar", + "test_add_scalar_broadcast_lhs", + "test_add_scalar_broadcast_rhs", + "test_div", + "test_div_broadcast_all", + "test_div_broadcast_lhs", + "test_div_broadcast_rhs", + "test_div_scalar", + "test_div_scalar_broadcast_lhs", + "test_div_scalar_broadcast_rhs", + "test_rsqrt", + "test_rsqrt_scalar", + "test_add", + "test_reciprocal", + "test_reciprocal_scalar", +] + + def add_test( name, self_size, @@ -4811,7 +4850,8 @@ def fn(*inputs, **kwargs): for skip in skipTestIf: do_test = skip(do_test) - setattr(TestJitGenerated, test_name, do_test) + if not (TEST_WITH_UBSAN and test_name in UBSAN_BLACKLISTED_TESTS): + setattr(TestJitGenerated, test_name, do_test) for test in method_tests: add_test(*test) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 62aca4026052d7..099f4d774ccc60 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -11,7 +11,7 @@ import torch.multiprocessing as mp from torch.autograd import Variable from torch.nn import Parameter -from common import TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN +from common import TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, TEST_WITH_ASAN TEST_REPEATS = 30 @@ -21,7 +21,6 @@ sys.platform != 'darwin' and \ sys.platform != 'win32' TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1 -TEST_WITH_ASAN = os.getenv('PYTORCH_TEST_WITH_ASAN', False) class SubProcess(mp.Process): diff --git a/test/test_nn.py b/test/test_nn.py index c25863edb15414..0b0e5ef1d29812 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -100,7 +100,7 @@ def _ordered_sequence(self, tensor_type): """Create ordered list of random sequences""" seqs = [tensor_type(random.randint(1, self.max_length)) for _ in range(self.batch_size)] - seqs = [s.random_() for s in seqs] + seqs = [s.random_(-128, 128) for s in seqs] ordered = sorted(seqs, key=len, reverse=True) return ordered @@ -510,8 +510,6 @@ def _get_parameters(self, module): params = [] d_params = [] for p in module.parameters(): - if p.grad is None: - p._grad = torch.zeros_like(p) params.append(p) d_params.append(p.grad) return params, d_params @@ -7136,16 +7134,33 @@ def multimarginloss_weights_no_reduce_test(): dict( module_name='Embedding', constructor_args=(4, 3), - input_fn=lambda: torch.randperm(2).repeat(1, 2), + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), jacobian_input=False, check_gradgrad=False, ), dict( module_name='EmbeddingBag', constructor_args=(4, 3), - input_fn=lambda:torch.randperm(2).repeat(1, 2), + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + jacobian_input=False, + check_gradgrad=False, + desc='mean', + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2, False, 'sum'), + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), + jacobian_input=False, + check_gradgrad=False, + desc='sum', + ), + dict( + module_name='EmbeddingBag', + constructor_args=(4, 3, None, 2, False, 'max'), + input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4), jacobian_input=False, check_gradgrad=False, + desc='max', ), dict( fullname='EmbeddingBag_sparse', diff --git a/test/test_optim.py b/test/test_optim.py index 889b13b6894e30..57fe9e5da53944 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -10,7 +10,7 @@ from torch.autograd import Variable from torch import sparse from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau -from common import TestCase, run_tests +from common import TestCase, run_tests, TEST_WITH_UBSAN def rosenbrock(tensor): @@ -475,6 +475,7 @@ def test_lbfgs(self): ignore_multidevice=True ) + @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") def test_lbfgs_return_type(self): params = [torch.randn(10, 5), torch.randn(10)] opt1 = optim.LBFGS(params, 0.01, tolerance_grad=float('inf')) diff --git a/test/test_sparse.py b/test/test_sparse.py index 3c0efa35c0ce04..da316bd26a0179 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -2,6 +2,7 @@ from torch import sparse import itertools +import functools import random import unittest from common import TestCase, run_tests @@ -11,6 +12,7 @@ def cpu_only(inner): + @functools.wraps(inner) def outer(self, *args, **kwargs): if self.is_cuda: raise unittest.SkipTest("Test is CPU-only") @@ -19,6 +21,7 @@ def outer(self, *args, **kwargs): def cuda_only(inner): + @functools.wraps(inner) def outer(self, *args, **kwargs): if not self.is_cuda: raise unittest.SkipTest("Test is GPU-only") @@ -34,8 +37,10 @@ def setUp(self): # tests self.is_cuda = False self.is_uncoalesced = False + self.device = 'cpu' self.IndexTensor = torch.LongTensor self.ValueTensor = torch.DoubleTensor + self.value_dtype = torch.float64 self.SparseTensor = torch.sparse.DoubleTensor super(TestSparse, self).setUp() @@ -517,6 +522,27 @@ def _test_spadd_shape(self, shape_i, shape_v=None): self.assertEqual(res, expected) + x, i, v = self._gen_sparse(len(shape_i), 10, shape) + nnz = i.size(1) + + # Non contiguous sparse indices tensor + x_ = self.SparseTensor(i[:, ::2], v[:int(nnz / 2)], x.shape) + res = torch.add(y, r, x_) + expected = y + r * self.safeToDense(x_) + self.assertEqual(res, expected) + + # Non contiguous sparse values tensor + x_ = self.SparseTensor(i[:, :int(nnz / 2)], v[::2], x.shape) + res = torch.add(y, r, x_) + expected = y + r * self.safeToDense(x_) + self.assertEqual(res, expected) + + # Non contiguous sparse indices and values tensors + x_ = self.SparseTensor(i[:, 1::2], v[1::2], x.shape) + res = torch.add(y, r, x_) + expected = y + r * self.safeToDense(x_) + self.assertEqual(res, expected) + def test_spadd(self): self._test_spadd_shape([5, 6]) self._test_spadd_shape([10, 10, 10]) @@ -609,6 +635,13 @@ def test_basic_ops_hybrid(self): self._test_basic_ops_shape([50, 30, 20], [2]) self._test_basic_ops_shape([5, 5, 5, 5, 5, 5], [2]) + def test_add_dense_sparse_mismatch(self): + x = torch.zeros([3, 4], dtype=self.value_dtype, device=self.device) + sparse_y = self.SparseTensor(torch.zeros(1, 4, dtype=torch.int64, device=self.device), + torch.randn(4, 4, 4, dtype=self.value_dtype, device=self.device), + torch.Size([3, 4, 4])) + self.assertExpectedRaises(RuntimeError, lambda: x + sparse_y) + def _test_sparse_mask_shape(self, shape_i, shape_v=None): shape = shape_i + (shape_v or []) x1, _, _ = self._gen_sparse(len(shape_i), 9, shape) @@ -676,9 +709,7 @@ def test_log1p(self): self.assertEqual(expected_output, input.coalesce().log1p_().to_dense()) # test in-place op on uncoalesced input - with self.assertRaisesRegex(RuntimeError, - "in-place log1p on uncoalesced tensors is not supported yet!"): - input.log1p_() + self.assertExpectedRaises(RuntimeError, lambda: input.log1p_(), subname="uncoalesced") input.requires_grad_() self.assertTrue(input.requires_grad) @@ -686,9 +717,7 @@ def test_log1p(self): # test autograd x = input.clone() y = input.log1p() - with self.assertRaisesRegex(RuntimeError, - "log1p of a sparse tensor is made to be non-differentiable since.*"): - y.backward(x) + self.assertExpectedRaises(RuntimeError, lambda: y.backward(x), subname="backward") # test uncoalesced input input_uncoalesced = torch.sparse.DoubleTensor( @@ -1064,6 +1093,7 @@ class TestCudaSparse(TestSparse): def setUp(self): super(TestCudaSparse, self).setUp() self.is_cuda = True + self.device = 'cuda' self.IndexTensor = torch.cuda.LongTensor self.ValueTensor = torch.cuda.DoubleTensor self.SparseTensor = torch.cuda.sparse.DoubleTensor @@ -1075,5 +1105,24 @@ def setUp(self): super(TestCudaUncoalescedSparse, self).setUp() self.is_uncoalesced = True + +class TestSparseOneOff(TestCase): + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_cuda_from_cpu(self): + self.assertExpectedRaises( + RuntimeError, + lambda: torch.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), + torch.randn(4, 4, 4), + [3, 4, 4])) + + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_cuda_sparse_cpu_dense_add(self): + x = torch.zeros(3, 4, 4) + sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), + torch.randn(4, 4, 4).cuda(), + [3, 4, 4]) + self.assertExpectedRaises(RuntimeError, lambda: x + sparse_y) + + if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index 10b4491fd1b012..d568d7a9b7c143 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1514,7 +1514,10 @@ def test_einsum(self): def do_einsum(*args): return torch.einsum(test[0], args) - self.assertTrue(torch.autograd.gradcheck(do_einsum, test[1:])) + # FIXME: following test cases fail gradcheck + if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: + gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) + self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) self.assertTrue(A._version == 0) # check that we do not use inplace ops def test_sum_all(self): @@ -4041,6 +4044,33 @@ def test_inverse(self): self.assertFalse(MII.is_contiguous(), 'MII is contiguous') self.assertEqual(MII, MI, 0, 'inverse value in-place') + @staticmethod + def _test_pinverse(self, conv_fn): + def run_test(M): + # Testing against definition for pseudo-inverses + MPI = torch.pinverse(M) + self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1') + self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2') + self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3') + self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4') + + # Square matrix + M = conv_fn(torch.randn(5, 5)) + run_test(M) + + # Rectangular matrix + M = conv_fn(torch.randn(3, 4)) + run_test(M) + + # Test inverse and pseudo-inverse for invertible matrix + M = torch.randn(5, 5) + M = conv_fn(M.mm(M.t())) + self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix') + + @skipIfNoLapack + def test_pinverse(self): + self._test_pinverse(self, conv_fn=lambda x: x) + @staticmethod def _test_det_logdet_slogdet(self, conv_fn): def reference_det(M): @@ -6153,6 +6183,17 @@ def _test_flip(self, use_cuda=False): self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0)) self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), tranposed_data.flip(0, 1, 2)) + # test rectangular case + data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3) + flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]) + flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]) + if use_cuda: + data = data.cuda() + flip0_result = flip0_result.cuda() + flip1_result = flip1_result.cuda() + self.assertEqual(flip0_result, data.flip(0)) + self.assertEqual(flip1_result, data.flip(1)) + def test_flip(self): self._test_flip(self, use_cuda=False) @@ -7712,6 +7753,25 @@ def test_is_nonzero(self): self.assertFalse(torch.tensor([[0]]).is_nonzero()) self.assertTrue(torch.tensor([[1]]).is_nonzero()) + def test_meshgrid(self): + a = torch.tensor(1) + b = torch.tensor([1, 2, 3]) + c = torch.tensor([1, 2]) + grid_a, grid_b, grid_c = torch.meshgrid([a, b, c]) + self.assertEqual(grid_a.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_b.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_c.shape, torch.Size([1, 3, 2])) + expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64) + expected_grid_b = torch.tensor([[[1, 1], + [2, 2], + [3, 3]]]) + expected_grid_c = torch.tensor([[[1, 2], + [1, 2], + [1, 2]]]) + self.assertTrue(grid_a.equal(expected_grid_a)) + self.assertTrue(grid_b.equal(expected_grid_b)) + self.assertTrue(grid_c.equal(expected_grid_c)) + # Functions to test negative dimension wrapping METHOD = 1 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7caaced8bdb70b..cfb48fe1f33312 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -342,7 +342,7 @@ self: -at::mm(output.t(), at::mm(grad, output.t())) - name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, indices, self.sizes(), keepdim) + self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) - name: le_(Tensor self, Scalar other) self: zeros_like(self) @@ -407,7 +407,7 @@ self: zeros_like(self).masked_scatter_(mask, grad) - name: max(Tensor self, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, max_indices, self.sizes(), keepdim) + self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) - name: max(Tensor self) self: select_equals_backward(grad, self, result) @@ -440,10 +440,10 @@ # The backward implementation is correct in the sense that it returns the # subgradient on one side. - name: median(Tensor self, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, indices, self.sizes(), keepdim) + self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) - name: min(Tensor self, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, min_indices, self.sizes(), keepdim) + self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) - name: min(Tensor self) self: select_equals_backward(grad, self, result) @@ -457,7 +457,7 @@ mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1) - name: mode(Tensor self, int64_t dim, bool keepdim) - self: index_select_backward(grad, dim, indices, self.sizes(), keepdim) + self: index_select_backward(grad, dim, result1, self.sizes(), keepdim) - name: mul(Tensor self, Scalar other) self: grad * other @@ -769,8 +769,8 @@ - name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse) -- name: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) - weight: embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse) +- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse) + weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse) - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) self: not_implemented("embedding_renorm") diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 858d93280da2fc..38bc0af4f929a2 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -24,7 +24,7 @@ '_arange.*', '_range.*', '_linspace.*', '_logspace.*', 'index', '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin', - '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_sum.*', '_th_prod.*', + '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*', '_th_*', 'arange.*', 'range.*', '_gesv.*', 'slice', 'max_pool1d', 'max_pool2d', 'max_pool3d' ] diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 5d617983f47f75..4e6fc46201db93 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -2522,3 +2522,10 @@ def callable(a, b) -> number See :func:`torch.slogdet` """) + +add_docstr_all('pinverse', + r""" +pinverse() -> Tensor + +See :func:`torch.pinverse` +""") diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b0db3911bb2ad3..675770d2104cb0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5264,6 +5264,52 @@ def parse_kwargs(desc): (tensor(-1.), tensor(1.5731)) """) +add_docstr(torch.pinverse, + r""" +pinverse(input, rcond=1e-15) -> Tensor + +Calculates the pseudo-inverse (also known as the Moore-Penrose inverse) of a 2D tensor. +Please look at `Moore-Penrose inverse`_ for more details + +.. note:: + This method is implemented using the Singular Value Decomposition. + +.. note:: + The pseudo-inverse is not necessarily a continuous function in the elements of the matrix `[1]`_. + Therefore, derivatives are not always existent, and exist for a constant rank only `[2]`_. + However, this method is backprop-able due to the implementation by using SVD results, and + could be unstable. Double-backward will also be unstable due to the usage of SVD internally. + See :meth:`~torch.svd` for more details. + +Arguments: + input (Tensor): The input 2D tensor of dimensions :math:`m \times n` + rcond (float): A floating point value to determine the cutoff for small singular values. + Default: 1e-15 + +Returns: + The pseudo-inverse of :attr:`input` of dimensions :math:`n \times m` + +Example:: + + >>> input = torch.randn(3, 5) + >>> input + tensor([[ 0.5495, 0.0979, -1.4092, -0.1128, 0.4132], + [-1.1143, -0.3662, 0.3042, 1.6374, -0.9294], + [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) + >>> torch.pinverse(input) + tensor([[ 0.0600, -0.1933, -0.2090], + [-0.0903, -0.0817, -0.4752], + [-0.7124, -0.1631, -0.2272], + [ 0.1356, 0.3933, -0.5023], + [-0.0308, -0.1725, -0.5216]]) + +.. _Moore-Penrose inverse: https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse + +.. _[1]: https://epubs.siam.org/doi/10.1137/0117004 + +.. _[2]: https://www.jstor.org/stable/2156365 +""") + add_docstr(torch.fft, r""" fft(input, signal_ndim, normalized=False) -> Tensor @@ -5787,3 +5833,36 @@ def parse_kwargs(desc): Tensor: A 1-D tensor of size :math:`(\text{{window_length}},)` containing the window """.format(**factory_common_args)) + + +add_docstr(torch.meshgrid, + r""" +meshgrid(seq) -> seq + +Take a sequence of :math:`N` tensors, each of which can be either scalar or 1-dimensional +vector, and create :math:`N` N-dimensional grids, where the :math:`i`th grid is defined by +expanding the :math:`i`th input over dimensions defined by other inputs. + +Arguments: + seq (sequence of Tensors): sequence of scalars or 1 dimensional tensors. Scalars will be + treated as tensors of size :math:`(1,)` automatically. + +Returns: + seq (sequence of Tensors): If the input has :math:`k` tensors of size + :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also has :math:`k` tensors, + where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`. + +Example:: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([4, 5, 6]) + >>> grid_x, grid_y = torch.meshgrid([x, y]) + >>> grid_x + tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + >>> grid_y + tensor([[4, 5, 6], + [4, 5, 6], + [4, 5, 6]]) +""") diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 7e020f3f884755..8697497a77bdd9 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -170,15 +170,22 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True tupled_inputs = _as_tuple(inputs) # Make sure that gradients are saved for all inputs + any_input_requiring_grad = False for inp in tupled_inputs: if isinstance(inp, torch.Tensor): - if inp.requires_grad and inp.dtype != torch.float64: - warnings.warn( - 'At least one of the inputs that requires gradient \ - is not of double precision floating point. ' - 'This check will likely fail if all the inputs are not of \ - double precision floating point. ') + if inp.requires_grad: + if inp.dtype != torch.float64: + warnings.warn( + 'At least one of the inputs that requires gradient ' + 'is not of double precision floating point. ' + 'This check will likely fail if all the inputs are ' + 'not of double precision floating point. ') + any_input_requiring_grad = True inp.retain_grad() + if not any_input_requiring_grad: + raise ValueError( + 'gradcheck expects at least one input tensor to require gradient, ' + 'but none of the them have requires_grad=True.') output = _differentiable_outputs(func(*inputs)) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 56fc925fddc7bf..6dd26948dff03e 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -126,8 +126,16 @@ static PyObject * THPModule_crashIfCsrcASAN(PyObject *module, PyObject *arg) { return PyLong_FromLong(x[0]); } +static PyObject * THPModule_crashIfCsrcUBSAN(PyObject *module, PyObject *arg) { + THPUtils_assert(THPUtils_checkLong(arg), "crash_if_csrc_ubsan expects an int, " + "but got %s", THPUtils_typename(arg)); + int32_t x = static_cast(THPUtils_unpackLong(arg)); + double y = 1.0 / x; + return PyLong_FromLong((int)y); +} + static PyObject * THPModule_crashIfATenASAN(PyObject *module, PyObject *arg) { - THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, " + THPUtils_assert(THPUtils_checkLong(arg), "crash_if_aten_asan expects an int, " "but got %s", THPUtils_typename(arg)); return PyLong_FromLong(at::_crash_if_asan(static_cast(THPUtils_unpackLong(arg)))); } @@ -401,6 +409,7 @@ static PyMethodDef TorchMethods[] = { {"_set_default_dtype", (PyCFunction)THPModule_setDefaultDtype, METH_O, NULL}, {"_infer_size", (PyCFunction)THPModule_inferSize, METH_VARARGS, NULL}, {"_crash_if_csrc_asan", (PyCFunction)THPModule_crashIfCsrcASAN, METH_O, NULL}, + {"_crash_if_csrc_ubsan", (PyCFunction)THPModule_crashIfCsrcUBSAN, METH_O, NULL}, {"_crash_if_aten_asan", (PyCFunction)THPModule_crashIfATenASAN, METH_O, NULL}, {"_set_backcompat_broadcast_warn", (PyCFunction)THPModule_setBackcompatBroadcastWarn, METH_O, NULL}, {"_get_backcompat_broadcast_warn", (PyCFunction)THPModule_getBackcompatBroadcastWarn, METH_NOARGS, NULL}, diff --git a/torch/csrc/PtrWrapper.cpp b/torch/csrc/PtrWrapper.cpp index 38b4fb35ae63a8..544895bfd0a44d 100644 --- a/torch/csrc/PtrWrapper.cpp +++ b/torch/csrc/PtrWrapper.cpp @@ -1,4 +1,5 @@ #include "torch/csrc/python_headers.h" +#include "ATen/Utils.h" #include static PyObject* THPWrapperClass = NULL; @@ -44,7 +45,8 @@ static PyObject * THPWrapper_pynew(PyTypeObject *type, PyObject *args, PyObject return self; } -static void THPWrapper_dealloc(THPWrapper* self) +// UBSAN error: https://github.com/pytorch/pytorch/issues/9054 +static void THPWrapper_dealloc(THPWrapper* self) __ubsan_ignore_function__ { self->destructor(self->data); Py_TYPE(self)->tp_free((PyObject*)self); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 1c2672b9d7e428..6573650f209358 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -52,7 +52,7 @@ static int64_t count_specified_dimensions(PyObject* index) { } else { count++; } - } else if (obj != Py_None && obj != Py_Ellipsis) { + } else if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && obj != Py_False) { count++; } } @@ -122,6 +122,15 @@ static Variable valueToTensor(const Type & type, PyObject* value) { throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString()); } +static Variable boolToIndexingTensor(const Variable& self, bool value) { + // booleans add a dimension of size 1. true indexes this dimension as if 0:, false as empty. + if (value) { + return at::zeros({1}, self.options().dtype(kLong)); + } else { + return at::empty({0}, self.options().dtype(kLong)); + } +} + static Variable applySlicing(const Variable& self, PyObject* index, variable_list& outIndices) { int64_t size = PyTuple_GET_SIZE(index); int64_t dim = 0; @@ -157,11 +166,19 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis } else if (obj == Py_None) { result = result.unsqueeze(dim); dim++; + } else if (PyBool_Check(obj)) { + result = result.unsqueeze(dim); + handle_var(boolToIndexingTensor(result, obj == Py_True)); } else if (THPVariable_Check(obj)) { auto& var = THPVariable_Unpack(obj); auto scalar_type = var.type().scalarType(); - if (var.dim() == 0 && at::isIntegralType(scalar_type) && scalar_type != at::kByte) { - result = applySelect(result, dim, THPUtils_unpackLong(obj)); + if (var.dim() == 0 && at::isIntegralType(scalar_type)) { + if (scalar_type != at::kByte) { + result = applySelect(result, dim, THPUtils_unpackLong(obj)); + } else { + result = result.unsqueeze(dim); + handle_var(boolToIndexingTensor(result, var.toCByte() != 0)); + } } else { handle_var(var); } @@ -259,18 +276,6 @@ static THPObjectPtr wrapTuple(PyObject* index) { return res; } -static bool isSingleBoolScalar(const variable_list& vars) { - return vars.size() == 1 && vars[0].type().scalarType() == ScalarType::Byte && vars[0].dim() == 0; -} - -static PyObject* applyBoolGetitem(const Variable& self, bool index) { - if (index) { - return wrap(self.type().copy(self.unsqueeze(0))); - } else { - return wrap(at::empty({0}, self.options())); - } -} - PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; @@ -283,8 +288,6 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { return wrap(at::alias(self_)); } else if (THPUtils_checkLong(index)) { return wrap(applySelect(self_, 0, THPUtils_unpackLong(index))); - } else if (PyBool_Check(index)) { - return applyBoolGetitem(self_, index == Py_True); } else if (PySlice_Check(index)) { return wrap(applySlice(self_, 0, index, true)); } @@ -301,9 +304,6 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { } return wrap(sliced); } - if (isSingleBoolScalar(variableIndices)) { - return applyBoolGetitem(self_, variableIndices[0].toCByte()); - } // indexing by tensors ("advanced" indexing) return wrap(dispatch_index(sliced, variableIndices)); @@ -311,22 +311,25 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { END_HANDLE_TH_ERRORS } -static void copy_to(Variable dst, const Variable& src) { - Tensor b_src; - // To match numpy semantics: - // As a special case for backwards compatibility, - // strip away unit dimensions from the left of 'src' - auto src_sizes = src.sizes(); - size_t first_nonzero_src = src_sizes.size(); - for (size_t i = 0; i < src_sizes.size(); ++i) { - if (src_sizes[i] != 1) { - first_nonzero_src = i; +// To match numpy semantics: +// As a special case for backwards compatibility, +// strip away unit dimensions from the left of 'src' +static IntList slicePrefix1sSize(IntList sizes) { + size_t first_non1_src = sizes.size(); + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] != 1) { + first_non1_src = i; break; } } - src_sizes = src_sizes.slice(first_nonzero_src); - std::tie(b_src) = expand_inplace(dst, src.view(src_sizes), "setitem"); + return sizes.slice(first_non1_src); +} + +static void copy_to(Variable dst, const Variable& src) { + Tensor b_src; + IntList sliced_src_sizes = slicePrefix1sSize(src.sizes()); + std::tie(b_src) = expand_inplace(dst, src.view(sliced_src_sizes), "setitem"); dst.copy_(b_src); } @@ -364,15 +367,10 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { copy_to(sliced, value); return 0; } - if (isSingleBoolScalar(variableIndices)) { - if (variableIndices[0].toCByte()) { - copy_to(self_.unsqueeze(0), value); - } - return 0; - } - // indexing by tensors ("advanced" indexing) - dispatch_index_put_(sliced, variableIndices, value); + IntList slicedValueSizes = slicePrefix1sSize(value.sizes()); + auto valuesSliced = value.view(slicedValueSizes); + dispatch_index_put_(sliced, variableIndices, valuesSliced); return 0; END_HANDLE_TH_ERRORS_RET(-1) } diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 5284207988cf07..de6de04f4045d9 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -59,7 +59,7 @@ static PyObject * THPStorage_(newTHView)(THWStorage *base, ptrdiff_t offset, siz #ifndef THC_GENERIC_FILE // TODO: move this somewhere - we only need one version static std::string THPStorage_(__newHandle)() { - std::random_device rd; + static std::random_device rd; std::string handle = "/torch_"; #ifdef _MSC_VER handle += std::to_string(GetCurrentProcessId()); diff --git a/torch/csrc/jit/attributes.h b/torch/csrc/jit/attributes.h index 58da7794b7e932..2e27635de9a947 100644 --- a/torch/csrc/jit/attributes.h +++ b/torch/csrc/jit/attributes.h @@ -5,6 +5,7 @@ #include #include #include +#include "ATen/Utils.h" #include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/assertions.h" @@ -195,7 +196,8 @@ struct Attributes { } private: - Derived* This() { + // UBSAN error: https://github.com/pytorch/pytorch/issues/9055 + Derived* This() __ubsan_ignore_vptr__ { return static_cast(this); } template diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index bf1dfc7d822c3b..ff5bcc3e81f3c5 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -9,8 +9,10 @@ class ExpRelaxedCategorical(Distribution): r""" - Creates a ExpRelaxedCategorical parameterized by `probs` and `temperature`. - Returns the log of a point in the simplex. Based on the interface to OneHotCategorical. + Creates a ExpRelaxedCategorical parameterized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + Returns the log of a point in the simplex. Based on the interface to + :class:`OneHotCategorical`. Implementation based on [1]. @@ -74,9 +76,10 @@ def log_prob(self, value): class RelaxedOneHotCategorical(TransformedDistribution): r""" - Creates a RelaxedOneHotCategorical distribution parametrized by `temperature` and either `probs` or `logits`. - This is a relaxed version of the `OneHotCategorical` distribution, so its - values are on simplex, and has reparametrizable samples. + Creates a RelaxedOneHotCategorical distribution parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + This is a relaxed version of the :class:`OneHotCategorical` distribution, so + its samples are on simplex, and are reparametrizable. Example:: diff --git a/torch/legacy/nn/BatchNormalization.py b/torch/legacy/nn/BatchNormalization.py index 56b835b0b24298..223879823e4492 100644 --- a/torch/legacy/nn/BatchNormalization.py +++ b/torch/legacy/nn/BatchNormalization.py @@ -50,6 +50,7 @@ def __init__(self, nOutput, eps=1e-5, momentum=0.1, affine=True): self.save_mean = None self.save_std = None + self._input = None self._gradOutput = None if self.affine: diff --git a/torch/legacy/nn/SpatialConvolutionLocal.py b/torch/legacy/nn/SpatialConvolutionLocal.py index db587dc4428d05..0e0cbafde8f7b1 100644 --- a/torch/legacy/nn/SpatialConvolutionLocal.py +++ b/torch/legacy/nn/SpatialConvolutionLocal.py @@ -32,6 +32,7 @@ def __init__(self, nInputPlane, nOutputPlane, iW, iH, kW, kH, dW=1, dH=1, padW=0 self.reset() self.finput = None self.fgradInput = None + self._input = None self._gradOutput = None def reset(self, stdv=None): diff --git a/torch/legacy/nn/SpatialFullConvolution.py b/torch/legacy/nn/SpatialFullConvolution.py index 63ba58dab4df56..9dc924138278e4 100644 --- a/torch/legacy/nn/SpatialFullConvolution.py +++ b/torch/legacy/nn/SpatialFullConvolution.py @@ -32,6 +32,7 @@ def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW=1, dH=1, padW=0, padH=N self.finput = None self.fgradInput = None self.zeroScalar = None + self._input = None self._gradOutput = None self.reset() diff --git a/torch/legacy/nn/VolumetricConvolution.py b/torch/legacy/nn/VolumetricConvolution.py index 0a3ace4b2e5153..8e506a1d93865b 100644 --- a/torch/legacy/nn/VolumetricConvolution.py +++ b/torch/legacy/nn/VolumetricConvolution.py @@ -29,6 +29,7 @@ def __init__(self, nInputPlane, nOutputPlane, kT, kW, kH, dT=1, dW=1, dH=1, padT self.finput = None self.fgradInput = None + self._input = None self._gradOutput = None def reset(self, stdv=None): diff --git a/torch/legacy/nn/VolumetricFullConvolution.py b/torch/legacy/nn/VolumetricFullConvolution.py index e0eb3faed23362..3236a7ede019bf 100644 --- a/torch/legacy/nn/VolumetricFullConvolution.py +++ b/torch/legacy/nn/VolumetricFullConvolution.py @@ -39,6 +39,7 @@ def __init__(self, nInputPlane, nOutputPlane, self.ones = torch.Tensor() self.finput = torch.Tensor() self.fgradInput = torch.Tensor() + self._input = None self._gradOutput = None self.reset() diff --git a/torch/nn/functional.py b/torch/nn/functional.py index a6c00b160f8857..969c24aa67e7d7 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1111,7 +1111,6 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, [ 0.0000, 0.0000, 0.0000], [ 0.6262, 0.2438, 0.7471]]]) """ - input = input.contiguous() if padding_idx is not None: if padding_idx > 0: assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' @@ -1121,6 +1120,10 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, elif padding_idx is None: padding_idx = -1 if max_norm is not None: + # `embedding_renorm_` will call .contiguous() on input anyways, so we + # call it here and take advantage of the improved locality in the + # `embedding` call below too. + input = input.contiguous() with torch.no_grad(): torch.embedding_renorm_(weight, input, max_norm, norm_type) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) @@ -1206,7 +1209,7 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, offsets = torch.arange(0, input.numel(), input.size(1), dtype=torch.long, device=input.device) - input = input.view(-1) + input = input.reshape(-1) elif input.dim() == 1: if offsets is None: raise ValueError("offsets has to be a 1D Tensor but got None") diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 2af36229af6696..d56485ee46fd7f 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -7,7 +7,7 @@ class Linear(Module): - r"""Applies a linear transformation to the incoming data: :math:`y = Ax + b` + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` Args: in_features: size of each input sample diff --git a/torch/storage.py b/torch/storage.py index 680e27eb99c373..a8ea4da336d7cd 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -1,3 +1,5 @@ +import io + import torch from ._utils import _type, _cuda @@ -28,7 +30,9 @@ def __deepcopy__(self, memo): return new_storage def __reduce__(self): - return type(self), (self.tolist(),) + b = io.BytesIO() + torch.save(self, b) + return (_load_from_bytes, (b.getvalue(),)) def __sizeof__(self): return super(_StorageBase, self).__sizeof__() + self.element_size() * self.size() @@ -116,5 +120,9 @@ def _new_shared(cls, size): return cls._new_using_fd(size) +def _load_from_bytes(b): + return torch.load(io.BytesIO(b)) + + _StorageBase.type = _type _StorageBase.cuda = _cuda diff --git a/torch/utils/trainer/plugins/loss.py b/torch/utils/trainer/plugins/loss.py index eea44ca81f0a23..1bd93f2b7fd909 100644 --- a/torch/utils/trainer/plugins/loss.py +++ b/torch/utils/trainer/plugins/loss.py @@ -5,4 +5,4 @@ class LossMonitor(Monitor): stat_name = 'loss' def _get_value(self, iteration, input, target, output, loss): - return loss[0] + return loss.item()