diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index e4c553dd4e652..7e5b98ee628cd 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -44,13 +44,10 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)") fi -export ATEN_DISABLE_AVX= -export ATEN_DISABLE_AVX2= if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then - export ATEN_DISABLE_AVX=1 -fi -if [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then - export ATEN_DISABLE_AVX2=1 + export ATEN_CPU_CAPABILITY=default +elif [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then + export ATEN_CPU_CAPABILITY=avx fi test_python_nn() { diff --git a/aten/src/ATen/Layout.h b/aten/src/ATen/Layout.h index 010248a010a5f..a610bc7f2ec0f 100644 --- a/aten/src/ATen/Layout.h +++ b/aten/src/ATen/Layout.h @@ -1,6 +1,9 @@ #pragma once #include +#include + +#include namespace at { enum class Layout { Strided, Sparse }; @@ -18,3 +21,14 @@ inline Layout layout_from_backend(Backend backend) { } } } // namespace at + +inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) { + switch (layout) { + case at::kStrided: + return stream << "Strided"; + case at::kSparse: + return stream << "Sparse"; + default: + AT_ERROR("Unknown layout"); + } +} diff --git a/aten/src/ATen/ScalarType.h b/aten/src/ATen/ScalarType.h index 4cb68a6370625..fb4581b34328a 100644 --- a/aten/src/ATen/ScalarType.h +++ b/aten/src/ATen/ScalarType.h @@ -1,11 +1,12 @@ #pragma once -#include - #include "ATen/ArrayRef.h" #include "ATen/ATenGeneral.h" #include "ATen/Half.h" +#include +#include + namespace at { // NB: Order matters for this macro; it is relied upon in @@ -168,3 +169,9 @@ typedef ArrayRef IntList; typedef ArrayRef TensorList; } // namespace at + +inline std::ostream& operator<<( + std::ostream& stream, + at::ScalarType scalar_type) { + return stream << at::toString(scalar_type); +} diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 62c8356351700..b320a1d009684 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -19,7 +19,7 @@ namespace at { // we don't currently support zero-size dimensions, so we can't actually // do this; so we just allocate zero-size tensors for everything. SparseTensorImpl::SparseTensorImpl(Type * type) - : TensorImpl(type) + : TensorImpl(type, nullptr) , size_{0} , sparseDims_(1) , denseDims_(0) diff --git a/aten/src/ATen/TensorImpl.cpp b/aten/src/ATen/TensorImpl.cpp index a77664d95c2a1..c5457f3c57886 100644 --- a/aten/src/ATen/TensorImpl.cpp +++ b/aten/src/ATen/TensorImpl.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace at { Tensor& TensorImpl::grad() { AT_ERROR("grad is not implemented for Tensor"); @@ -33,4 +35,42 @@ void Tensor::backward( bool create_graph) { pImpl->backward(std::move(gradient), keep_graph, create_graph); } + +TensorImpl::~TensorImpl() { + if (tensor) tensor->release(); +} + +IntList TensorImpl::sizes() const { + // NB: dim in tensor is not synchronized with THTensor, so it's + // important to apply dim here + return IntList(THTensor_getSizePtr(tensor), dim()); +} + +IntList TensorImpl::strides() const { + // NB: dim in tensor is not synchronized with THTensor, so it's + // important to apply dim here + return IntList(THTensor_getStridePtr(tensor), dim()); +} + +void TensorImpl::release_resources() { + if (tensor) { + tensor->release(); + tensor = nullptr; + } +} + +int64_t TensorImpl::dim() const { + if (isScalar()) { + return 0; + } + return tensor->dim(); +} + +void * TensorImpl::unsafeGetTH(bool retain) { + if (retain) { + tensor->retain(); + } + return tensor; +} + } // namespace at diff --git a/aten/src/ATen/TensorImpl.h b/aten/src/ATen/TensorImpl.h index f5abf15f4cf94..a2d28eca86b44 100644 --- a/aten/src/ATen/TensorImpl.h +++ b/aten/src/ATen/TensorImpl.h @@ -7,6 +7,8 @@ #include "ATen/ScalarType.h" #include "ATen/optional.h" +struct THTensor; + namespace at { class Scalar; struct Type; @@ -15,23 +17,27 @@ struct Tensor; } // namespace at namespace at { -struct TensorImpl : public Retainable { - explicit TensorImpl(Type * type) - : is_scalar(false), type_(type) {} +struct AT_API TensorImpl : public Retainable { + explicit TensorImpl(Type * type, THTensor * tensor) + : is_scalar(false), type_(type), tensor(tensor) {} + + virtual ~TensorImpl(); + + virtual void release_resources() override; Type & type() const { return *type_; } virtual const char * toString() const = 0; - virtual IntList sizes() const = 0; - virtual IntList strides() const = 0; - virtual int64_t dim() const = 0; + virtual IntList sizes() const; + virtual IntList strides() const; + virtual int64_t dim() const; /** * Perform a conversion of this tensor to a scalar, if numel() == 1. * Otherwise, raise an error. */ virtual Scalar localScalar() = 0; - virtual void * unsafeGetTH(bool retain) = 0; + virtual void * unsafeGetTH(bool retain); virtual std::unique_ptr storage() = 0; friend struct Type; @@ -69,30 +75,32 @@ struct TensorImpl : public Retainable { // Some methods below are defined in TensorImpl.cpp because Tensor is an // incomplete type. - AT_API virtual void set_requires_grad(bool requires_grad) { + virtual void set_requires_grad(bool requires_grad) { AT_ERROR("set_requires_grad is not implemented for Tensor"); } - AT_API virtual bool requires_grad() const { + virtual bool requires_grad() const { AT_ERROR("requires_grad is not implemented for Tensor"); } - AT_API virtual Tensor& grad(); - AT_API virtual const Tensor& grad() const; + virtual Tensor& grad(); + virtual const Tensor& grad() const; - AT_API virtual Tensor detach() const; - AT_API virtual void detach_() { + virtual Tensor detach() const; + virtual void detach_() { AT_ERROR("detach_ is not implemented for Tensor"); } - AT_API virtual void backward( + virtual void backward( at::optional gradient, bool keep_graph, bool create_graph); - AT_API virtual void set_data(Tensor new_data); + virtual void set_data(Tensor new_data); protected: bool is_scalar; Type * type_; +public: + THTensor * tensor; }; } // namespace at diff --git a/aten/src/ATen/TensorOptions.cpp b/aten/src/ATen/TensorOptions.cpp index cb8b9bfedb021..c7b218ebcd680 100644 --- a/aten/src/ATen/TensorOptions.cpp +++ b/aten/src/ATen/TensorOptions.cpp @@ -6,6 +6,8 @@ #include #include +#include + namespace at { TensorOptions::TensorOptions(bool use_thread_local_default_options) { @@ -17,3 +19,13 @@ TensorOptions::TensorOptions(bool use_thread_local_default_options) { } } } // namespace at + +std::ostream& operator<<( + std::ostream& stream, + const at::TensorOptions& options) { + return stream << "TensorOptions(dtype=" << options.dtype() + << ", device=" << options.device() + << ", layout=" << options.layout() + << ", requires_grad=" << std::boolalpha + << options.requires_grad() << ")"; +} diff --git a/aten/src/ATen/TensorOptions.h b/aten/src/ATen/TensorOptions.h index 53ad9d827c628..20b0d1ed71d78 100644 --- a/aten/src/ATen/TensorOptions.h +++ b/aten/src/ATen/TensorOptions.h @@ -9,6 +9,7 @@ #include #include +#include #include namespace at { @@ -277,3 +278,7 @@ inline Tensor Tensor::to(Device device, bool non_blocking) const { return detail::to(*this, options().device(device), non_blocking); } } // namespace at + +std::ostream& operator<<( + std::ostream& stream, + const at::TensorOptions& options); diff --git a/aten/src/ATen/UndefinedTensor.cpp b/aten/src/ATen/UndefinedTensor.cpp index 9c9e989417ac1..0de3c05a127ae 100644 --- a/aten/src/ATen/UndefinedTensor.cpp +++ b/aten/src/ATen/UndefinedTensor.cpp @@ -6,7 +6,7 @@ namespace at { // should this use the globalContext? Can it get a context passed in somehow? UndefinedTensor::UndefinedTensor() -: TensorImpl(&(globalContext().getType(Backend::Undefined,ScalarType::Undefined))) { +: TensorImpl(&(globalContext().getType(Backend::Undefined,ScalarType::Undefined)), nullptr) { } const char * UndefinedTensor::toString() const { diff --git a/aten/src/ATen/code_template.py b/aten/src/ATen/code_template.py index f239030db0658..1cebf11839e7c 100644 --- a/aten/src/ATen/code_template.py +++ b/aten/src/ATen/code_template.py @@ -50,7 +50,9 @@ def replace(match): comma_after = ', ' key = key[:-1] v = lookup(key) - if indent is not None and isinstance(v, list): + if indent is not None: + if not isinstance(v, list): + v = [v] return indent_lines(indent, v) elif isinstance(v, list): middle = ', '.join([str(x) for x in v]) @@ -58,7 +60,7 @@ def replace(match): return middle return comma_before + middle + comma_after else: - return (indent or '') + str(v) + return str(v) return self.subtitution.sub(replace, self.pattern) diff --git a/aten/src/ATen/copy_wrapper.py b/aten/src/ATen/copy_wrapper.py index 02eb56e4129c2..feda7573a7818 100644 --- a/aten/src/ATen/copy_wrapper.py +++ b/aten/src/ATen/copy_wrapper.py @@ -116,7 +116,7 @@ def create_one_copy(dst_type, all_types): cuda = '' state = [] if src_type['Backend'] == 'CUDA' or dst_type['Backend'] == 'CUDA': - state.append('context->getTHCState()') + state.append('globalContext().getTHCState()') if src_type['Backend'] == 'CUDA': if dst_type['Backend'] == 'CUDA': cuda = 'Cuda' @@ -183,7 +183,7 @@ def create_one_copy_from(src_type, all_types): if src_type['Backend'] == 'CUDA': cuda = 'Cuda' if dst_type['Backend'] == 'CUDA' or src_type['Backend'] == 'CUDA': - state.append('context->getTHCState()') + state.append('globalContext().getTHCState()') body_env = nested_dict({ 'src_scalar_name': src_type['ScalarName'], diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 1c06654fe891b..dc43fbc9cac1b 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -180,7 +180,7 @@ def TypedDict(name, attrs, total=True): # type: ignore }""") BUFFER_DEFINITION = CodeTemplate("""\ -auto ${name}_ = new ${Tensor}(context); +auto ${name}_ = new ${Tensor}(${THTensor}_new()); auto ${name} = Tensor(${name}_, false);""") CONDITIONAL_INITIALIZER = CodeTemplate("""\ @@ -277,7 +277,7 @@ def __init__(self, reason): 'THStorage*': CodeTemplate('checked_cast_storage<${Storage}>(&${arg_name},"${arg_name}",${arg_pos})'), 'THGenerator*': CodeTemplate( - 'check_generator<${Backend}Generator>(${arg_name}, &context->defaultGenerator(backend()))'), + 'check_generator<${Backend}Generator>(${arg_name}, &globalContext().defaultGenerator(backend()))'), # This is a cast done via direct-construction 'THSize*': CodeTemplate('THLongStorageView ${result_name}(${arg_name}, THLongStorageViewKind::SIZE);'), # This is a cast done via direct-construction @@ -306,14 +306,24 @@ def __init__(self, reason): CHECKED_USE_NULLABLE = CodeTemplate('${arg_name}_ ? ${usage} : NULL') +ALLOC_NOARGS_WRAP = { + 'THTensor*': 'detail::new_${Tensor}()', + 'THBoolTensor*': 'detail::new_${Backend}ByteTensor()', + 'THIndexTensor*': 'detail::new_${Backend}LongTensor()', + 'THIntegerTensor*': 'detail::new_${Backend}IntTensor()', + 'THSTensor*': 'detail::new_Sparse${Tensor}()', + 'THDenseTensor*': 'detail::new_${DenseTensor}()', + 'THDenseIndexTensor*': 'detail::new_${DenseBackend}LongTensor()', +} + ALLOC_WRAP = { - 'THTensor*': 'new ${Tensor}(context${,arguments})', - 'THBoolTensor*': 'new ${Backend}ByteTensor(context${,arguments})', - 'THIndexTensor*': 'new ${Backend}LongTensor(context${,arguments})', - 'THIntegerTensor*': 'new ${Backend}IntTensor(context${,arguments})', - 'THSTensor*': 'new Sparse${Tensor}(context${,arguments})', - 'THDenseTensor*': 'new ${DenseTensor}(context${,arguments})', - 'THDenseIndexTensor*': 'new ${DenseBackend}LongTensor(context${,arguments})', + 'THTensor*': 'new ${Tensor}(${arguments})', + 'THBoolTensor*': 'new ${Backend}ByteTensor(${arguments})', + 'THIndexTensor*': 'new ${Backend}LongTensor(${arguments})', + 'THIntegerTensor*': 'new ${Backend}IntTensor(${arguments})', + 'THSTensor*': 'new Sparse${Tensor}(${arguments})', + 'THDenseTensor*': 'new ${DenseTensor}(${arguments})', + 'THDenseIndexTensor*': 'new ${DenseBackend}LongTensor(${arguments})', } # Replacements for constants when calling into TH @@ -1228,7 +1238,10 @@ def handle_sparse(env, option): def allocate_arg(env, arg, output_count): # type: (Environment, THFormal, int) -> List[str] name = arg['name'] - allocation = CodeTemplate(ALLOC_WRAP[arg['type']]).substitute(env, arguments=[]) + state = '' + if is_cuda: + state = 'globalContext().getTHCState()' + allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(env) tensor_arg = '{}_'.format(name) if arg.get('mask', False): allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation) @@ -1257,7 +1270,7 @@ def handle_call(env, option, cimpl): is_nn = option['mode'] == 'NN' actuals = get_arguments(cimpl['arguments'], option) if is_cuda or is_nn: - actuals = ['context->getTHCState()'] + actuals + actuals = ['globalContext().getTHCState()'] + actuals cname = cimpl['cname'] if option.get('sparse', False): diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 6d3598a1957cd..0f2aaffd6eac9 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -273,7 +273,7 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations env['THStorage'] = 'THCuda{}Storage'.format(sname) env['THTensor'] = 'THCuda{}Tensor'.format(sname) env['THIndexTensor'] = 'THCudaLongTensor' - env['state'] = ['context->getTHCState()'] + env['state'] = ['globalContext().getTHCState()'] env['isCUDA'] = 'true' env['storage_device'] = 'return storage->device;' env['Generator'] = 'CUDAGenerator' diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp new file mode 100644 index 0000000000000..662ae580c599a --- /dev/null +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -0,0 +1,44 @@ +#include "DispatchStub.h" + +#include + +#include +#include +#include + +namespace at { namespace native { + +static CPUCapability compute_cpu_capability() { + auto envar = std::getenv("ATEN_CPU_CAPABILITY"); + if (envar) { + if (strcmp(envar, "avx2") == 0) { + return CPUCapability::AVX2; + } + if (strcmp(envar, "avx") == 0) { + return CPUCapability::AVX; + } + if (strcmp(envar, "default") == 0) { + return CPUCapability::DEFAULT; + } + AT_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar); + } + +#ifndef __powerpc__ + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) { + return CPUCapability::AVX2; + } + if (cpuinfo_has_x86_avx()) { + return CPUCapability::AVX; + } + } +#endif + return CPUCapability::DEFAULT; +} + +CPUCapability get_cpu_capability() { + static CPUCapability capability = compute_cpu_capability(); + return capability; +} + +}} // namespace at::native diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h new file mode 100644 index 0000000000000..64a33c60d6f71 --- /dev/null +++ b/aten/src/ATen/native/DispatchStub.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include + +// Implements instruction set specific function dispatch. +// +// Kernels that may make use of specialized instruction sets (e.g. AVX) are +// compiled multiple times with different compiler flags (e.g. -mavx). A +// DispatchStub contains a table of function pointers for a kernel. At runtime, +// the fastest available kernel is chosen based on the features reported by +// cpuinfo. +// +// Example: +// +// In native/MyKernel.h: +// using fn_type = void(*)(const Tensor& x); +// DECLARE_DISPATCH(fn_type, stub); +// +// In native/MyKernel.cpp +// DEFINE_DISPATCH(stub); +// +// In native/cpu/MyKernel.cpp: +// void kernel(const Tensor& x) { ... } +// REGISTER_DISPATCH(stub, &kernel); +// +// To call: +// stub(kCPU, tensor); + +// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wundefined-var-template" +#endif + +namespace at { namespace native { + +enum class CPUCapability { + DEFAULT = 0, + AVX = 1, + AVX2 = 2, + NUM_OPTIONS +}; + +CPUCapability get_cpu_capability(); + +template +struct DispatchStub { + static_assert(std::is_pointer::value, "FnPtr should be a pointer type"); + + template + void operator()(Backend backend, ArgTypes... args) { + if (backend == Backend::CPU) { + if (!cpu_dispatch_ptr) { + cpu_dispatch_ptr = choose_cpu_impl(); + } + (*cpu_dispatch_ptr)(args...); + } else if (backend == Backend::CUDA) { + AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel"); + (*cuda_dispatch_ptr)(args...); + } else { + AT_ERROR("DispatchStub: unsupported backend", backend); + } + } + + FnPtr choose_cpu_impl() { + auto capability = static_cast(get_cpu_capability()); + (void)capability; +#ifdef HAVE_AVX2_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::AVX2)) { + AT_ASSERTM(AVX2, "DispatchStub: missing AVX2 kernel"); + return AVX2; + } +#endif +#ifdef HAVE_AVX_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::AVX)) { + AT_ASSERTM(AVX, "DispatchStub: missing AVX kernel"); + return AVX; + } +#endif + AT_ASSERTM(DEFAULT, "DispatchStub: missing default kernel"); + return DEFAULT; + } + + FnPtr cpu_dispatch_ptr = nullptr; + FnPtr cuda_dispatch_ptr = nullptr; + static FnPtr DEFAULT; +#ifdef HAVE_AVX_CPU_DEFINITION + static FnPtr AVX; +#endif +#ifdef HAVE_AVX2_CPU_DEFINITION + static FnPtr AVX2; +#endif +}; + +namespace { +template +struct RegisterDispatch { + RegisterDispatch(DispatchStub& stub, FnPtr value) { + stub.cuda_dispatch_ptr = value; + } +}; +} // anonymous namespace + +#define DECLARE_DISPATCH(fn, name) \ + extern struct name : DispatchStub {} name + +#define DEFINE_DISPATCH(name) struct name name + +#if defined(__CUDACC__) +#define REGISTER_DISPATCH(name, fn) \ + static RegisterDispatch name ## __register(name, fn); +#elif defined(CPU_CAPABILITY) +#define REGISTER_DISPATCH(name, fn) \ + template <> decltype(fn) DispatchStub::CPU_CAPABILITY = fn; +#endif + + +}} // namespace at::native + + +#if defined(__clang__) +#pragma clang diagnostic pop +#endif diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index affa9d24059d9..8a8187df4e8d9 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -17,6 +17,9 @@ namespace at { namespace native { +DEFINE_DISPATCH(sum_kernel); +DEFINE_DISPATCH(prod_kernel); + static inline Tensor integer_upcast(const Tensor& self, optional dtype) { ScalarType scalarType = self.type().scalarType(); ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType); @@ -127,7 +130,7 @@ Tensor sum(const Tensor &self) { Tensor _sum_cpu(const Tensor& self) { if (self.is_contiguous()) { Tensor result = at::empty({}, self.type()); - sum_kernel(result, self, at::nullopt); + sum_kernel(kCPU, result, self, at::nullopt); return result; } return self._sumall(); @@ -148,7 +151,7 @@ Tensor prod(const Tensor &self) { Tensor _prod_cpu(const Tensor &self) { if (self.is_contiguous()) { Tensor result = at::empty({}, self.type()); - prod_kernel(result, self, at::nullopt); + prod_kernel(kCPU, result, self, at::nullopt); return result; } return self._prodall(); @@ -222,7 +225,7 @@ Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, return result; if (self.is_contiguous() && result.is_contiguous()) { _dimreduce_setup(result, self, dim); - sum_kernel(result, self, dim); + sum_kernel(kCPU, result, self, dim); if (!keepdim) result.squeeze_(dim); return result; } @@ -260,7 +263,7 @@ Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_, return result; if (self.is_contiguous() && result.is_contiguous()) { _dimreduce_setup(result, self, dim); - prod_kernel(result, self, dim); + prod_kernel(kCPU, result, self, dim); if (!keepdim) result.squeeze_(dim); return result; } diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 546c75829a6d5..aebb021696084 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -128,7 +128,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_) { dim >= 0 && dim < input.dim(), "dim must be non-negative and less than input dimensions"); if (input.ndimension() > 0 && dim == input.ndimension() - 1) { - softmax_lastdim_kernel(output, input); + softmax_lastdim_kernel(kCPU, output, input); } else { AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] { host_softmax(output, input, dim); @@ -147,7 +147,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_) { dim >= 0 && dim < input.dim(), "dim must be non-negative and less than input dimensions"); if (input.ndimension() > 0 && dim == input.ndimension() - 1) { - log_softmax_lastdim_kernel(output, input); + log_softmax_lastdim_kernel(kCPU, output, input); } else { AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] { host_softmax(output, input, dim); @@ -176,7 +176,7 @@ Tensor softmax_backward_cpu( dim >= 0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) { - softmax_backward_lastdim_kernel(grad_input, grad, output); + softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output); } else { AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] { host_softmax_backward(grad_input, grad, output, dim); @@ -205,7 +205,7 @@ Tensor log_softmax_backward_cpu( dim >= 0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) { - log_softmax_backward_lastdim_kernel(grad_input, grad, output); + log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output); } else { AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] { host_softmax_backward(grad_input, grad, output, dim); @@ -213,5 +213,11 @@ Tensor log_softmax_backward_cpu( } return grad_input; } + +DEFINE_DISPATCH(softmax_lastdim_kernel); +DEFINE_DISPATCH(log_softmax_lastdim_kernel); +DEFINE_DISPATCH(softmax_backward_lastdim_kernel); +DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel); + } } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index dbfc623b0ccba..f988e261eb991 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -92,14 +92,14 @@ Tensor& fill_(Tensor& self, const Tensor& value) { Tensor& _##op##__cpu(Tensor& self_) { \ if (self_.numel() > 0) { \ Tensor self = sort_strides(self_); \ - op##Impl(self, self); \ + op##Impl(kCPU, self, self); \ } \ return self_; \ } \ Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \ result.resize_(self.sizes()); \ if (result.numel() > 0) { \ - op##Impl(result, self); \ + op##Impl(kCPU, result, self); \ } \ return result; \ } @@ -145,5 +145,29 @@ IMPLEMENT_UNARY_OP_VEC(tan) IMPLEMENT_UNARY_OP_VEC(tanh) IMPLEMENT_UNARY_OP_VEC(trunc) +DEFINE_DISPATCH(absImpl); +DEFINE_DISPATCH(acosImpl); +DEFINE_DISPATCH(asinImpl); +DEFINE_DISPATCH(atanImpl); +DEFINE_DISPATCH(ceilImpl); +DEFINE_DISPATCH(cosImpl); +DEFINE_DISPATCH(erfImpl); +DEFINE_DISPATCH(erfcImpl); +DEFINE_DISPATCH(expImpl); +DEFINE_DISPATCH(expm1Impl); +DEFINE_DISPATCH(floorImpl); +DEFINE_DISPATCH(logImpl); +DEFINE_DISPATCH(log10Impl); +DEFINE_DISPATCH(log1pImpl); +DEFINE_DISPATCH(log2Impl); +DEFINE_DISPATCH(roundImpl); +DEFINE_DISPATCH(rsqrtImpl); +DEFINE_DISPATCH(sigmoidImpl); +DEFINE_DISPATCH(sinImpl); +DEFINE_DISPATCH(sqrtImpl); +DEFINE_DISPATCH(tanImpl); +DEFINE_DISPATCH(tanhImpl); +DEFINE_DISPATCH(truncImpl); + } } // namespace at diff --git a/aten/src/ATen/native/cpu/CapabilityDispatch.h b/aten/src/ATen/native/cpu/CapabilityDispatch.h deleted file mode 100644 index 6cb0f279872d6..0000000000000 --- a/aten/src/ATen/native/cpu/CapabilityDispatch.h +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -#include -#include -#include - -// Implements instruction set specific function dispatch. -// -// Kernels that may make use of specialized instruction sets (e.g. AVX) are -// compiled multiple times with different compiler flags (e.g. -mavx). A -// DispatchStub contains a table of function pointers for a kernel. At runtime, -// the fastest available kernel is chosen based on the features reported by -// cpuinfo. -// -// Example: -// -// In native/cpu/MyKernel.h: -// using fn_type = void(*)(const Tensor& x); -// DispatchStub stub; -// -// In native/cpu/MyKernel.cpp: -// void kernel(const Tensor& x) { ... } -// REGISTER_DISPATCH(stub, &kernel); -// -// To call: -// stub(tensor); -// - -namespace at { -namespace native { - -enum class CPUCapability { DEFAULT, AVX, AVX2, NUM_OPTIONS }; - -template -struct DispatchStub { - static_assert(std::is_pointer::value, "FnPtr should be a pointer type"); - - template - void operator()(ArgTypes... args) { - if (!dispatch_ptr) { - dispatch_ptr = choose_impl(); - } - (*dispatch_ptr)(args...); - } - - FnPtr choose_impl() { -// Do not use cpuinfo on PowerPC as it shows confusing errors when run on ppc -#ifndef __powerpc__ - if (cpuinfo_initialize()) { - int avx2 = static_cast(CPUCapability::AVX2); - if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() && - cpuinfo_has_x86_fma3() && table[avx2]) { - return table[avx2]; - } - int avx = static_cast(CPUCapability::AVX); - if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx() && table[avx]) { - return table[avx]; - } - } -#endif - int def = static_cast(CPUCapability::DEFAULT); - AT_ASSERTM(table[def], "DispatchStub: missing default kernel"); - return table[def]; - } - - FnPtr dispatch_ptr = nullptr; - FnPtr table[static_cast(CPUCapability::NUM_OPTIONS)]; -}; - - -#if defined(CPU_CAPABILITY) - -constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::CPU_CAPABILITY; - -// Registers an implementation a kernel for the current CPU capability. -template -struct RegisterDispatch { - RegisterDispatch(DispatchStub& stub, FnPtr value) { - stub.table[static_cast(CURRENT_CAPABILITY)] = value; - } -}; - -// We only define the stub once in the DEFAULT capability compilation -#if defined(CPU_CAPABILITY_DEFAULT) -#define _DEFINE_STUB(stub, fn) DispatchStub stub -#else -#define _DEFINE_STUB(stub, fn) -#endif - -#define REGISTER_DISPATCH(stub, fn) \ - _DEFINE_STUB(stub, fn); \ - static RegisterDispatch stub ## __register(stub, fn); - -#endif - -} -} diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.h b/aten/src/ATen/native/cpu/ReduceOpsKernel.h index 9481b90fe7696..2423b4b523018 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.h +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.h @@ -1,16 +1,14 @@ #pragma once #include +#include #include -#include "CapabilityDispatch.h" -namespace at { -namespace native { +namespace at { namespace native { using reduce_fn = void(*)(Tensor &, const Tensor &, at::optional); -extern DispatchStub sum_kernel; -extern DispatchStub prod_kernel; +DECLARE_DISPATCH(reduce_fn, sum_kernel); +DECLARE_DISPATCH(reduce_fn, prod_kernel); -} -} +}} // namespace at::native diff --git a/aten/src/ATen/native/cpu/SoftmaxKernel.h b/aten/src/ATen/native/cpu/SoftmaxKernel.h index dbd703b6d3c02..0fb2a8e18a5ff 100644 --- a/aten/src/ATen/native/cpu/SoftmaxKernel.h +++ b/aten/src/ATen/native/cpu/SoftmaxKernel.h @@ -1,7 +1,7 @@ #pragma once #include -#include "CapabilityDispatch.h" +#include namespace at { namespace native { @@ -9,10 +9,10 @@ namespace native { using forward_fn = void(*)(Tensor &, const Tensor &); using backward_fn = void(*)(Tensor &, const Tensor &, const Tensor&); -extern DispatchStub softmax_lastdim_kernel; -extern DispatchStub log_softmax_lastdim_kernel; -extern DispatchStub softmax_backward_lastdim_kernel; -extern DispatchStub log_softmax_backward_lastdim_kernel; +DECLARE_DISPATCH(forward_fn, softmax_lastdim_kernel); +DECLARE_DISPATCH(forward_fn, log_softmax_lastdim_kernel); +DECLARE_DISPATCH(backward_fn, softmax_backward_lastdim_kernel); +DECLARE_DISPATCH(backward_fn, log_softmax_backward_lastdim_kernel); } } diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 7416923cfd886..459838a9b6c68 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -4,7 +4,7 @@ #include "ATen/Dispatch.h" #include "ATen/cpu/vml.h" #include "ATen/CPUApplyUtils.h" -#include "ATen/native/cpu/CapabilityDispatch.h" +#include "ATen/native/DispatchStub.h" #ifdef __AVX2__ #include "ATen/native/cpu/avx_mathfun.h" #endif diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.h b/aten/src/ATen/native/cpu/UnaryOpsKernel.h index d9bffadd1e1fb..157dda8b2598a 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.h +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.h @@ -1,38 +1,38 @@ #pragma once #include +#include #include -#include "CapabilityDispatch.h" namespace at { namespace native { using unary_fn = void(*)(Tensor&, const Tensor&); -extern DispatchStub absImpl; -extern DispatchStub acosImpl; -extern DispatchStub asinImpl; -extern DispatchStub atanImpl; -extern DispatchStub ceilImpl; -extern DispatchStub cosImpl; -// extern DispatchStub coshImpl; -extern DispatchStub erfImpl; -extern DispatchStub erfcImpl; -extern DispatchStub expImpl; -extern DispatchStub expm1Impl; -extern DispatchStub floorImpl; -extern DispatchStub logImpl; -extern DispatchStub log10Impl; -extern DispatchStub log1pImpl; -extern DispatchStub log2Impl; -extern DispatchStub roundImpl; -extern DispatchStub rsqrtImpl; -extern DispatchStub sigmoidImpl; -extern DispatchStub sinImpl; -// extern DispatchStub sinhImpl; -extern DispatchStub sqrtImpl; -extern DispatchStub tanImpl; -extern DispatchStub tanhImpl; -extern DispatchStub truncImpl; +DECLARE_DISPATCH(unary_fn, absImpl); +DECLARE_DISPATCH(unary_fn, acosImpl); +DECLARE_DISPATCH(unary_fn, asinImpl); +DECLARE_DISPATCH(unary_fn, atanImpl); +DECLARE_DISPATCH(unary_fn, ceilImpl); +DECLARE_DISPATCH(unary_fn, cosImpl); +// DECLARE_DISPATCH(unary_fn, coshImpl); +DECLARE_DISPATCH(unary_fn, erfImpl); +DECLARE_DISPATCH(unary_fn, erfcImpl); +DECLARE_DISPATCH(unary_fn, expImpl); +DECLARE_DISPATCH(unary_fn, expm1Impl); +DECLARE_DISPATCH(unary_fn, floorImpl); +DECLARE_DISPATCH(unary_fn, logImpl); +DECLARE_DISPATCH(unary_fn, log10Impl); +DECLARE_DISPATCH(unary_fn, log1pImpl); +DECLARE_DISPATCH(unary_fn, log2Impl); +DECLARE_DISPATCH(unary_fn, roundImpl); +DECLARE_DISPATCH(unary_fn, rsqrtImpl); +DECLARE_DISPATCH(unary_fn, sigmoidImpl); +DECLARE_DISPATCH(unary_fn, sinImpl); +// DECLARE_DISPATCH(unary_fn, sinhImpl); +DECLARE_DISPATCH(unary_fn, sqrtImpl); +DECLARE_DISPATCH(unary_fn, tanImpl); +DECLARE_DISPATCH(unary_fn, tanhImpl); +DECLARE_DISPATCH(unary_fn, truncImpl); // Missing unary functions diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 8599d17611bd7..6351b8aa635ba 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -483,44 +483,47 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_){ AT_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions"); int64_t outer_size = 1; int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - if (inner_size == 1) { - const int ILP = 2; - dim3 grid(outer_size); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { - using accscalar_t = acc_type; - cunn_SoftMaxForward - <<>>( - output.data(), input.data(), dim_size - ); - }); - // This kernel runs in a 2D grid, where each application along y dimension has a fixed - // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. - // Reductions over dim are done in a single-threaded manner. - } else { - uint32_t smem_size; - dim3 grid, block; - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { - using accscalar_t = acc_type; - SpatialSoftMax_getLaunchSizes( - &cunn_SpatialSoftMaxForward, - outer_size, dim_size, inner_size, - grid, block, smem_size); - cunn_SpatialSoftMaxForward - <<>>( - output.data(), input.data(), outer_size, dim_size, inner_size - ); - }); + + if (input.numel() > 0) { + int64_t inner_size = 1; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + for (int64_t i = 0; i < dim; ++i) + outer_size *= input.size(i); + for (int64_t i = dim + 1; i < input.dim(); ++i) + inner_size *= input.size(i); + // This kernel spawns a block per each element in the batch. + // XXX: it assumes that inner_size == 1 + if (inner_size == 1) { + const int ILP = 2; + dim3 grid(outer_size); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { + using accscalar_t = acc_type; + cunn_SoftMaxForward + <<>>( + output.data(), input.data(), dim_size + ); + }); + // This kernel runs in a 2D grid, where each application along y dimension has a fixed + // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. + // Reductions over dim are done in a single-threaded manner. + } else { + uint32_t smem_size; + dim3 grid, block; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "host_softmax", [&] { + using accscalar_t = acc_type; + SpatialSoftMax_getLaunchSizes( + &cunn_SpatialSoftMaxForward, + outer_size, dim_size, inner_size, + grid, block, smem_size); + cunn_SpatialSoftMaxForward + <<>>( + output.data(), input.data(), outer_size, dim_size, inner_size + ); + }); + } + THCudaCheck(cudaGetLastError()); } - THCudaCheck(cudaGetLastError()); return output; } diff --git a/aten/src/ATen/native/cuda/TensorTransformations.cu b/aten/src/ATen/native/cuda/TensorTransformations.cu index c1c0e943fde7c..ee4d030f775e7 100644 --- a/aten/src/ATen/native/cuda/TensorTransformations.cu +++ b/aten/src/ATen/native/cuda/TensorTransformations.cu @@ -75,9 +75,13 @@ Tensor flip_cuda(const Tensor& self, IntList dims) { dim3 dim_block(block_size); dim3 dim_grid((N + block_size - 1) / block_size); + auto out_tensor = at::empty_like(in_tensor); + if (out_tensor.numel() == 0) { + return out_tensor; + } + // use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work if (flip_dims_size == 1 && in_tensor.is_contiguous() && (dims[0] == 0 || dims[0] == total_dims - 1)) { - auto out_tensor = at::empty_like(self); AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] { auto in_tensor_info = cuda::detail::getTensorInfo(in_tensor); auto out_tensor_info = cuda::detail::getTensorInfo(out_tensor); @@ -99,8 +103,6 @@ Tensor flip_cuda(const Tensor& self, IntList dims) { auto strides = std::vector(in_tensor.strides()); auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast(strides.size())}); - auto out_tensor = at::empty_like(in_tensor); - // stride_contiguous is the stride of non-contiguous tensor after calling contiguous(), // it is used to compute indices for each element in non-contiguous tensor Tensor stride_contiguous = at::zeros({total_dims}, kLong); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index ddeae3bd5707a..aa5b7d6d8f344 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -51,9 +51,9 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT AT_CHECK(_check_device({sparse_, r_, t, dense})); // TODO: This error message seems awfully opaque - AT_CHECK(sparse_._sparseDims() == 2, "addmm: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + AT_CHECK(sparse_._sparseDims() == 2, "addmm: 2D tensor expected, got ", sparse_._sparseDims(), "D tensor"); AT_CHECK(sparse_._denseDims() == 0, "addmm: scalar values expected, got ", sparse_._denseDims(), "D values"); - AT_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor"); + AT_CHECK(dense.dim() == 2, "addmm: 2D tensor expected, got ", dense.dim(), "D tensor"); // mxk * kxn = mxn int64_t m = sparse_.size(0); @@ -183,11 +183,11 @@ SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse AT_CHECK(_check_device({r_, sparse_, dense})); AT_CHECK(sparse_._sparseDims() == 2, - "hspmm: Argument #2: matrices expected, got ", sparse_._sparseDims(), "D tensor"); + "hspmm: Argument #2: 2D tensor expected, got ", sparse_._sparseDims(), "D tensor"); AT_CHECK(sparse_._denseDims() == 0, "hspmm: Argument #2: scalar values expected, got ", sparse_._denseDims(), "D values"); AT_CHECK(dense.dim() == 2, - "hspmm: Argument #3: matrices expected, got ", dense.dim(), "D tensor"); + "hspmm: Argument #3: 2D tensor expected, got ", dense.dim(), "D tensor"); int64_t m = sparse_.size(0); int64_t k = sparse_.size(1); diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 2c84f212a7cc4..c2174e933a90d 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -35,7 +35,7 @@ inline Tensor from_blob( void* data, IntList sizes, const TensorOptions& options = {}) { - return native::from_blob(data, sizes, [](void*) {}, options); + return native::from_blob(data, sizes, /*deleter=*/[](void*) {}, options); } // These functions are defined in native/TensorFactories.cpp. diff --git a/aten/src/ATen/templates/StorageDerived.cpp b/aten/src/ATen/templates/StorageDerived.cpp index 28e17e0d0c168..398969236eba0 100644 --- a/aten/src/ATen/templates/StorageDerived.cpp +++ b/aten/src/ATen/templates/StorageDerived.cpp @@ -10,17 +10,17 @@ namespace at { -${Storage}::${Storage}(Context* context): - Storage(${THStorage}_new(${state})), context(context) {} +${Storage}::${Storage}(): + Storage(${THStorage}_new(${state})) {} -${Storage}::${Storage}(Context* context, THStorage* storage): - Storage(storage), context(context) {} +${Storage}::${Storage}(THStorage* storage): + Storage(storage) {} -${Storage}::${Storage}(Context* context, size_t storage_size) - : Storage(${THStorage}_newWithSize(${state,} storage_size)), context(context) {} +${Storage}::${Storage}(size_t storage_size) + : Storage(${THStorage}_newWithSize(${state,} storage_size)) {} -${Storage}::${Storage}(Context* context, size_t size, Allocator* allocator) - : Storage(nullptr), context(context) { +${Storage}::${Storage}(size_t size, Allocator* allocator) + : Storage(nullptr) { storage = ${THStorage}_newWithAllocator(${state,} size, allocator); ${THStorage}_clearFlag(${state,} storage, TH_STORAGE_RESIZABLE); } @@ -35,7 +35,7 @@ static int getPointerDevice(void* ptr) { } #endif -${Storage}::${Storage}(Context* context, +${Storage}::${Storage}( void * data, size_t size, const std::function & deleter) : Storage(${THStorage}_newWithDataAndAllocator(${state,} InefficientStdFunctionContext::makeDataPtr(data, deleter, @@ -46,7 +46,7 @@ static int getPointerDevice(void* ptr) { #endif ), size, /* allocator */ nullptr - )), context(context) { + )) { ${THStorage}_clearFlag(${state,} storage, TH_STORAGE_RESIZABLE); } @@ -57,7 +57,7 @@ size_t ${Storage}::elementSize() const { } Type& ${Storage}::type() const { - return context->getType(Backend::${Backend},ScalarType::${ScalarName}); + return globalContext().getType(Backend::${Backend},ScalarType::${ScalarName}); } const char * ${Storage}::typeString() { diff --git a/aten/src/ATen/templates/StorageDerived.h b/aten/src/ATen/templates/StorageDerived.h index 8cfa8c1d01d3b..cb091a52d3c51 100644 --- a/aten/src/ATen/templates/StorageDerived.h +++ b/aten/src/ATen/templates/StorageDerived.h @@ -15,11 +15,11 @@ struct Allocator; struct ${Storage} final : public Storage { public: - explicit ${Storage}(Context* context); - ${Storage}(Context* context, THStorage *wrapped); - ${Storage}(Context* context, size_t size); - ${Storage}(Context* context, size_t size, Allocator* allocator); - ${Storage}(Context* context, + ${Storage}(); + ${Storage}(THStorage *wrapped); + ${Storage}(size_t size); + ${Storage}(size_t size, Allocator* allocator); + ${Storage}( void * data, size_t size, const std::function & deleter); ~${Storage}(); @@ -31,7 +31,6 @@ struct ${Storage} final : public Storage { protected: friend struct ${Type}; - Context* context; }; } // namespace at diff --git a/aten/src/ATen/templates/TensorDense.cpp b/aten/src/ATen/templates/TensorDense.cpp index 1ca2cda09fa7d..ed4eb3271fc56 100644 --- a/aten/src/ATen/templates/TensorDense.cpp +++ b/aten/src/ATen/templates/TensorDense.cpp @@ -1,17 +1,12 @@ // included as 'TensorDenseOrSparse' in TensorDerived.cpp -IntList ${Tensor}::strides() const { - // NB: THTensor doesn't agree with Tensor for scalars, so we - // have to construct a fresh IntList - return IntList(THTensor_getStridePtr(tensor), dim()); -} Scalar ${Tensor}::localScalar() { int64_t numel = ${THTensor}_nElement(${state,}tensor); AT_CHECK(numel == 1,"a Tensor with ", numel, " elements cannot be converted to Scalar"); return Scalar(${to_at_type}(${THStorage}_get(${state,} THTensor_getStoragePtr(tensor), tensor->storage_offset()))); } std::unique_ptr ${Tensor}::storage() { - auto storage = ${THTensor}_storage(${state,}tensor); - ${THStorage}_retain(${state,}storage); - return std::unique_ptr(new ${Storage}(&type().get_context(), storage)); + auto storage = THTensor_getStoragePtr(tensor); + THStorage_retain(storage); + return std::unique_ptr(new ${Storage}(storage)); } diff --git a/aten/src/ATen/templates/TensorDerived.cpp b/aten/src/ATen/templates/TensorDerived.cpp index 70f2cc260b62f..0d5bb415293f0 100644 --- a/aten/src/ATen/templates/TensorDerived.cpp +++ b/aten/src/ATen/templates/TensorDerived.cpp @@ -15,48 +15,23 @@ namespace at { -${Tensor}::${Tensor}(Context* context) -: ${Tensor}(context,${THTensor}_new(${state})) {} - -${Tensor}::${Tensor}(Context* context, ${THTensor} * tensor) -: TensorImpl(&context->getType(Backend::${Backend},ScalarType::${ScalarName})), - tensor(tensor), - context(context) {} - -${Tensor}::~${Tensor}() { - if (tensor) tensor->release(); +namespace detail { + ${Tensor}* new_${Tensor}() { + return new ${Tensor}(${THTensor}_new(${state})); + } } +${Tensor}::${Tensor}(${THTensor} * tensor) +: TensorImpl(&globalContext().getType(Backend::${Backend},ScalarType::${ScalarName}), tensor) +{} + const char * ${Tensor}::toString() const { return "${Tensor}"; } -IntList ${Tensor}::sizes() const { - // NB: dim in tensor is not synchronized with THTensor, so it's - // important to apply dim here - return IntList(THTensor_getSizePtr(tensor), dim()); -} - -int64_t ${Tensor}::dim() const { - if(isScalar()) - return 0; - return tensor->dim(); -} - const char * ${Tensor}::typeString() { return "${Type}"; } -void * ${Tensor}::unsafeGetTH(bool retain) { - if (retain) { - tensor->retain(); - } - return tensor; -} - -void ${Tensor}::release_resources() { - tensor->release(); - tensor = nullptr; -} ${TensorDenseOrSparse} diff --git a/aten/src/ATen/templates/TensorDerived.h b/aten/src/ATen/templates/TensorDerived.h index 892d6bcca5827..c9e5b9f870def 100644 --- a/aten/src/ATen/templates/TensorDerived.h +++ b/aten/src/ATen/templates/TensorDerived.h @@ -6,32 +6,25 @@ #include "ATen/Tensor.h" #include "ATen/TensorImpl.h" -#include "ATen/Context.h" #include "ATen/TensorMethods.h" namespace at { struct ${Tensor} final : public TensorImpl { public: - explicit ${Tensor}(Context* context); - ${Tensor}(Context* context, ${THTensor} * tensor); - virtual ~${Tensor}(); + ${Tensor}(THTensor * tensor); virtual const char * toString() const override; - virtual IntList sizes() const override; - virtual IntList strides() const override; - virtual int64_t dim() const override; virtual Scalar localScalar() override; - virtual void * unsafeGetTH(bool retain) override; virtual std::unique_ptr storage() override; - virtual void release_resources() override; static const char * typeString(); - -//TODO(zach): sort of friend permissions later so this -// can be protected -public: - ${THTensor} * tensor; - Context* context; - friend struct ${Type}; }; +namespace detail { + // This is just a temporary function to help out code generation. + // Eventually, the codegen code should construct tensors using + // a new Tensor constructor that takes scalar type and backend, + // but I haven't written this yet. + ${Tensor}* new_${Tensor}(); +} + } // namespace at diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index 6699070685f56..67009473dddef 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -44,28 +44,28 @@ bool ${Type}::is_sparse() const { return backend() == kSparseCPU || backend() == bool ${Type}::is_distributed() const { return false; } std::unique_ptr ${Type}::storage() const { - return std::unique_ptr(new ${Storage}(context)); + return std::unique_ptr(new ${Storage}()); } std::unique_ptr ${Type}::storage(size_t size) const { - return std::unique_ptr(new ${Storage}(context,size)); + return std::unique_ptr(new ${Storage}(size)); } std::unique_ptr ${Type}::storageFromBlob(void * data, int64_t size, const std::function & deleter) const { return std::unique_ptr( - new ${Storage}(context,data,size,deleter)); + new ${Storage}(data,size,deleter)); } std::unique_ptr ${Type}::storageWithAllocator(int64_t size, Allocator* allocator) const { return std::unique_ptr( - new ${Storage}(context, size, allocator)); + new ${Storage}(size, allocator)); } Tensor ${Type}::unsafeTensorFromTH(void * th_pointer, bool retain) const { if (retain) ${THTensor}_retain(${state,} (${THTensor}*) th_pointer); - return Tensor(new ${Tensor}(context,(${THTensor}*)(th_pointer)), false); + return Tensor(new ${Tensor}((${THTensor}*)(th_pointer)), false); } std::unique_ptr ${Type}::unsafeStorageFromTH(void * th_pointer, bool retain) const { if (retain) ${THStorage}_retain(${state,} (${THStorage}*) th_pointer); - return std::unique_ptr(new ${Storage}(context, (${THStorage}*) th_pointer)); + return std::unique_ptr(new ${Storage}((${THStorage}*) th_pointer)); } std::unique_ptr ${Type}::generator() const { return std::unique_ptr(new ${Generator}(context)); diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt index 86fd8db5ff55c..32f0b9d691e80 100644 --- a/aten/src/TH/CMakeLists.txt +++ b/aten/src/TH/CMakeLists.txt @@ -32,6 +32,13 @@ set(ATen_TH_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/THStorageClass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorCopy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorRandom.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorEvenMoreMath.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorConv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THTensorLapack.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THBlas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THLapack.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THLogAdd.cpp diff --git a/aten/src/TH/THStorageFunctions.hpp b/aten/src/TH/THStorageFunctions.hpp index b4b1d4c51e882..712f8b081c32b 100644 --- a/aten/src/TH/THStorageFunctions.hpp +++ b/aten/src/TH/THStorageFunctions.hpp @@ -42,5 +42,4 @@ TH_API void THStorage_resize(THStorage *storage, ptrdiff_t size); TH_API void THStorage_swap(THStorage *storage1, THStorage *storage2); TH_API void THStorage_weakRetain(THStorage *weak_storage); -TH_API void THStorage_weakFree(THStorage *weak_storage); TH_API THStorage* THStorage_weakLock(THStorage *weak_storage); diff --git a/aten/src/TH/THTensor.cpp b/aten/src/TH/THTensor.cpp index 5c6bdb48bd936..13df5128e5f5f 100644 --- a/aten/src/TH/THTensor.cpp +++ b/aten/src/TH/THTensor.cpp @@ -1,16 +1,4 @@ -#include -#include - -#include #include "THTensor.hpp" -#include "THVector.h" -#include "generic/simd/simd.h" - -#include "THBlas.h" -#include "THLapack.h" -#include "THRandom.h" -#include "THTensorDimApply.h" -#include "THMath.h" #include "generic/THTensor.cpp" #include "THGenerateAllTypes.h" @@ -18,24 +6,6 @@ #include "generic/THTensor.cpp" #include "THGenerateHalfType.h" -#include "generic/THTensorCopy.cpp" -#include "THGenerateAllTypes.h" - -#include "generic/THTensorCopy.cpp" -#include "THGenerateHalfType.h" - -#include "generic/THTensorRandom.cpp" -#include "THGenerateAllTypes.h" - -#include "generic/THTensorMath.cpp" -#include "THGenerateAllTypes.h" - -#include "generic/THTensorConv.cpp" -#include "THGenerateAllTypes.h" - -#include "generic/THTensorLapack.cpp" -#include "THGenerateFloatTypes.h" - #include void THTensor_free(THTensor *self) diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp index 7912506cec107..f641ca3269098 100644 --- a/aten/src/TH/THTensor.hpp +++ b/aten/src/TH/THTensor.hpp @@ -17,7 +17,6 @@ struct THTensor , storage_offset_(0) , sizes_{0} , strides_{1} - , dim_(1) {} ~THTensor() { @@ -35,7 +34,6 @@ struct THTensor std::vector sizes_; std::vector strides_; - int64_t dim_; template inline T * data() const { @@ -51,11 +49,11 @@ struct THTensor // _dim() returns the "old" TH dimension view where no dimensions represents an empty tensor. // dim() returns the ATen view of the dimensionality, i.e. 0-sized dimensions are supported. inline int64_t _dim() const { - return is_empty() ? 0 : dim_; + return is_empty() ? 0 : dim(); } inline int64_t dim() const { - return dim_; + return sizes_.size(); } ptrdiff_t storage_offset() const { @@ -64,7 +62,7 @@ struct THTensor // represents that numel() == 0. inline bool is_empty() const { - for (int64_t i = 0; i < dim_; ++i) { + for (int64_t i = 0; i < dim(); ++i) { if (sizes_[i] == 0) { return true; } @@ -113,7 +111,6 @@ inline int64_t* THTensor_getStridePtr(THTensor* tensor) { } inline void THTensor_resizeDim(THTensor* tensor, int64_t ndim) { - tensor->dim_ = ndim; // NB: This is *truly* a resize; calling code (e.g., squeeze) // assumes that old values are preserved tensor->sizes_.resize(ndim); @@ -121,7 +118,6 @@ inline void THTensor_resizeDim(THTensor* tensor, int64_t ndim) { } inline void THTensor_setSizesAndStrides(THTensor* tensor, std::vector&& new_size, std::vector&& new_stride) { - tensor->dim_ = new_size.size(); tensor->sizes_ = std::move(new_size); tensor->strides_ = std::move(new_stride); } diff --git a/aten/src/TH/THTensorConv.cpp b/aten/src/TH/THTensorConv.cpp new file mode 100644 index 0000000000000..cc55916c36752 --- /dev/null +++ b/aten/src/TH/THTensorConv.cpp @@ -0,0 +1,6 @@ +#include "THTensor.hpp" +#include "THVector.h" + +#include "generic/THTensorConv.cpp" +#include "THGenerateAllTypes.h" + diff --git a/aten/src/TH/THTensorCopy.cpp b/aten/src/TH/THTensorCopy.cpp new file mode 100644 index 0000000000000..d8df519e26bdb --- /dev/null +++ b/aten/src/TH/THTensorCopy.cpp @@ -0,0 +1,8 @@ +#include "THTensor.hpp" +#include "THVector.h" + +#include "generic/THTensorCopy.cpp" +#include "THGenerateAllTypes.h" + +#include "generic/THTensorCopy.cpp" +#include "THGenerateHalfType.h" diff --git a/aten/src/TH/THTensorEvenMoreMath.cpp b/aten/src/TH/THTensorEvenMoreMath.cpp new file mode 100644 index 0000000000000..bb04fa1e24c38 --- /dev/null +++ b/aten/src/TH/THTensorEvenMoreMath.cpp @@ -0,0 +1,7 @@ +#include "THTensor.hpp" +#include "THVector.h" +#include "THBlas.h" +#include "THTensorDimApply.h" + +#include "generic/THTensorEvenMoreMath.cpp" +#include "THGenerateAllTypes.h" diff --git a/aten/src/TH/THTensorLapack.cpp b/aten/src/TH/THTensorLapack.cpp new file mode 100644 index 0000000000000..467a46f82e9f4 --- /dev/null +++ b/aten/src/TH/THTensorLapack.cpp @@ -0,0 +1,5 @@ +#include "THTensor.hpp" +#include "THLapack.h" + +#include "generic/THTensorLapack.cpp" +#include "THGenerateFloatTypes.h" diff --git a/aten/src/TH/THTensorMath.cpp b/aten/src/TH/THTensorMath.cpp new file mode 100644 index 0000000000000..1454823ac8ee1 --- /dev/null +++ b/aten/src/TH/THTensorMath.cpp @@ -0,0 +1,7 @@ +#include "THTensor.hpp" +#include "THVector.h" +#include "THBlas.h" +#include "THTensorDimApply.h" + +#include "generic/THTensorMath.cpp" +#include "THGenerateAllTypes.h" diff --git a/aten/src/TH/THTensorMoreMath.cpp b/aten/src/TH/THTensorMoreMath.cpp new file mode 100644 index 0000000000000..a0d10b127be85 --- /dev/null +++ b/aten/src/TH/THTensorMoreMath.cpp @@ -0,0 +1,7 @@ +#include "THTensor.hpp" +#include "THVector.h" +#include "THBlas.h" +#include "THTensorDimApply.h" + +#include "generic/THTensorMoreMath.cpp" +#include "THGenerateAllTypes.h" diff --git a/aten/src/TH/THTensorRandom.cpp b/aten/src/TH/THTensorRandom.cpp new file mode 100644 index 0000000000000..aa987474fadb9 --- /dev/null +++ b/aten/src/TH/THTensorRandom.cpp @@ -0,0 +1,5 @@ +#include "THTensor.hpp" +#include "THVector.h" + +#include "generic/THTensorRandom.cpp" +#include "THGenerateAllTypes.h" diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index c281e916c58f0..bb09f3d1df9ca 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -413,7 +413,7 @@ void THTensor_(select)(THTensor *self, THTensor *src, int dimension, int64_t sli THTensor_setSizeAtDim(self, d, self->size(d+1)); THTensor_setStrideAtDim(self, d, self->stride(d+1)); } - THTensor_resizeDim(self, self->dim_ - 1); + THTensor_resizeDim(self, self->dim() - 1); } void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1, int dimension2) @@ -535,7 +535,7 @@ void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension) THTensor_setSizeAtDim(self, d, self->size(d+1)); THTensor_setStrideAtDim(self, d, self->stride(d+1)); } - THTensor_resizeDim(self, self->dim_ - 1); + THTensor_resizeDim(self, self->dim() - 1); } } diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h index cdc8f7edef41c..429c7a10d2d04 100644 --- a/aten/src/TH/generic/THTensor.h +++ b/aten/src/TH/generic/THTensor.h @@ -103,6 +103,7 @@ TH_API void THTensor_(setStorage4d)(THTensor *self, THStorage *storage_, ptrdiff TH_API void THTensor_(narrow)(THTensor *self, THTensor *src, int dimension_, int64_t firstIndex_, int64_t size_); TH_API void THTensor_(select)(THTensor *self, THTensor *src, int dimension_, int64_t sliceIndex_); TH_API void THTensor_(transpose)(THTensor *self, THTensor *src, int dimension1_, int dimension2_); +TH_API int THTensor_(isTransposed)(const THTensor *self); TH_API void THTensor_(unfold)(THTensor *self, THTensor *src, int dimension_, int64_t size_, int64_t step_); TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src); diff --git a/aten/src/TH/generic/THTensorApply.hpp b/aten/src/TH/generic/THTensorApply.hpp new file mode 100644 index 0000000000000..09952b303a06c --- /dev/null +++ b/aten/src/TH/generic/THTensorApply.hpp @@ -0,0 +1,175 @@ +#ifndef NAN + #define NAN (nan(NULL)) +#endif + +#ifdef _OPENMP +#include +#endif + +#define HYPER_TH_OMP_OVERHEAD_THRESHOLD 2000 +#define ORDIN_TH_OMP_OVERHEAD_THRESHOLD 20000 +#define UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD 50000 +#define TH_OMP_OVERHEAD_THRESHOLD 100000 + +#ifdef _OPENMP + +#ifndef _WIN32 +#define PRAGMA(P) _Pragma(#P) +#else +#define PRAGMA(P) __pragma(P) +#endif + +#define TH_TENSOR_APPLY_CONTIG(TYPE, TENSOR, CODE) \ +{ \ + int inOmp = omp_in_parallel(); \ + ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR); \ + PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ + { \ + size_t num_threads = omp_get_num_threads(); \ + size_t tid = omp_get_thread_num(); \ + ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ + ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ + TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ + ptrdiff_t TENSOR##_len = TH_TENSOR_end - TH_TENSOR_offset; \ + TYPE *TENSOR##_data = THTensor_(data)(TENSOR) + TH_TENSOR_offset; \ + CODE \ + } \ +} +#else +#define TH_TENSOR_APPLY_CONTIG(TYPE, TENSOR, CODE) \ +{ \ + TYPE *TENSOR##_data = THTensor_(data)(TENSOR); \ + ptrdiff_t TENSOR##_len = THTensor_(nElement)(TENSOR); \ + CODE \ +} +#endif + +#ifdef _OPENMP +#define TH_TENSOR_APPLY2_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ +{ \ + int inOmp = omp_in_parallel(); \ + ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR1); \ + PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ + { \ + size_t num_threads = omp_get_num_threads(); \ + size_t tid = omp_get_thread_num(); \ + ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ + ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ + TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ + ptrdiff_t TENSOR1##_len = TH_TENSOR_end - TH_TENSOR_offset; \ + TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1) + TH_TENSOR_offset; \ + TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2) + TH_TENSOR_offset; \ + CODE \ + } \ +} +#else +#define TH_TENSOR_APPLY2_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1); \ + TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2); \ + ptrdiff_t TENSOR1##_len = THTensor_(nElement)(TENSOR1); \ + CODE \ +} +#endif + +#ifdef _OPENMP +#define TH_TENSOR_APPLY3_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ +{ \ + int inOmp = omp_in_parallel(); \ + ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR1); \ + PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ + { \ + size_t num_threads = omp_get_num_threads(); \ + size_t tid = omp_get_thread_num(); \ + ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ + ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ + TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ + ptrdiff_t TENSOR1##_len = TH_TENSOR_end - TH_TENSOR_offset; \ + TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1) + TH_TENSOR_offset; \ + TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2) + TH_TENSOR_offset; \ + TYPE3 *TENSOR3##_data = THTensor_(data)(TENSOR3) + TH_TENSOR_offset; \ + CODE \ + } \ +} +#else +#define TH_TENSOR_APPLY3_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ +{ \ + TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1); \ + TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2); \ + TYPE3 *TENSOR3##_data = THTensor_(data)(TENSOR3); \ + ptrdiff_t TENSOR1##_len = THTensor_(nElement)(TENSOR1); \ + CODE \ +} +#endif + +#define TH_CHECK_SAME_SIZE(TENSOR1, TENSOR2) \ +{ \ + if(!THTensor_(isSameSizeAs)(TENSOR1, TENSOR2)) { \ + AT_ERROR("inconsistent tensor size, expected ", #TENSOR1, " ", TENSOR1->sizes(), " and ", #TENSOR2, " ", TENSOR2->sizes(), " to have the same size"); \ + } \ +} + +// Used for `scatter` and `scatterAdd` +// Assumes TENSOR1 is real +// TENSOR2 is src +// TENSOR3 is index +// Tests: +// 1. index->size(d) <= src->size(d) for all d +// 2. index->size(d) <= real->size(d) for all d != dim +#define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \ +{ \ + int shape_check_flag = 0; \ + for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->_dim(); TH_TENSOR_DIM_APPLY_i++) \ + { \ + int64_t TENSOR3##_dim_size = TENSOR3->size(TH_TENSOR_DIM_APPLY_i); \ + if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \ + if (TENSOR3##_dim_size > TENSOR1->size(TH_TENSOR_DIM_APPLY_i)) { \ + shape_check_flag = 1; \ + break; \ + } \ + } \ + if (TENSOR3##_dim_size > TENSOR2->size(TH_TENSOR_DIM_APPLY_i)) { \ + shape_check_flag = 1; \ + break; \ + } \ + } \ + if (shape_check_flag == 1) { \ + AT_ERROR("Expected ", #TENSOR3, " ", TENSOR3->sizes(), " to be smaller size than ", #TENSOR2, " ", TENSOR2->sizes(), " and to be smaller than ", #TENSOR1, " ", TENSOR1->sizes(), " apart from dimension ", DIMENSION); \ + } \ +} + +#undef th_isnan +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) +#define th_isnan(val) \ +(std::isnan(val)) +#else +#define th_isnan(val) (0) +#endif + +#undef th_isnan_break +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) +#define th_isnan_break(val) \ +if (std::isnan(val)) break; +#else +#define th_isnan_break(val) +#endif + +static inline real THTensor_(powOne)(real x, real y) { +#if defined(TH_REAL_IS_FLOAT) + return powf(x, y); +#elif defined(TH_REAL_IS_DOUBLE) + return pow(x, y); +#else + THArgCheck(y >= 0, 1, + "Integers to negative integer powers are not allowed"); + real result = 1; + while (y) { + if (y & 1) { + result *= x; + } + y /= 2; + x *= x; + } + return result; +#endif +} diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp new file mode 100644 index 0000000000000..4b88c95872292 --- /dev/null +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -0,0 +1,971 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorEvenMoreMath.cpp" +#else + +#include + +void THTensor_(fill)(THTensor *r_, real value) +{ + if (THTensor_(isContiguous)(r_) || THTensor_(isTransposed)(r_)) { + TH_TENSOR_APPLY_CONTIG(real, r_, THVector_(fill)(r__data, value, r__len);); + } else { + TH_TENSOR_APPLY(real, r_, + if (r__stride == 1) { + THVector_(fill)(r__data, value, r__size); + r__i = r__size; + r__data += r__stride * r__size; + break; + } else { + *r__data = value; + } + ); + } +} + +void THTensor_(zero)(THTensor *r_) +{ + THTensor_(fill)(r_, 0); +} + +void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value) +{ + TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, + if (*mask_data > 1) + { + THFree(mask_counter); + THFree(tensor_counter); + THError("Mask tensor can take 0 and 1 values only"); + } + else if (*mask_data == 1) + { + *tensor_data = value; + }); +} + +void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src ) +{ + THTensor *srct = THTensor_(newContiguous)(src); + real *src_data = THTensor_(data)(srct); + ptrdiff_t cntr = 0; + ptrdiff_t nelem = THTensor_(nElement)(srct); + if (THTensor_(nElement)(tensor) != THByteTensor_nElement(mask)) + { + THTensor_(free)(srct); + THError("Number of elements of destination tensor != Number of elements in mask"); + } + TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, + if (*mask_data > 1) + { + THTensor_(free)(srct); + THFree(mask_counter); + THFree(tensor_counter); + THError("Mask tensor can take 0 and 1 values only"); + } + else if (*mask_data == 1) + { + if (cntr == nelem) + { + THTensor_(free)(srct); + THFree(mask_counter); + THFree(tensor_counter); + THError("Number of elements of src < number of ones in mask"); + } + *tensor_data = *src_data; + src_data++; + cntr++; + }); + THTensor_(free)(srct); +} + +void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask) +{ + ptrdiff_t numel = THByteTensor_sumall(mask); + real *tensor_data; + +#ifdef DEBUG + THAssert(numel <= LONG_MAX); +#endif + THTensor_(resize1d)(tensor,numel); + tensor_data = THTensor_(data)(tensor); + TH_TENSOR_APPLY2(real, src, unsigned char, mask, + if (*mask_data > 1) + { + THFree(mask_counter); + THFree(src_counter); + THError("Mask tensor can take 0 and 1 values only"); + } + else if (*mask_data == 1) + { + *tensor_data = *src_data; + tensor_data++; + }); +} + +// Finds non-zero elements of a tensor and returns their subscripts +void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) +{ + ptrdiff_t numel = 0; + int64_t *subscript_data; + int64_t i = 0; + int64_t dim; + int64_t div = 1; +#ifdef TH_REAL_IS_HALF +#define IS_NONZERO(val) ((val.x & 0x7fff) != 0) +#else +#define IS_NONZERO(val) ((val)!=0) +#endif + + /* First Pass to determine size of subscripts */ + TH_TENSOR_APPLY(real, tensor, + if IS_NONZERO(*tensor_data) { + ++numel; + }); +#ifdef DEBUG + THAssert(numel <= LONG_MAX); +#endif + THLongTensor_resize2d(subscript, numel, tensor->dim()); + + /* Second pass populates subscripts */ + subscript_data = THLongTensor_data(subscript); + TH_TENSOR_APPLY(real, tensor, + if IS_NONZERO(*tensor_data) { + div = 1; + + for (dim = tensor->dim() - 1; dim >= 0; dim--) { + *(subscript_data + dim) = (i/div) % tensor->size(dim); + div *= tensor->size(dim); + } + + subscript_data += tensor->dim(); + } + ++i;); +} + +void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) +{ + ptrdiff_t i, numel; + THLongStorage *newSize; + THTensor *tSlice, *sSlice; + int64_t *index_data; + real *tensor_data, *src_data; + +#ifndef USE_TH_SIZE_ZERO_DIM + THArgCheck(index->_dim() <= 1, 3, "Index is supposed to be an empty tensor or a vector"); + THArgCheck(dim < src->_dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); + THArgCheck(src->_dim() > 0, 2, "Source tensor is empty"); +#else + THArgCheck(index->dim() == 1, 3, "Index is supposed to be 1-dimensional"); + THArgCheck(dim < src->dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); + //THArgCheck(src->dim() > 0, 2, "Source tensor is empty"); +#endif + + numel = THLongTensor_nElement(index); + + newSize = THLongStorage_newWithSize(src->dim()); + THLongStorage_rawCopy(newSize, THTensor_getSizePtr(src)); +#ifdef DEBUG + THAssert(numel <= LONG_MAX); +#endif + THLongStorage_data(newSize)[dim] = numel; + THTensor_(resize)(tensor,newSize,NULL); + THLongStorage_free(newSize); + + index = THLongTensor_newContiguous(index); + index_data = THLongTensor_data(index); + + if (dim == 0 && THTensor_(isContiguous)(src) && THTensor_(isContiguous)(tensor)) + { + tensor_data = THTensor_(data)(tensor); + src_data = THTensor_(data)(src); + ptrdiff_t rowsize = src->size(0) == 0 ? 1: THTensor_(nElement)(src) / src->size(0); + + // check that the indices are within range + int64_t max = src->size(0) - 1 + TH_INDEX_BASE; + for (i=0; i max) { + THLongTensor_free(index); + THError("index out of range"); + } + } + + if (src->dim() == 1) { + #pragma omp parallel for if(numel > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; idim() == 1) + { + for (i=0; idim() > 1 ) + { + tSlice = THTensor_(new)(); + sSlice = THTensor_(new)(); + + for (i=0; isizes(); + auto stride = tensor->strides(); + int nDim = tensor->_dim(); + ptrdiff_t dataOffset = 0; + for (int i = nDim - 1; i >= 0; i--) { + dataOffset += (linearIndex % size[i]) * stride[i]; + linearIndex /= size[i]; + } + return dataOffset; +} + +static inline void THTensor_(checkLinearIndex)(int64_t linearIndex, int64_t numel) { + THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel); +} + +static inline int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) { + return linearIndex < 0 ? linearIndex + numel : linearIndex; +} + +void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index) +{ + THTensor_(resizeNd)(r_, index->dim(), THTensor_getSizePtr(index), NULL); + THTensor* dst = THTensor_(newContiguous)(r_); + + index = THLongTensor_newContiguous(index); + int64_t* index_data = THLongTensor_data(index); + ptrdiff_t srcElements = THTensor_(nElement)(src); + real* src_data = THTensor_(data)(src); + real* dst_data = THTensor_(data)(dst); + ptrdiff_t nIndices = THLongTensor_nElement(index); + int isContiguous = THTensor_(isContiguous)(src); + + // Exceptions must not be thrown across OpenMP parallel sections, so we + // record the position of the invalid index and throw the exception after the + // loop. + std::atomic invalidIdxPos(-1); + + ptrdiff_t i; + #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i = 0; i < nIndices; i++) { + int64_t idx = index_data[i]; + if (idx < srcElements && idx >= -srcElements) { + idx = THTensor_(wrapLinearIndex)(idx, srcElements); + if (isContiguous) { + dst_data[i] = src_data[idx]; + } else { + dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)]; + } + } else { + int64_t tmp = -1; + invalidIdxPos.compare_exchange_strong(tmp, i); + } + } + + if (invalidIdxPos >= 0) { + THTensor_(checkLinearIndex)(index_data[invalidIdxPos], srcElements); + } + + THLongTensor_free(index); + THTensor_(freeCopyTo)(dst, r_); +} + +void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate) +{ + THArgCheck(THLongTensor_nElement(index) == THTensor_(nElement)(src), 3, + "src should have the same number of elements as index"); + + index = THLongTensor_newContiguous(index); + src = THTensor_(newContiguous)(src); + real* data = THTensor_(data)(tensor); + ptrdiff_t numel = THTensor_(nElement)(tensor); + int is_contiguous = THTensor_(isContiguous)(tensor); + + TH_TENSOR_APPLY2(int64_t, index, real, src, + THTensor_(checkLinearIndex)(*index_data, numel); + int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel); + int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex); + if (accumulate) { + data[dataOffset] += *src_data; + } else { + data[dataOffset] = *src_data; + } + ); + + THTensor_(free)(src); + THLongTensor_free(index); +} + +void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) +{ + ptrdiff_t i, numel; + THTensor *tSlice, *sSlice; + int64_t *index_data; + + numel = THLongTensor_nElement(index); +#ifndef USE_TH_SIZE_ZERO_DIM + THArgCheck(index->_dim() == 1, 3, "Index is supposed to be a vector"); + THArgCheck(dim < src->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); +#else + THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector"); + THArgCheck(dim < src->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); +#endif + THArgCheck(numel == src->size(dim),4,"Number of indices should be equal to source:size(dim)"); + + index = THLongTensor_newContiguous(index); + index_data = THLongTensor_data(index); + + if (tensor->dim() > 1) + { + tSlice = THTensor_(new)(); + sSlice = THTensor_(new)(); + + for (i=0; i_dim() == 1, 3, "Index is supposed to be a vector"); + THArgCheck(dim < tensor->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); +#else + THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector"); + THArgCheck(dim < tensor->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); +#endif + + index = THLongTensor_newContiguous(index); + index_data = THLongTensor_data(index); + + for (i=0; idim() > 1) + { + tSlice = THTensor_(new)(); + THTensor_(select)(tSlice, tensor,dim,index_data[i] - TH_INDEX_BASE); + THTensor_(fill)(tSlice, val); + THTensor_(free)(tSlice); + } + else + { + THTensor_(set1d)(tensor, index_data[i] - TH_INDEX_BASE, val); + } + } + THLongTensor_free(index); +} + +void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) +{ + int64_t elems_per_row, i, idx; + + THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(src), 4, + "Index tensor must have same dimensions as input tensor"); + THArgCheck(dim >= 0 && dim < THTensor_(nDimension)(tensor), 3, + "Index dimension is out of bounds"); + THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 2, + "Input tensor must have same dimensions as output tensor"); + + elems_per_row = THLongTensor_size(index, dim); + + TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + for (i = 0; i < elems_per_row; ++i) + { + idx = *(index_data + i*index_stride); + if (idx < TH_INDEX_BASE || idx >= src_size + TH_INDEX_BASE) + { + THFree(TH_TENSOR_DIM_APPLY_counter); + THError("Invalid index in gather"); + } + *(tensor_data + i*tensor_stride) = src_data[(idx - TH_INDEX_BASE) * src_stride]; + }) +} + +void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) +{ + int64_t elems_per_row, i, idx; + +#ifndef USE_TH_SIZE_ZERO_DIM + THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds"); + THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3, + "Index tensor must have same dimensions as output tensor"); + THArgCheck(THTensor_(_nDimension)(src) == THTensor_(_nDimension)(tensor), 4, + "Input tensor must have same dimensions as output tensor"); +#else + THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds"); + THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3, + "Index tensor must have same dimensions as output tensor"); + THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4, + "Input tensor must have same dimensions as output tensor"); +#endif + + elems_per_row = THLongTensor_size(index, dim); + + TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, + TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, + for (i = 0; i < elems_per_row; ++i) + { + idx = *(index_data + i*index_stride); + if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) + { + THFree(TH_TENSOR_DIM_APPLY_counter); + THError("Invalid index in scatter"); + } + tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = *(src_data + i*src_stride); + }) +} + +void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) +{ + int64_t elems_per_row, i, idx; + + THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds"); + THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3, + "Index tensor must have same dimensions as output tensor"); + THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4, + "Input tensor must have same dimensions as output tensor"); + + elems_per_row = THLongTensor_size(index, dim); + + TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, + TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, + for (i = 0; i < elems_per_row; ++i) + { + idx = *(index_data + i*index_stride); + if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) + { + THFree(TH_TENSOR_DIM_APPLY_counter); + THError("Invalid index in scatterAdd"); + } + tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] += *(src_data + i*src_stride); + }) +} + +void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val) +{ + int64_t elems_per_row, i, idx; + + THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds"); + THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3, + "Index tensor must have same dimensions as output tensor"); + + elems_per_row = THLongTensor_size(index, dim); + + TH_TENSOR_DIM_APPLY2(real, tensor, int64_t, index, dim, + for (i = 0; i < elems_per_row; ++i) + { + idx = *(index_data + i*index_stride); + if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) + { + THFree(TH_TENSOR_DIM_APPLY_counter); + THError("Invalid index in scatter"); + } + tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = val; + }) +} + +accreal THTensor_(dot)(THTensor *tensor, THTensor *src) +{ + accreal sum = 0; + /* we use a trick here. careful with that. */ + TH_TENSOR_APPLY2(real, tensor, real, src, + int64_t sz = (tensor_size-tensor_i < src_size-src_i ? tensor_size-tensor_i : src_size-src_i); + sum += THBlas_(dot)(sz, src_data, src_stride, tensor_data, tensor_stride); + tensor_i += sz; + src_i += sz; + tensor_data += sz*tensor_stride; + src_data += sz*src_stride; + break;); + return sum; +} + +real THTensor_(minall)(THTensor *tensor) +{ + real theMin; + real value; + + THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); + theMin = THTensor_(data)(tensor)[0]; + TH_TENSOR_APPLY(real, tensor, + value = *tensor_data; + /* This is not the same as value= theMin)) + { + theMin = value; + th_isnan_break(value) + }); + return theMin; +} + +real THTensor_(maxall)(THTensor *tensor) +{ + real theMax; + real value; + + THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); + theMax = THTensor_(data)(tensor)[0]; + TH_TENSOR_APPLY(real, tensor, + value = *tensor_data; + /* This is not the same as value>theMax in the case of NaNs */ + if(!(value <= theMax)) + { + theMax = value; + th_isnan_break(value) + }); + return theMax; +} + +accreal THTensor_(sumall)(THTensor *tensor) +{ + accreal sum = 0; + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if(inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, +:sum, sum += *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } +#else + serial_path = 1; +#endif + if (serial_path) { + TH_TENSOR_APPLY(real, tensor, sum += *tensor_data;); + } + return sum; +} + +accreal THTensor_(prodall)(THTensor *tensor) +{ + accreal prod = 1; + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if(inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, *:prod, prod *= *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } +#else + serial_path = 1; +#endif + if (serial_path) { + TH_TENSOR_APPLY(real, tensor, prod *= *tensor_data;); + } + return prod; +} + +void THTensor_(add)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(adds)(r__data, t_data, value, r__len);); + } else { +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data + value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD) + } +#else + (void)r_Size; + serial_path = 1; +#endif + } + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data + value;); + } +} + +void THTensor_(sub)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(add)(r_, t, -value); +} + +void THTensor_(add_scaled)(THTensor *r_, THTensor *t, real value, real alpha) +{ + THTensor_(add)(r_, t, value * alpha); +} + +void THTensor_(sub_scaled)(THTensor *r_, THTensor *t, real value, real alpha) +{ + THTensor_(add)(r_, t, -value * alpha); +} + +void THTensor_(mul)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(muls)(r__data, t_data, value, r__len);); + } else { +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data * value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD) + } +#else + (void)r_Size; + serial_path = 1; +#endif + } + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data * value;); + } +} + +void THTensor_(div)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(divs)(r__data, t_data, value, r__len);); + } else { +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data / value;, ORDIN_TH_OMP_OVERHEAD_THRESHOLD) + } +#else + (void)r_Size; + serial_path = 1; +#endif + } + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data / value;); + } +} + +void THTensor_(lshift)(THTensor *r_, THTensor *t, real value) +{ +#if defined(TH_REAL_IS_FLOAT) + return THTensor_(mul)(r_, t, powf(2, value)); +#elif defined(TH_REAL_IS_DOUBLE) + return THTensor_(mul)(r_, t, pow(2, value)); +#elif defined(TH_REAL_IS_HALF) + return THError("lshift is not supported for torch.HalfTensor"); +#else + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i> value; +#else + rp[i] = ((ureal) tp[i]) >> value; +#endif + } + } else { +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { +#if defined(TH_REAL_IS_BYTE) + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (((real) *t_data) >> value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); +#else + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (((ureal) *t_data) >> value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); +#endif + } +#else + serial_path = 1; +#endif + } + if (serial_path) { +#if defined(TH_REAL_IS_BYTE) + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((real) *t_data) >> value);); +#else + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((ureal) *t_data) >> value);); +#endif + } +#endif +} + +void THTensor_(fmod)(THTensor *r_, THTensor *t, real value) +{ + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i -#endif +#include + +// HEY YOU! +// +// Looking for a function which used to be in THTensorMath.cpp, but +// can't find it anymore? Check THTensorMoreMath.cpp and +// THTensorEvenMoreMath.cpp. These source files have been split up +// because they were getting too big (a whopping 4669 lines at time +// of writing) and causing MSVC to run out of memory. Did you come +// here because you saw: +// +// fatal error C1002: compiler is out of heap space in pass 2 +// +// Try splitting up the file some more. +// +// At some point, we should reorganize these files in a way that makes +// sense (rather than just having cut the file down the middle, which is +// what I did when I split these up originally). -#define HYPER_TH_OMP_OVERHEAD_THRESHOLD 2000 -#define ORDIN_TH_OMP_OVERHEAD_THRESHOLD 20000 -#define UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD 50000 -#define TH_OMP_OVERHEAD_THRESHOLD 100000 - -#ifdef _OPENMP - -#ifndef _WIN32 -#define PRAGMA(P) _Pragma(#P) -#else -#define PRAGMA(P) __pragma(P) -#endif -#define TH_TENSOR_APPLY_CONTIG(TYPE, TENSOR, CODE) \ -{ \ - int inOmp = omp_in_parallel(); \ - ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR); \ - PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ - { \ - size_t num_threads = omp_get_num_threads(); \ - size_t tid = omp_get_thread_num(); \ - ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ - ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ - TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ - ptrdiff_t TENSOR##_len = TH_TENSOR_end - TH_TENSOR_offset; \ - TYPE *TENSOR##_data = THTensor_(data)(TENSOR) + TH_TENSOR_offset; \ - CODE \ - } \ -} -#else -#define TH_TENSOR_APPLY_CONTIG(TYPE, TENSOR, CODE) \ -{ \ - TYPE *TENSOR##_data = THTensor_(data)(TENSOR); \ - ptrdiff_t TENSOR##_len = THTensor_(nElement)(TENSOR); \ - CODE \ +// Should wrap if the value (a) has a different sign than the divisor (b), but is not 0. +static inline bool modulo_wrap(real a, real b) { + return (a != 0) && (a < 0) != (b < 0); } -#endif -#ifdef _OPENMP -#define TH_TENSOR_APPLY2_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ -{ \ - int inOmp = omp_in_parallel(); \ - ptrdiff_t TH_TENSOR_size = THTensor_(nElement)(TENSOR1); \ - PRAGMA(omp parallel if ((TH_TENSOR_size > TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ - { \ - size_t num_threads = omp_get_num_threads(); \ - size_t tid = omp_get_thread_num(); \ - ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ - ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ - TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ - ptrdiff_t TENSOR1##_len = TH_TENSOR_end - TH_TENSOR_offset; \ - TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1) + TH_TENSOR_offset; \ - TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2) + TH_TENSOR_offset; \ - CODE \ - } \ -} +void THTensor_(bitor)(THTensor *r_, THTensor *t, real value) +{ +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) + (void)r_; + (void)t; + (void)value; + return THError("bitor is only supported for integer type tensors"); #else -#define TH_TENSOR_APPLY2_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, CODE) \ -{ \ - TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1); \ - TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2); \ - ptrdiff_t TENSOR1##_len = THTensor_(nElement)(TENSOR1); \ - CODE \ -} -#endif - + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) && (!inOmp))) \ - { \ - size_t num_threads = omp_get_num_threads(); \ - size_t tid = omp_get_thread_num(); \ - ptrdiff_t TH_TENSOR_offset = tid * (TH_TENSOR_size / num_threads); \ - ptrdiff_t TH_TENSOR_end = tid == num_threads - 1 ? TH_TENSOR_size : \ - TH_TENSOR_offset + TH_TENSOR_size / num_threads; \ - ptrdiff_t TENSOR1##_len = TH_TENSOR_end - TH_TENSOR_offset; \ - TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1) + TH_TENSOR_offset; \ - TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2) + TH_TENSOR_offset; \ - TYPE3 *TENSOR3##_data = THTensor_(data)(TENSOR3) + TH_TENSOR_offset; \ - CODE \ - } \ -} -#else -#define TH_TENSOR_APPLY3_CONTIG(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, CODE) \ -{ \ - TYPE1 *TENSOR1##_data = THTensor_(data)(TENSOR1); \ - TYPE2 *TENSOR2##_data = THTensor_(data)(TENSOR2); \ - TYPE3 *TENSOR3##_data = THTensor_(data)(TENSOR3); \ - ptrdiff_t TENSOR1##_len = THTensor_(nElement)(TENSOR1); \ - CODE \ -} -#endif - -#define TH_CHECK_SAME_SIZE(TENSOR1, TENSOR2) \ -{ \ - if(!THTensor_(isSameSizeAs)(TENSOR1, TENSOR2)) { \ - AT_ERROR("inconsistent tensor size, expected ", #TENSOR1, " ", TENSOR1->sizes(), " and ", #TENSOR2, " ", TENSOR2->sizes(), " to have the same size"); \ - } \ -} - -// Used for `scatter` and `scatterAdd` -// Assumes TENSOR1 is real -// TENSOR2 is src -// TENSOR3 is index -// Tests: -// 1. index->size(d) <= src->size(d) for all d -// 2. index->size(d) <= real->size(d) for all d != dim -#define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \ -{ \ - int shape_check_flag = 0; \ - for(TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->_dim(); TH_TENSOR_DIM_APPLY_i++) \ - { \ - int64_t TENSOR3##_dim_size = TENSOR3->size(TH_TENSOR_DIM_APPLY_i); \ - if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \ - if (TENSOR3##_dim_size > TENSOR1->size(TH_TENSOR_DIM_APPLY_i)) { \ - shape_check_flag = 1; \ - break; \ - } \ - } \ - if (TENSOR3##_dim_size > TENSOR2->size(TH_TENSOR_DIM_APPLY_i)) { \ - shape_check_flag = 1; \ - break; \ - } \ - } \ - if (shape_check_flag == 1) { \ - AT_ERROR("Expected ", #TENSOR3, " ", TENSOR3->sizes(), " to be smaller size than ", #TENSOR2, " ", TENSOR2->sizes(), " and to be smaller than ", #TENSOR1, " ", TENSOR1->sizes(), " apart from dimension ", DIMENSION); \ - } \ -} - -static inline real THTensor_(powOne)(real x, real y) { -#if defined(TH_REAL_IS_FLOAT) - return powf(x, y); -#elif defined(TH_REAL_IS_DOUBLE) - return pow(x, y); -#else - THArgCheck(y >= 0, 1, - "Integers to negative integer powers are not allowed"); - real result = 1; - while (y) { - if (y & 1) { - result *= x; + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = *t_data | value;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); } - y /= 2; - x *= x; - } - return result; +#else + serial_path = 1; #endif -} - -void THTensor_(fill)(THTensor *r_, real value) -{ - if (THTensor_(isContiguous)(r_) || THTensor_(isTransposed)(r_)) { - TH_TENSOR_APPLY_CONTIG(real, r_, THVector_(fill)(r__data, value, r__len);); - } else { - TH_TENSOR_APPLY(real, r_, - if (r__stride == 1) { - THVector_(fill)(r__data, value, r__size); - r__i = r__size; - r__data += r__stride * r__size; - break; - } else { - *r__data = value; - } - ); } -} - -void THTensor_(zero)(THTensor *r_) -{ - THTensor_(fill)(r_, 0); -} - -void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, real value) -{ - TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, - if (*mask_data > 1) - { - THFree(mask_counter); - THFree(tensor_counter); - THError("Mask tensor can take 0 and 1 values only"); - } - else if (*mask_data == 1) - { - *tensor_data = value; - }); -} - -void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src ) -{ - THTensor *srct = THTensor_(newContiguous)(src); - real *src_data = THTensor_(data)(srct); - ptrdiff_t cntr = 0; - ptrdiff_t nelem = THTensor_(nElement)(srct); - if (THTensor_(nElement)(tensor) != THByteTensor_nElement(mask)) - { - THTensor_(free)(srct); - THError("Number of elements of destination tensor != Number of elements in mask"); + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data | value;); } - TH_TENSOR_APPLY2(real, tensor, unsigned char, mask, - if (*mask_data > 1) - { - THTensor_(free)(srct); - THFree(mask_counter); - THFree(tensor_counter); - THError("Mask tensor can take 0 and 1 values only"); - } - else if (*mask_data == 1) - { - if (cntr == nelem) - { - THTensor_(free)(srct); - THFree(mask_counter); - THFree(tensor_counter); - THError("Number of elements of src < number of ones in mask"); - } - *tensor_data = *src_data; - src_data++; - cntr++; - }); - THTensor_(free)(srct); -} - -void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask) -{ - ptrdiff_t numel = THByteTensor_sumall(mask); - real *tensor_data; - -#ifdef DEBUG - THAssert(numel <= LONG_MAX); -#endif - THTensor_(resize1d)(tensor,numel); - tensor_data = THTensor_(data)(tensor); - TH_TENSOR_APPLY2(real, src, unsigned char, mask, - if (*mask_data > 1) - { - THFree(mask_counter); - THFree(src_counter); - THError("Mask tensor can take 0 and 1 values only"); - } - else if (*mask_data == 1) - { - *tensor_data = *src_data; - tensor_data++; - }); -} - -// Finds non-zero elements of a tensor and returns their subscripts -void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) -{ - ptrdiff_t numel = 0; - int64_t *subscript_data; - int64_t i = 0; - int64_t dim; - int64_t div = 1; -#ifdef TH_REAL_IS_HALF -#define IS_NONZERO(val) ((val.x & 0x7fff) != 0) -#else -#define IS_NONZERO(val) ((val)!=0) -#endif - - /* First Pass to determine size of subscripts */ - TH_TENSOR_APPLY(real, tensor, - if IS_NONZERO(*tensor_data) { - ++numel; - }); -#ifdef DEBUG - THAssert(numel <= LONG_MAX); #endif - THLongTensor_resize2d(subscript, numel, tensor->dim()); - - /* Second pass populates subscripts */ - subscript_data = THLongTensor_data(subscript); - TH_TENSOR_APPLY(real, tensor, - if IS_NONZERO(*tensor_data) { - div = 1; - - for (dim = tensor->dim() - 1; dim >= 0; dim--) { - *(subscript_data + dim) = (i/div) % tensor->size(dim); - div *= tensor->size(dim); - } - - subscript_data += tensor->dim(); - } - ++i;); } -void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) +void THTensor_(bitxor)(THTensor *r_, THTensor *t, real value) { - ptrdiff_t i, numel; - THLongStorage *newSize; - THTensor *tSlice, *sSlice; - int64_t *index_data; - real *tensor_data, *src_data; - -#ifndef USE_TH_SIZE_ZERO_DIM - THArgCheck(index->_dim() <= 1, 3, "Index is supposed to be an empty tensor or a vector"); - THArgCheck(dim < src->_dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); - THArgCheck(src->_dim() > 0, 2, "Source tensor is empty"); +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) + (void)r_; + (void)t; + (void)value; + return THError("bitxor is only supported for integer type tensors"); #else - THArgCheck(index->dim() == 1, 3, "Index is supposed to be 1-dimensional"); - THArgCheck(dim < src->dim(), 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); - //THArgCheck(src->dim() > 0, 2, "Source tensor is empty"); -#endif - - numel = THLongTensor_nElement(index); - - newSize = THLongStorage_newWithSize(src->dim()); - THLongStorage_rawCopy(newSize, THTensor_getSizePtr(src)); -#ifdef DEBUG - THAssert(numel <= LONG_MAX); -#endif - THLongStorage_data(newSize)[dim] = numel; - THTensor_(resize)(tensor,newSize,NULL); - THLongStorage_free(newSize); - - index = THLongTensor_newContiguous(index); - index_data = THLongTensor_data(index); - - if (dim == 0 && THTensor_(isContiguous)(src) && THTensor_(isContiguous)(tensor)) - { - tensor_data = THTensor_(data)(tensor); - src_data = THTensor_(data)(src); - ptrdiff_t rowsize = src->size(0) == 0 ? 1: THTensor_(nElement)(src) / src->size(0); - - // check that the indices are within range - int64_t max = src->size(0) - 1 + TH_INDEX_BASE; - for (i=0; i max) { - THLongTensor_free(index); - THError("index out of range"); - } + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int serial_path = 0; + if (r_Contig && tContig) { + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) + for (i=0; idim() == 1) { - #pragma omp parallel for if(numel > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; idim() == 1) - { - for (i=0; idim() > 1 ) - { - tSlice = THTensor_(new)(); - sSlice = THTensor_(new)(); - - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i max_value ? max_value : tp[i]); + } else { +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (*t_data < min_value) ? min_value : (*t_data > max_value ? max_value : *t_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); } +#else + serial_path = 1; +#endif } - THLongTensor_free(index); -} - -static ptrdiff_t THTensor_(dataOffset)(THTensor* tensor, ptrdiff_t linearIndex) { - auto size = tensor->sizes(); - auto stride = tensor->strides(); - int nDim = tensor->_dim(); - ptrdiff_t dataOffset = 0; - for (int i = nDim - 1; i >= 0; i--) { - dataOffset += (linearIndex % size[i]) * stride[i]; - linearIndex /= size[i]; + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data < min_value) ? min_value : (*t_data > max_value ? max_value : *t_data);); } - return dataOffset; -} - -static inline void THTensor_(checkLinearIndex)(int64_t linearIndex, int64_t numel) { - THArgCheck(linearIndex < numel && linearIndex >= -numel, 2, "out of range: %d out of %d", (int)linearIndex, (int)numel); -} - -static inline int64_t THTensor_(wrapLinearIndex)(int64_t linearIndex, int64_t numel) { - return linearIndex < 0 ? linearIndex + numel : linearIndex; } -void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index) +void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src) { - THTensor_(resizeNd)(r_, index->dim(), THTensor_getSizePtr(index), NULL); - THTensor* dst = THTensor_(newContiguous)(r_); - - index = THLongTensor_newContiguous(index); - int64_t* index_data = THLongTensor_data(index); - ptrdiff_t srcElements = THTensor_(nElement)(src); - real* src_data = THTensor_(data)(src); - real* dst_data = THTensor_(data)(dst); - ptrdiff_t nIndices = THLongTensor_nElement(index); - int isContiguous = THTensor_(isContiguous)(src); - - // Exceptions must not be thrown across OpenMP parallel sections, so we - // record the position of the invalid index and throw the exception after the - // loop. - std::atomic invalidIdxPos(-1); - - ptrdiff_t i; - #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i = 0; i < nIndices; i++) { - int64_t idx = index_data[i]; - if (idx < srcElements && idx >= -srcElements) { - idx = THTensor_(wrapLinearIndex)(idx, srcElements); - if (isContiguous) { - dst_data[i] = src_data[idx]; + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int64_t srcSize = THTensor_(nElement)(src); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int srcContig = THTensor_(isContiguous)(src); + int serial_path = 0; + if (srcSize == r_Size){ + if (r_Contig && tContig && srcContig) { + if(r_ == t) { + THBlas_(axpy)(THTensor_(nElement)(t), value, THTensor_(data)(src), 1, THTensor_(data)(r_), 1); } else { - dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)]; + TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len);); } } else { - int64_t tmp = -1; - invalidIdxPos.compare_exchange_strong(tmp, i); +#if _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data + value * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } +#else + serial_path = 1; +#endif } + } else { + serial_path = 1; } - - if (invalidIdxPos >= 0) { - THTensor_(checkLinearIndex)(index_data[invalidIdxPos], srcElements); + if (serial_path) { + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data + value * *src_data;); } - - THLongTensor_free(index); - THTensor_(freeCopyTo)(dst, r_); } -void THTensor_(put)(THTensor *tensor, THLongTensor *index, THTensor *src, int accumulate) +void THTensor_(csub)(THTensor *r_, THTensor *t, real value, THTensor *src) { - THArgCheck(THLongTensor_nElement(index) == THTensor_(nElement)(src), 3, - "src should have the same number of elements as index"); - - index = THLongTensor_newContiguous(index); - src = THTensor_(newContiguous)(src); - real* data = THTensor_(data)(tensor); - ptrdiff_t numel = THTensor_(nElement)(tensor); - int is_contiguous = THTensor_(isContiguous)(tensor); - - TH_TENSOR_APPLY2(int64_t, index, real, src, - THTensor_(checkLinearIndex)(*index_data, numel); - int64_t linearIndex = THTensor_(wrapLinearIndex)(*index_data, numel); - int64_t dataOffset = is_contiguous ? linearIndex : THTensor_(dataOffset)(tensor, linearIndex); - if (accumulate) { - data[dataOffset] += *src_data; - } else { - data[dataOffset] = *src_data; - } - ); - - THTensor_(free)(src); - THLongTensor_free(index); + THTensor_(cadd)(r_, t, -value, src); } -void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) +void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src) { - ptrdiff_t i, numel; - THTensor *tSlice, *sSlice; - int64_t *index_data; - - numel = THLongTensor_nElement(index); -#ifndef USE_TH_SIZE_ZERO_DIM - THArgCheck(index->_dim() == 1, 3, "Index is supposed to be a vector"); - THArgCheck(dim < src->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int64_t srcSize = THTensor_(nElement)(src); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int srcContig = THTensor_(isContiguous)(src); + int serial_path = 0; + if (srcSize == r_Size){ + if (r_Contig && tContig && srcContig) { + TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cmul)(r__data, t_data, src_data, r__len);); + } else { +#if _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } #else - THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector"); - THArgCheck(dim < src->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); + serial_path = 1; #endif - THArgCheck(numel == src->size(dim),4,"Number of indices should be equal to source:size(dim)"); - - index = THLongTensor_newContiguous(index); - index_data = THLongTensor_data(index); - - if (tensor->dim() > 1) - { - tSlice = THTensor_(new)(); - sSlice = THTensor_(new)(); - - for (i=0; i_dim() == 1, 3, "Index is supposed to be a vector"); - THArgCheck(dim < tensor->_dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); -#else - THArgCheck(index->dim() == 1, 3, "Index is supposed to be a vector"); - THArgCheck(dim < tensor->dim(), 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE); -#endif - - index = THLongTensor_newContiguous(index); - index_data = THLongTensor_data(index); - - for (i=0; idim() > 1) - { - tSlice = THTensor_(new)(); - THTensor_(select)(tSlice, tensor,dim,index_data[i] - TH_INDEX_BASE); - THTensor_(fill)(tSlice, val); - THTensor_(free)(tSlice); - } - else - { - THTensor_(set1d)(tensor, index_data[i] - TH_INDEX_BASE, val); - } + THTensor_(resizeAs)(r_, t); + if(value == 1){ + THTensor_(copy)(r_, t); + } + else if(value == 2){ + THTensor_(cmul)(r_, t, t); + } + else if(value == 3){ + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data * *t_data * *t_data;); } - THLongTensor_free(index); -} - -void THTensor_(gather)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) -{ - int64_t elems_per_row, i, idx; - - THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(src), 4, - "Index tensor must have same dimensions as input tensor"); - THArgCheck(dim >= 0 && dim < THTensor_(nDimension)(tensor), 3, - "Index dimension is out of bounds"); - THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 2, - "Input tensor must have same dimensions as output tensor"); - - elems_per_row = THLongTensor_size(index, dim); - - TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < TH_INDEX_BASE || idx >= src_size + TH_INDEX_BASE) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in gather"); - } - *(tensor_data + i*tensor_stride) = src_data[(idx - TH_INDEX_BASE) * src_stride]; - }) -} - -void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) -{ - int64_t elems_per_row, i, idx; - -#ifndef USE_TH_SIZE_ZERO_DIM - THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3, - "Index tensor must have same dimensions as output tensor"); - THArgCheck(THTensor_(_nDimension)(src) == THTensor_(_nDimension)(tensor), 4, - "Input tensor must have same dimensions as output tensor"); -#else - THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3, - "Index tensor must have same dimensions as output tensor"); - THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4, - "Input tensor must have same dimensions as output tensor"); -#endif - - elems_per_row = THLongTensor_size(index, dim); - - TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, - TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in scatter"); - } - tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = *(src_data + i*src_stride); - }) -} - -void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src) -{ - int64_t elems_per_row, i, idx; - - THArgCheck(dim < THTensor_(nDimension)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THLongTensor_nDimension(index) == THTensor_(nDimension)(tensor), 3, - "Index tensor must have same dimensions as output tensor"); - THArgCheck(THTensor_(nDimension)(src) == THTensor_(nDimension)(tensor), 4, - "Input tensor must have same dimensions as output tensor"); - - elems_per_row = THLongTensor_size(index, dim); - - TH_TENSOR_DIM_APPLY3(real, tensor, real, src, int64_t, index, dim, - TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in scatterAdd"); - } - tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] += *(src_data + i*src_stride); - }) -} - -void THTensor_(scatterFill)(THTensor *tensor, int dim, THLongTensor *index, real val) -{ - int64_t elems_per_row, i, idx; - - THArgCheck(dim < THTensor_(_nDimension)(tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THLongTensor__nDimension(index) == THTensor_(_nDimension)(tensor), 3, - "Index tensor must have same dimensions as output tensor"); - - elems_per_row = THLongTensor_size(index, dim); - - TH_TENSOR_DIM_APPLY2(real, tensor, int64_t, index, dim, - for (i = 0; i < elems_per_row; ++i) - { - idx = *(index_data + i*index_stride); - if (idx < TH_INDEX_BASE || idx >= tensor_size + TH_INDEX_BASE) - { - THFree(TH_TENSOR_DIM_APPLY_counter); - THError("Invalid index in scatter"); - } - tensor_data[(idx - TH_INDEX_BASE) * tensor_stride] = val; - }) -} - -accreal THTensor_(dot)(THTensor *tensor, THTensor *src) -{ - accreal sum = 0; - /* we use a trick here. careful with that. */ - TH_TENSOR_APPLY2(real, tensor, real, src, - int64_t sz = (tensor_size-tensor_i < src_size-src_i ? tensor_size-tensor_i : src_size-src_i); - sum += THBlas_(dot)(sz, src_data, src_stride, tensor_data, tensor_stride); - tensor_i += sz; - src_i += sz; - tensor_data += sz*tensor_stride; - src_data += sz*src_stride; - break;); - return sum; -} - - -#undef th_isnan -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) -#define th_isnan(val) \ -(std::isnan(val)) -#else -#define th_isnan(val) (0) -#endif - -#undef th_isnan_break #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) -#define th_isnan_break(val) \ -if (std::isnan(val)) break; +#if defined (TH_REAL_IS_FLOAT) +#define TH_MATH_NAME(fn) fn##f #else -#define th_isnan_break(val) +#define TH_MATH_NAME(fn) fn #endif - -real THTensor_(minall)(THTensor *tensor) -{ - real theMin; - real value; - - THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); - theMin = THTensor_(data)(tensor)[0]; - TH_TENSOR_APPLY(real, tensor, - value = *tensor_data; - /* This is not the same as value= theMin)) - { - theMin = value; - th_isnan_break(value) - }); - return theMin; -} - -real THTensor_(maxall)(THTensor *tensor) -{ - real theMax; - real value; - - THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); - theMax = THTensor_(data)(tensor)[0]; - TH_TENSOR_APPLY(real, tensor, - value = *tensor_data; - /* This is not the same as value>theMax in the case of NaNs */ - if(!(value <= theMax)) - { - theMax = value; - th_isnan_break(value) - }); - return theMax; -} - -static void THTensor_(quickselectnoidx)(real *arr, int64_t k, int64_t elements, int64_t stride); - -real THTensor_(medianall)(THTensor *tensor) -{ - THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); - - real theMedian; - ptrdiff_t numel; - int64_t k; - THTensor *temp_; - real *temp__data; - - numel = THTensor_(nElement)(tensor); - k = (numel-1) >> 1; - - temp_ = THTensor_(newClone)(tensor); - temp__data = THTensor_(data)(temp_); - - THTensor_(quickselectnoidx)(temp__data, k, numel, 1); - - theMedian = temp__data[k]; - - THTensor_(free)(temp_); - - return theMedian; -} - -accreal THTensor_(sumall)(THTensor *tensor) -{ - accreal sum = 0; - int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if(inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, +:sum, sum += *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + else if(value == 0.5){ + THTensor_(sqrt)(r_, t); } -#else - serial_path = 1; -#endif - if (serial_path) { - TH_TENSOR_APPLY(real, tensor, sum += *tensor_data;); + else if(value == -0.5){ + THTensor_(rsqrt)(r_, t); } - return sum; -} - -accreal THTensor_(prodall)(THTensor *tensor) -{ - accreal prod = 1; - int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if(inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, *:prod, prod *= *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + else if(value == -1){ + THTensor_(cinv)(r_, t); + } + else if(value == -2){ + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data);); + } + else{ + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = TH_MATH_NAME(pow)(*t_data, value);); } +#undef TH_MATH_NAME #else - serial_path = 1; -#endif - if (serial_path) { - TH_TENSOR_APPLY(real, tensor, prod *= *tensor_data;); + else { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = THTensor_(powOne)(*t_data, value);); } - return prod; +#endif } -void THTensor_(add)(THTensor *r_, THTensor *t, real value) +void THTensor_(cpow)(THTensor *r_, THTensor *t, THTensor *src) { THTensor_(resizeAs)(r_, t); int64_t r_Size = THTensor_(nElement)(r_); + int64_t srcSize = THTensor_(nElement)(src); int r_Contig = THTensor_(isContiguous)(r_); int tContig = THTensor_(isContiguous)(t); + int srcContig = THTensor_(isContiguous)(src); int serial_path = 0; - if (r_Contig && tContig) { - TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(adds)(r__data, t_data, value, r__len);); - } else { -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; + if (srcSize == r_Size){ + if (r_Contig && tContig && srcContig) { + real *tp = THTensor_(data)(t); + real *sp = THTensor_(data)(src); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i> sp[i]; #else - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((ureal) *t_data) << value);); -#endif - } + rp[i] = ((ureal) tp[i]) >> sp[i]; #endif -} - -void THTensor_(rshift)(THTensor *r_, THTensor *t, real value) -{ + } + } else { +#if _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { #if defined(TH_REAL_IS_FLOAT) - return THTensor_(div)(r_, t, powf(2, value)); + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data / powf(2, *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); #elif defined(TH_REAL_IS_DOUBLE) - return THTensor_(div)(r_, t, pow(2, value)); -#elif defined(TH_REAL_IS_HALF) - return THError("rshift is not supported for torch.HalfTensor"); -#else - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int serial_path = 0; - if (r_Contig && tContig) { - real *tp = THTensor_(data)(t); - real *rp = THTensor_(data)(r_); - int64_t i; - #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) - for (i=0; i> value; + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data / pow(2, *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); +#elif defined(TH_REAL_IS_BYTE) + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = ((real)*t_data) >> *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); #else - rp[i] = ((ureal) tp[i]) >> value; + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = ((ureal)*t_data) >> *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); #endif - } - } else { -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { -#if defined(TH_REAL_IS_BYTE) - TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (((real) *t_data) >> value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } #else - TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (((ureal) *t_data) >> value);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + serial_path = 1; #endif } -#else + } else { serial_path = 1; -#endif } if (serial_path) { -#if defined(TH_REAL_IS_BYTE) - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((real) *t_data) >> value);); +#if defined(TH_REAL_IS_FLOAT) + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data / powf(2, *src_data);); +#elif defined(TH_REAL_IS_DOUBLE) + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data / pow(2, *src_data);); +#elif defined(TH_REAL_IS_BYTE) + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = ((real)*t_data) >> *src_data;); #else - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (((ureal) *t_data) >> value);); + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = ((ureal)*t_data) >> *src_data;); #endif } -#endif } -void THTensor_(fmod)(THTensor *r_, THTensor *t, real value) +void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src) { THTensor_(resizeAs)(r_, t); int64_t r_Size = THTensor_(nElement)(r_); + int64_t srcSize = THTensor_(nElement)(src); int r_Contig = THTensor_(isContiguous)(r_); int tContig = THTensor_(isContiguous)(t); + int srcContig = THTensor_(isContiguous)(src); int serial_path = 0; - if (r_Contig && tContig) { - real *tp = THTensor_(data)(t); - real *rp = THTensor_(data)(r_); - int64_t i; - #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD * 100) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i max_value ? max_value : tp[i]); - } else { -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = (*t_data < min_value) ? min_value : (*t_data > max_value ? max_value : *t_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } -#else - serial_path = 1; -#endif - } - if (serial_path) { - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = (*t_data < min_value) ? min_value : (*t_data > max_value ? max_value : *t_data);); - } -} - -void THTensor_(cadd)(THTensor *r_, THTensor *t, real value, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - int serial_path = 0; - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - if(r_ == t) { - THBlas_(axpy)(THTensor_(nElement)(t), value, THTensor_(data)(src), 1, THTensor_(data)(r_), 1); - } else { - TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cadd)(r__data, t_data, src_data, value, r__len);); - } - } else { -#if _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data + value * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } -#else - serial_path = 1; -#endif - } - } else { - serial_path = 1; - } - if (serial_path) { - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data + value * *src_data;); - } -} - -void THTensor_(csub)(THTensor *r_, THTensor *t, real value, THTensor *src) -{ - THTensor_(cadd)(r_, t, -value, src); -} - -void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - int serial_path = 0; - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - TH_TENSOR_APPLY3_CONTIG(real, r_, real, t, real, src, THVector_(cmul)(r__data, t_data, src_data, r__len);); - } else { -#if _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data * *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } -#else - serial_path = 1; -#endif - } - } else { - serial_path = 1; - } - if (serial_path) { - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data * *src_data;); - } -} - -void THTensor_(pow)(THTensor *r_, THTensor *t, real value) -{ - THTensor_(resizeAs)(r_, t); - if(value == 1){ - THTensor_(copy)(r_, t); - } - else if(value == 2){ - THTensor_(cmul)(r_, t, t); - } - else if(value == 3){ - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = *t_data * *t_data * *t_data;); - } -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) -#if defined (TH_REAL_IS_FLOAT) -#define TH_MATH_NAME(fn) fn##f -#else -#define TH_MATH_NAME(fn) fn -#endif - else if(value == 0.5){ - THTensor_(sqrt)(r_, t); - } - else if(value == -0.5){ - THTensor_(rsqrt)(r_, t); - } - else if(value == -1){ - THTensor_(cinv)(r_, t); - } - else if(value == -2){ - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data);); - } - else{ - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = TH_MATH_NAME(pow)(*t_data, value);); - } -#undef TH_MATH_NAME -#else - else { - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = THTensor_(powOne)(*t_data, value);); - } -#endif -} - -void THTensor_(cpow)(THTensor *r_, THTensor *t, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - int serial_path = 0; - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - real *tp = THTensor_(data)(t); - real *sp = THTensor_(data)(src); - real *rp = THTensor_(data)(r_); - int64_t i; - #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i> sp[i]; -#else - rp[i] = ((ureal) tp[i]) >> sp[i]; -#endif - } - } else { -#if _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { -#if defined(TH_REAL_IS_FLOAT) - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data / powf(2, *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); -#elif defined(TH_REAL_IS_DOUBLE) - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data / pow(2, *src_data);, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); -#elif defined(TH_REAL_IS_BYTE) - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = ((real)*t_data) >> *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); -#else - TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = ((ureal)*t_data) >> *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); -#endif - } -#else - serial_path = 1; -#endif - } - } else { - serial_path = 1; - } - if (serial_path) { -#if defined(TH_REAL_IS_FLOAT) - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data / powf(2, *src_data);); -#elif defined(TH_REAL_IS_DOUBLE) - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data / pow(2, *src_data);); -#elif defined(TH_REAL_IS_BYTE) - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = ((real)*t_data) >> *src_data;); -#else - TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = ((ureal)*t_data) >> *src_data;); -#endif - } -} - -void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src) -{ - THTensor_(resizeAs)(r_, t); - int64_t r_Size = THTensor_(nElement)(r_); - int64_t srcSize = THTensor_(nElement)(src); - int r_Contig = THTensor_(isContiguous)(r_); - int tContig = THTensor_(isContiguous)(t); - int srcContig = THTensor_(isContiguous)(src); - int serial_path = 0; - if (srcSize == r_Size){ - if (r_Contig && tContig && srcContig) { - real *tp = THTensor_(data)(t); - real *sp = THTensor_(data)(src); - real *rp = THTensor_(data)(r_); - int64_t i; - #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; i TH_OMP_OVERHEAD_THRESHOLD) private(i) - for (i=0; idim() != 2) || (vec->dim() != 1) ) - THError("matrix and vector expected, got %dD, %dD", - mat->dim(), vec->dim()); - - if( mat->size(1) != vec->size(0) ) { - THDescBuff bm = THTensor_(sizeDesc)(mat); - THDescBuff bv = THTensor_(sizeDesc)(vec); - THError("size mismatch, %s, %s", bm.str, bv.str); - } - - if(t->dim() != 1) - THError("vector expected, got t: %dD", t->dim()); - - if(t->size(0) != mat->size(0)) { - THDescBuff bt = THTensor_(sizeDesc)(t); - THDescBuff bm = THTensor_(sizeDesc)(mat); - THError("size mismatch, t: %s, mat: %s", bt.str, bm.str); - } - - if(r_ != t) - { - THTensor_(resizeAs)(r_, t); - THTensor_(copy)(r_, t); - } - - // n == 1 || lda >= max(1, m) - #define LDA_COND(M, N, LDA) ((N) == 1 || (LDA) >= THMax(1, (M))) - - if(mat->stride(0) == 1 && LDA_COND(mat->size(0), mat->size(1), mat->stride(1))) - { - THBlas_(gemv)('n', mat->size(0), mat->size(1), - alpha, THTensor_(data)(mat), mat->stride(1), - THTensor_(data)(vec), vec->stride(0), - beta, THTensor_(data)(r_), r_->stride(0)); - } - else if(mat->stride(1) == 1 && LDA_COND(mat->size(1), mat->size(0), mat->stride(0))) - { - THBlas_(gemv)('t', mat->size(1), mat->size(0), - alpha, THTensor_(data)(mat), mat->stride(0), - THTensor_(data)(vec), vec->stride(0), - beta, THTensor_(data)(r_), r_->stride(0)); - } - else - { - THTensor *cmat = THTensor_(newContiguous)(mat); - - THBlas_(gemv)('t', mat->size(1), mat->size(0), - alpha, THTensor_(data)(cmat), cmat->stride(0), - THTensor_(data)(vec), vec->stride(0), - beta, THTensor_(data)(r_), r_->stride(0)); - - THTensor_(free)(cmat); - } - - #undef LDA_COND -} - -void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain) -{ - int64_t N1 = m1->size(0); - int64_t N2 = m2->size(0); - int64_t dim; - real *m1_p; - real *m2_p; - real *r_p; - int64_t i; - - THTensor_(resize2d)(r_, N1, N2); - - m1 = THTensor_(newContiguous)(m1); - m2 = THTensor_(newContiguous)(m2); - - THTensor_(resize2d)(m1, N1, THTensor_(nElement)(m1) / N1); - THTensor_(resize2d)(m2, N2, THTensor_(nElement)(m2) / N2); - - dim = m1->size(1); - THArgCheck(m1->size(1) == m2->size(1), 3, "m1 and m2 must have the same inner vector dim"); - - m1_p = THTensor_(data)(m1); - m2_p = THTensor_(data)(m2); - r_p = THTensor_(data)(r_); - -#pragma omp parallel for private(i) - for (i=0; idim() != 2) || (m2->dim() != 2)) - THError("matrices expected, got %dD, %dD tensors", m1->dim(), m2->dim()); - - if(m1->size(1) != m2->size(0)) { - THDescBuff bm1 = THTensor_(sizeDesc)(m1); - THDescBuff bm2 = THTensor_(sizeDesc)(m2); - THError("size mismatch, m1: %s, m2: %s", bm1.str, bm2.str); - } - - if( t->dim() != 2 ) - THError("matrix expected, got %dD tensor for t", t->dim()); - - if( (t->size(0) != m1->size(0)) || (t->size(1) != m2->size(1)) ) { - THDescBuff bt = THTensor_(sizeDesc)(t); - THDescBuff bm1 = THTensor_(sizeDesc)(m1); - THDescBuff bm2 = THTensor_(sizeDesc)(m2); - THError("size mismatch, t: %s, m1: %s, m2: %s", bt.str, bm1.str, bm2.str); - } - - if(t != r_) - { - THTensor_(resizeAs)(r_, t); - if (beta != 0.0) { - THTensor_(copy)(r_, t); - } - } - - // n == 1 || ldc >= max(1, m) - #define LDC_COND(M, N, LDC) ((N) == 1 || (LDC) >= THMax(1, M)) - - /* r_ */ - if(r_->stride(0) == 1 && - LDC_COND(r_->size(0), r_->size(1), r_->stride(1))) - { - transpose_r = 'n'; - r__ = r_; - } - else if(r_->stride(1) == 1 && - LDC_COND(r_->size(1), r_->size(0), r_->stride(0))) - { - THTensor *swap = m2; - m2 = m1; - m1 = swap; - transpose_r = 't'; - r__ = r_; - } - else - { - transpose_r = 'n'; - // make r__ FORTRAN contiguous - THTensor *transp_r_ = THTensor_(newTranspose)(r_, 0, 1); - r__ = THTensor_(newClone)(transp_r_); - THTensor_(free)(transp_r_); - THTensor_(transpose)(r__, NULL, 0, 1); - } - - #undef LDC_COND - - int64_t m = r__->size((transpose_r == 'n' ? 0 : 1)); - int64_t n = r__->size((transpose_r == 'n' ? 1 : 0)); - int64_t k = m1->size((transpose_r == 'n' ? 1 : 0)); - int64_t ldr__ = r__->stride((transpose_r == 'n' ? 1 : 0)); - - /* m1 */ - /* Need ldm1_ >= max(1, (transpose_m1 == 'n' ? m : k)) */ - if(m1->stride((transpose_r == 'n' ? 0 : 1)) == 1 && - m1->stride((transpose_r == 'n' ? 1 : 0)) >= THMax(1, m)) - { - transpose_m1 = 'n'; - m1_ = m1; - } - else if(m1->stride((transpose_r == 'n' ? 1 : 0)) == 1 && - m1->stride((transpose_r == 'n' ? 0 : 1)) >= THMax(1, k)) - { - transpose_m1 = 't'; - m1_ = m1; - } - else - { - transpose_m1 = (transpose_r == 'n' ? 't' : 'n'); - m1_ = THTensor_(newContiguous)(m1); - free_m1 = 1; - } - - /* m2 */ - /* Need ldm2_ >= max(1, (transpose_m2 == 'n' ? k : n)) */ - if(m2->stride((transpose_r == 'n' ? 0 : 1)) == 1 && - m2->stride((transpose_r == 'n' ? 1 : 0)) >= THMax(1, k)) - { - transpose_m2 = 'n'; - m2_ = m2; - } - else if(m2->stride((transpose_r == 'n' ? 1 : 0)) == 1 && - m2->stride((transpose_r == 'n' ? 0 : 1)) >= THMax(1, n)) - { - transpose_m2 = 't'; - m2_ = m2; - } - else - { - transpose_m2 = (transpose_r == 'n' ? 't' : 'n'); - m2_ = THTensor_(newContiguous)(m2); - free_m2 = 1; - } - - int64_t ldm1_ = (transpose_m1 == 'n' ? m1_->stride((transpose_r == 'n' ? 1 : 0)) : m1_->stride((transpose_r == 'n' ? 0 : 1))); - int64_t ldm2_ = (transpose_m2 == 'n' ? m2_->stride((transpose_r == 'n' ? 1 : 0)) : m2_->stride((transpose_r == 'n' ? 0 : 1))); - -#pragma omp critical(blasgemm) - /* do the operation */ - THBlas_(gemm)(transpose_m1, - transpose_m2, - m, - n, - k, - alpha, - THTensor_(data)(m1_), - ldm1_, - THTensor_(data)(m2_), - ldm2_, - beta, - THTensor_(data)(r__), - ldr__); - - /* free intermediate variables */ - if(free_m1) - THTensor_(free)(m1_); - - if(free_m2) - THTensor_(free)(m2_); - - if(r__ != r_) - THTensor_(freeCopyTo)(r__, r_); -} - -void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2) -{ - if( (vec1->dim() != 1) || (vec2->dim() != 1) ) - THError("vector and vector expected, got %dD, %dD tensors", - vec1->dim(), vec2->dim()); - - if(t->dim() != 2) - THError("expected matrix, got %dD tensor for t", t->dim()); - - if( (t->size(0) != vec1->size(0)) || (t->size(1) != vec2->size(0)) ) { - THDescBuff bt = THTensor_(sizeDesc)(t); - THDescBuff bv1 = THTensor_(sizeDesc)(vec1); - THDescBuff bv2 = THTensor_(sizeDesc)(vec2); - THError("size mismatch, t: %s, vec1: %s, vec2: %s", bt.str, bv1.str, bv2.str); - } - - if(r_ != t) - { - THTensor_(resizeAs)(r_, t); - THTensor_(copy)(r_, t); - } - - if(beta == 0) { - THTensor_(zero)(r_); - } - else if(beta != 1) - THTensor_(mul)(r_, r_, beta); - - // n == 1 || lda >= max(1, m) - #define LDA_COND(M, N, LDA) ((N) == 1 || (LDA) >= THMax(1, (M))) - - if(r_->stride(0) == 1 && LDA_COND(vec1->size(0), vec2->size(0), r_->stride(1))) - { - THBlas_(ger)(vec1->size(0), vec2->size(0), - alpha, THTensor_(data)(vec1), vec1->stride(0), - THTensor_(data)(vec2), vec2->stride(0), - THTensor_(data)(r_), r_->stride(1)); - } - else if(r_->stride(1) == 1 && LDA_COND(vec2->size(0), vec1->size(0), r_->stride(0))) - { - THBlas_(ger)(vec2->size(0), vec1->size(0), - alpha, THTensor_(data)(vec2), vec2->stride(0), - THTensor_(data)(vec1), vec1->stride(0), - THTensor_(data)(r_), r_->stride(0)); - } - else - { - THTensor *cr = THTensor_(newClone)(r_); - - THBlas_(ger)(vec2->size(0), vec1->size(0), - alpha, THTensor_(data)(vec2), vec2->stride(0), - THTensor_(data)(vec1), vec1->stride(0), - THTensor_(data)(cr), cr->stride(0)); - - THTensor_(freeCopyTo)(cr, r_); - } - - #undef LDA_COND -} - -void THTensor_(addbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2) -{ - int64_t batch; - - THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor"); - THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor"); - THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2, - "equal number of batches expected, got %d, %d", - THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0)); - THArgCheck(THTensor_(size)(batch1, 2) == THTensor_(size)(batch2, 1), 2, - "wrong matrix size, batch1: %dx%d, batch2: %dx%d", - THTensor_(size)(batch1, 1), THTensor_(size)(batch1,2), - THTensor_(size)(batch2, 1), THTensor_(size)(batch2,2)); - - int64_t dim1 = THTensor_(size)(batch1, 1); - int64_t dim2 = THTensor_(size)(batch2, 2); - THArgCheck(THTensor_(size)(t, 0) == dim1, 1, "output tensor of incorrect size"); - THArgCheck(THTensor_(size)(t, 1) == dim2, 1, "output tensor of incorrect size"); - - if (t != result) { - THTensor_(resizeAs)(result, t); - if (beta != 0.0) { - THTensor_(copy)(result, t); - } - } - - THTensor *matrix1 = THTensor_(new)(); - THTensor *matrix2 = THTensor_(new)(); - - for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) { - THTensor_(select)(matrix1, batch1, 0, batch); - THTensor_(select)(matrix2, batch2, 0, batch); - - THTensor_(addmm)(result, beta, result, alpha, matrix1, matrix2); - beta = 1; // accumulate output once - } - - THTensor_(free)(matrix1); - THTensor_(free)(matrix2); -} - -void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2) -{ - int64_t batch; - - THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch1)); - THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch2)); - THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2, - "equal number of batches expected, got %d, %d", - THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0)); - THArgCheck(THTensor_(size)(batch1, 2) == THTensor_(size)(batch2, 1), 2, - "wrong matrix size, batch1: %dx%d, batch2: %dx%d", - THTensor_(size)(batch1, 1), THTensor_(size)(batch1, 2), - THTensor_(size)(batch2, 1), THTensor_(size)(batch2, 2)); - - int64_t bs = THTensor_(size)(batch1, 0); - int64_t dim1 = THTensor_(size)(batch1, 1); - int64_t dim2 = THTensor_(size)(batch2, 2); - THArgCheck(THTensor_(size)(t, 0) == bs, 1, "output tensor of incorrect size"); - THArgCheck(THTensor_(size)(t, 1) == dim1, 1, "output tensor of incorrect size"); - THArgCheck(THTensor_(size)(t, 2) == dim2, 1, "output tensor of incorrect size"); - - if (t != result) { - THTensor_(resizeAs)(result, t); - if (beta != 0.0) { - THTensor_(copy)(result, t); - } - } - - THTensor *matrix1 = THTensor_(new)(); - THTensor *matrix2 = THTensor_(new)(); - THTensor *result_matrix = THTensor_(new)(); - - for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) { - THTensor_(select)(matrix1, batch1, 0, batch); - THTensor_(select)(matrix2, batch2, 0, batch); - THTensor_(select)(result_matrix, result, 0, batch); - - THTensor_(addmm)(result_matrix, beta, result_matrix, alpha, matrix1, matrix2); - } - - THTensor_(free)(matrix1); - THTensor_(free)(matrix2); - THTensor_(free)(result_matrix); -} - -ptrdiff_t THTensor_(numel)(THTensor *t) -{ - return THTensor_(nElement)(t); -} - - -// Helper function to be used in a reduction operation. -// Due to resize semantics of outputs, if the specified output tensor r_ has -// same size as the output of the reduction operation, then any noncontiguities -// in r_ should be preserved. -// The reduction operation, however, needs to act on r_ with an extra dimension -// (the reduced dimension), so this function "resizes" r_ and preserves its -// noncontiguities if necessary. -void THTensor_(preserveReduceDimSemantics)( - THTensor *r_, int in_dims, int reduce_dimension, int keepdim) { - if (r_ && !keepdim && - THTensor_(_nDimension)(r_) == in_dims - 1 && - THTensor_(_nDimension)(r_) != 0) { - THTensor_(unsqueeze1d)(r_, r_, reduce_dimension); - } -} - -void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) -{ - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - int in_dims = THTensor_(_nDimension)(t); - THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); - THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(values_, dim, NULL); - THLongTensor_resize(indices_, dim, NULL); - THLongStorage_free(dim); - - // two implementations optimized for data locality - if (t->stride(dimension) == 1) { - real theMax; - real value; - int64_t theIndex; - int64_t i; - TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - theMax = t_data[0]; - theIndex = 0; - - for(i = 0; i < t_size; i++) - { - value = t_data[i*t_stride]; - /* This is not the same as value>theMax in the case of NaNs */ - if(!(value <= theMax)) - { - theIndex = i; - theMax = value; - th_isnan_break(value) - } - } - *indices__data = theIndex; - *values__data = theMax;); - } else { - if (THTensor_(_nDimension)(t) > 1) { - THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); - THTensor_(copy)(values_, t0); - THTensor_(free)(t0); - } else { - THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); - } - THLongTensor_zero(indices_); - - if(t->size(dimension) == 1) { - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } - return; - } - - THTensor *tempValues_ = THTensor_(newWithTensor)(values_); - // tempValues_.expand_as(t) - THTensor_setSizeAtDim(tempValues_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(tempValues_, dimension, 0); - - THLongTensor *tempIndices_ = THLongTensor_newWithTensor(indices_); - // tempIndices_.expand_as(t) - THTensor_setSizeAtDim(tempIndices_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(tempIndices_, dimension, 0); - - TH_TENSOR_APPLY3_D(real, t, real, tempValues_, int64_t, tempIndices_, dimension, - if(!(*t_data <= *tempValues__data) && !th_isnan(*tempValues__data)) { - *tempValues__data = *t_data; - *tempIndices__data = *tempIndices__dimOffset; - }); - - THTensor_(free)(tempValues_); - THLongTensor_free(tempIndices_); - } - - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } -} - -void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) -{ - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - int in_dims = THTensor_(_nDimension)(t); - THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); - THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(values_, dim, NULL); - THLongTensor_resize(indices_, dim, NULL); - THLongStorage_free(dim); - - // two implementations optimized for data locality - if (t->stride(dimension) == 1) { - real theMax; - real value; - int64_t theIndex; - int64_t i; - TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - theMax = t_data[0]; - theIndex = 0; - - for(i = 0; i < t_size; i++) - { - value = t_data[i*t_stride]; - /* This is not the same as value>theMax in the case of NaNs */ - if(!(value >= theMax)) - { - theIndex = i; - theMax = value; - th_isnan_break(value) - } - } - *indices__data = theIndex; - *values__data = theMax;); - } else { - if (THTensor_(_nDimension)(t) > 1) { - THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); - THTensor_(copy)(values_, t0); - THTensor_(free)(t0); - } else { - THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); - } - THLongTensor_zero(indices_); - - if(t->size(dimension) == 1) { - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } - return; - } - - THTensor *tempValues_ = THTensor_(newWithTensor)(values_); - // tempValues_.expand_as(t) - THTensor_setSizeAtDim(tempValues_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(tempValues_, dimension, 0); - - THLongTensor *tempIndices_ = THLongTensor_newWithTensor(indices_); - // tempIndices_.expand_as(t) - THTensor_setSizeAtDim(tempIndices_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(tempIndices_, dimension, 0); - - TH_TENSOR_APPLY3_D(real, t, real, tempValues_, int64_t, tempIndices_, dimension, - if(!(*t_data >= *tempValues__data) && !th_isnan(*tempValues__data)) { - *tempValues__data = *t_data; - *tempIndices__data = *tempIndices__dimOffset; - }); - - THTensor_(free)(tempValues_); - THLongTensor_free(tempIndices_); - } - - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } -} - -void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim) -{ - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - - int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - int r_Contig = THTensor_(isContiguous)(r_); - real *tp = THTensor_(data)(t); - real *rp = THTensor_(data)(r_); - if(r_Contig && (tp != rp)){ - ptrdiff_t iter = 0; - ptrdiff_t r_Size = THTensor_(nElement)(r_); - int r_Dim = r_->_dim(); - #pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD) - for (iter = 0; iter < r_Size; iter++) { - int j; - int64_t quot; - int64_t rem = iter; - ptrdiff_t tBasicIndex = 0; - - for(j = 0; j < r_Dim; ++j) { - if(j != dimension){ - quot = rem/r_->stride(j); - rem = rem%r_->stride(j); - tBasicIndex += quot*t->stride(j); - } - } - real *t_data = tp+tBasicIndex; - real *r__data = rp+iter; - *r__data = 0; - for(j=0; j < t->size(dimension); ++j) { - *r__data += *(t_data + j*t->stride(dimension)); - } - } - } else { - serial_path = 1; - } - } -#else - serial_path = 1; -#endif - if (serial_path) { - // two implementations optimized for data locality - if (t->stride(dimension) == 1) { - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal sum = 0; - int64_t i; - for(i = 0; i < t_size; i++) - sum += t_data[i*t_stride]; - *r__data = (real)sum;); - } else { - THTensor_(zero)(r_); - THTensor *temp_ = THTensor_(newWithTensor)(r_); - // r_.expand_as(t) - THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(temp_, dimension, 0); - - TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;); - THTensor_(free)(temp_); - } - } - - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); - } -} - -void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim) -{ - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - - int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - int r_Contig = THTensor_(isContiguous)(r_); - real *tp = THTensor_(data)(t); - real *rp = THTensor_(data)(r_); - if(r_Contig && (tp != rp)){ - ptrdiff_t iter = 0; - ptrdiff_t r_Size = THTensor_(nElement)(r_); - int r_Dim = r_->_dim(); - #pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD) - for (iter = 0; iter < r_Size; iter++) { - int j; - int64_t quot; - int64_t rem = iter; - ptrdiff_t tBasicIndex = 0; - - for(j = 0; j < r_Dim; ++j) { - if(j != dimension){ - quot = rem/r_->stride(j); - rem = rem%r_->stride(j); - tBasicIndex += quot*t->stride(j); - } - } - real *t_data = tp+tBasicIndex; - real *r__data = rp+iter; - *r__data = 1; - for(j=0; j < t->size(dimension); ++j) { - *r__data *= *(t_data + j*t->stride(dimension)); - } - } - } else { - serial_path = 1; - } - } -#else - serial_path = 1; -#endif - - if(serial_path) { - // two implementations optimized for data locality - if (t->stride(dimension) == 1) { - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal prod = 1; - int64_t i; - for(i = 0; i < t_size; i++) - prod *= t_data[i*t_stride]; - *r__data = (real)prod;); - } else { - THTensor_(fill)(r_, 1); - THTensor *temp_ = THTensor_(newWithTensor)(r_); - // r_.expand_as(t) - THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(temp_, dimension, 0); - - TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;); - THTensor_(free)(temp_); - } - } - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); - } -} - -void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension) -{ - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(resizeAs)(r_, t); - - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal cumsum = 0; - int64_t i; - for(i = 0; i < t_size; i++) - { - cumsum += t_data[i*t_stride]; - r__data[i*r__stride] = (real)cumsum; - }); -} - -void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension) -{ - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(resizeAs)(r_, t); - - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal cumprod = 1; - int64_t i; - for(i = 0; i < t_size; i++) - { - cumprod *= t_data[i*t_stride]; - r__data[i*r__stride] = (real)cumprod; - }); -} - - -void THTensor_(sign)(THTensor *r_, THTensor *t) -{ - THTensor_(resizeAs)(r_, t); - -#if defined (TH_REAL_IS_BYTE) - TH_TENSOR_APPLY2(real, r_, real, t, - if (*t_data > 0) *r__data = 1; - else *r__data = 0;); -#else - TH_TENSOR_APPLY2(real, r_, real, t, - if (*t_data > 0) *r__data = 1; - else if (*t_data < 0) *r__data = -1; - else *r__data = 0;); -#endif -} - - -accreal THTensor_(trace)(THTensor *t) -{ - real *t_data = THTensor_(data)(t); - accreal sum = 0; - int64_t i = 0; - int64_t t_stride_0, t_stride_1, t_diag_size; - - THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); - - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - t_diag_size = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)); - while(i < t_diag_size) - { - sum += t_data[i*(t_stride_0+t_stride_1)]; - i++; - } - - return sum; -} - -void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension) -{ - int i; - - if(THTensor_(nDimension)(a) != THTensor_(nDimension)(b)) - THError("inconsistent tensor dimension %dD, %dD", - THTensor_(nDimension)(a), THTensor_(nDimension)(b)); - - for(i = 0; i < THTensor_(nDimension)(a); i++) - { - if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) { - THDescBuff ba = THTensor_(sizeDesc)(a); - THDescBuff bb = THTensor_(sizeDesc)(b); - THError("inconsistent tensor sizes %s, %s", ba.str, bb.str); - } - } - - if(dimension < 0) - { - for(i = 0; i < THTensor_(nDimension)(a); i++) - { - if(THTensor_(size)(a, i) == 3) - { - dimension = i; - break; - } - } - if(dimension < 0) { - THDescBuff ba = THTensor_(sizeDesc)(a); - THError("no dimension of size 3 in a: %s", ba.str); - } - } - - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(a), 3, "dimension %d out of range", - dimension + TH_INDEX_BASE); - THArgCheck(THTensor_(size)(a, dimension) == 3, 3, "dimension %d does not have size 3", - dimension + TH_INDEX_BASE); - - THTensor_(resizeAs)(r_, a); - - TH_TENSOR_DIM_APPLY3(real, a, real, b, real, r_, dimension, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - r__data[0*r__stride] = a_data[1*a_stride]*b_data[2*b_stride] - a_data[2*a_stride]*b_data[1*b_stride]; - r__data[1*r__stride] = a_data[2*a_stride]*b_data[0*b_stride] - a_data[0*a_stride]*b_data[2*b_stride]; - r__data[2*r__stride] = a_data[0*a_stride]*b_data[1*b_stride] - a_data[1*a_stride]*b_data[0*b_stride];); -} - -void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src) { - THTensor_(resizeAs)(r, t); - TH_TENSOR_APPLY3(real, r, real, t, real, src, - *r_data = *t_data > *src_data ? *t_data : *src_data;); -} - -void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src) { - THTensor_(resizeAs)(r, t); - TH_TENSOR_APPLY3(real, r, real, t, real, src, - *r_data = *t_data < *src_data ? *t_data : *src_data;); -} - -void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value) { - THTensor_(resizeAs)(r, t); - TH_TENSOR_APPLY2(real, r, real, t, - *r_data = *t_data < value ? value : *t_data;); // this order propagates NaN -} - -void THTensor_(cminValue)(THTensor *r, THTensor *t, real value) { - THTensor_(resizeAs)(r, t); - TH_TENSOR_APPLY2(real, r, real, t, - *r_data = *t_data > value ? value : *t_data;); // this order propagates NaN -} - -void THTensor_(zerosLike)(THTensor *r_, THTensor *input) -{ - THTensor_(resizeAs)(r_, input); - THTensor_(zero)(r_); -} - -void THTensor_(onesLike)(THTensor *r_, THTensor *input) -{ - THTensor_(resizeAs)(r_, input); - THTensor_(fill)(r_, 1); -} - -void THTensor_(diag)(THTensor *r_, THTensor *t, int k) -{ -#ifndef USE_TH_SIZE_ZERO_DIM - AT_ASSERT(!t->is_empty()) -#endif - THArgCheck(THTensor_(nDimension)(t) == 1 || THTensor_(nDimension)(t) == 2, 1, "matrix or a vector expected"); - - if(THTensor_(nDimension)(t) == 1) - { - real *t_data = THTensor_(data)(t); - int64_t t_stride_0 = THTensor_(stride)(t, 0); - int64_t t_size = THTensor_(size)(t, 0); - int64_t sz = t_size + (k >= 0 ? k : -k); - real *r__data; - int64_t r__stride_0; - int64_t r__stride_1; - int64_t i; - - THTensor_(resize2d)(r_, sz, sz); - THTensor_(zero)(r_); - r__data = THTensor_(data)(r_); - r__stride_0 = THTensor_(stride)(r_, 0); - r__stride_1 = THTensor_(stride)(r_, 1); - r__data += (k >= 0 ? k*r__stride_1 : -k*r__stride_0); - - for(i = 0; i < t_size; i++) - r__data[i*(r__stride_0+r__stride_1)] = t_data[i*t_stride_0]; - } - else - { - real *t_data = THTensor_(data)(t); - int64_t t_stride_0 = THTensor_(stride)(t, 0); - int64_t t_stride_1 = THTensor_(stride)(t, 1); - int64_t sz; - real *r__data; - int64_t r__stride_0; - int64_t i; - - if(k >= 0) - sz = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)-k); - else - sz = THMin(THTensor_(size)(t, 0)+k, THTensor_(size)(t, 1)); - THTensor_(resize1d)(r_, sz); - r__data = THTensor_(data)(r_); - r__stride_0 = THTensor_(stride)(r_, 0); - - t_data += (k >= 0 ? k*t_stride_1 : -k*t_stride_0); - for(i = 0; i < sz; i++) - r__data[i*r__stride_0] = t_data[i*(t_stride_0+t_stride_1)]; - } -} - -void THTensor_(eye)(THTensor *r_, int64_t n, int64_t m) -{ - real *r__data; - int64_t i, sz; - - THArgCheck(n > 0, 1, "invalid argument"); - - if(m <= 0) - m = n; - - THTensor_(resize2d)(r_, n, m); - THTensor_(zero)(r_); - - i = 0; - r__data = THTensor_(data)(r_); - sz = THMin(THTensor_(size)(r_, 0), THTensor_(size)(r_, 1)); - for(i = 0; i < sz; i++) - r__data[i*(r_->stride(0)+r_->stride(1))] = 1; -} - - -void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step) -{ - ptrdiff_t size; - real i = 0; - - THArgCheck(step > 0 || step < 0, 3, "step must be nonzero"); - THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin)) - , 2, "upper bound and larger bound inconsistent with step sign"); - - size = (ptrdiff_t) (((xmax - xmin) / step) + 1); - - if (THTensor_(nElement)(r_) != size) { - THTensor_(resize1d)(r_, size); - } - - TH_TENSOR_APPLY(real, r_, *r__data = xmin + (i++)*step;); -} - -void THTensor_(arange)(THTensor *r_, accreal xmin, accreal xmax, accreal step) { - ptrdiff_t size; - real i = 0; - - THArgCheck(step > 0 || step < 0, 3, "step must be nonzero"); - THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin)) - , 2, "upper bound and larger bound inconsistent with step sign"); - - size = (ptrdiff_t) ceil((double)(xmax - xmin) / step); - - if (THTensor_(nElement)(r_) != size) { - THTensor_(resize1d)(r_, size); - } - - TH_TENSOR_APPLY(real, r_, *r__data = xmin + (i++)*step;); -} - -void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n) -{ - real *r__data; - int64_t r__stride_0; - int64_t i; - - THArgCheck(n > 0, 1, "must be strictly positive"); - - THTensor_(resize1d)(r_, n); - r__data = THTensor_(data)(r_); - r__stride_0 = THTensor_(stride)(r_,0); - - for(i = 0; i < n; i++) - r__data[i*r__stride_0] = (real)(i); - - for(i = 0; i < n-1; i++) - { - int64_t z = THRandom_random(_generator) % (n-i); - real sav = r__data[i*r__stride_0]; - r__data[i*r__stride_0] = r__data[(z+i)*r__stride_0]; - r__data[(z+i)*r__stride_0] = sav; - } -} - -/* I cut and pasted (slightly adapted) the quicksort code from - Sedgewick's 1978 "Implementing Quicksort Programs" article - http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf - - It is the state of the art existing implementation. The macros - are here to make as close a match as possible to the pseudocode of - Program 2 p.851 - - Note that other partition schemes exist, and are typically presented - in textbook, but those are less efficient. See e.g. - http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto - - Julien, November 12th 2013 -*/ -#define MAX_LEVELS 300 -#define M_SMALL 10 /* Limit for small subfiles */ - -#define ARR(III) arr[(III)*stride] -#define IDX(III) idx[(III)*stride] - -#define LONG_SWAP(AAA, BBB) swap = AAA; AAA = BBB; BBB = swap -#define REAL_SWAP(AAA, BBB) rswap = AAA; AAA = BBB; BBB = rswap - -#define ARR_SWAP(III, JJJ) \ - REAL_SWAP(ARR(III), ARR(JJJ)); - -#define BOTH_SWAP(III, JJJ) \ - REAL_SWAP(ARR(III), ARR(JJJ)); \ - LONG_SWAP(IDX(III), IDX(JJJ)) - -static void THTensor_(quicksortascend)(real *arr, int64_t *idx, int64_t elements, int64_t stride) -{ - int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left; - real rswap, piv; - unsigned char done = 0; - - /* beg[0]=0; end[0]=elements; */ - stack = 0; - L = 0; R = elements-1; - done = elements-1 <= M_SMALL; - - while(!done) { - /* Use median of three for pivot choice */ - P=(L+R)>>1; - BOTH_SWAP(P, L+1); - if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); } - if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); } - if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); } - - i = L+1; j = R; piv = ARR(L); pid = IDX(L); - - do { - do { i = i+1; } while(ARR(i) < piv); - do { j = j-1; } while(ARR(j) > piv); - if (j < i) - break; - BOTH_SWAP(i, j); - } while(1); - BOTH_SWAP(L, j); - /* Left subfile is (L, j-1) */ - /* Right subfile is (i, R) */ - sz_left = j-L; - sz_right = R-i+1; - if (sz_left <= M_SMALL && sz_right <= M_SMALL) { - /* both subfiles are small */ - /* if stack empty */ - if (stack == 0) { - done = 1; - } else { - stack--; - L = beg[stack]; - R = end[stack]; - } - } else if (sz_left <= M_SMALL || sz_right <= M_SMALL) { - /* exactly one of the subfiles is small */ - /* (L,R) = large subfile */ - if (sz_left > sz_right) { - /* Implicit: L = L; */ - R = j-1; - } else { - L = i; - /* Implicit: R = R; */ - } - } else { - /* none of the subfiles is small */ - /* push large subfile */ - /* (L,R) = small subfile */ - if (sz_left > sz_right) { - beg[stack] = L; - end[stack] = j-1; - stack++; - L = i; - /* Implicit: R = R */ - } else { - beg[stack] = i; - end[stack] = R; - stack++; - /* Implicit: L = L; */ - R = j-1; - } - } - } /* while not done */ - /* Now insertion sort on the concatenation of subfiles */ - for(i=elements-2; i>=0; i--) { - if (ARR(i) > ARR(i+1)) { - piv = ARR(i); - pid = IDX(i); - j = i+1; - do { - ARR(j-1) = ARR(j); - IDX(j-1) = IDX(j); - j = j+1; - } while(j < elements && ARR(j) < piv); - ARR(j-1) = piv; - IDX(j-1) = pid; - } - } -} - -static void THTensor_(quicksortdescend)(real *arr, int64_t *idx, int64_t elements, int64_t stride) -{ - int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left; - real rswap, piv; - unsigned char done = 0; - - /* beg[0]=0; end[0]=elements; */ - stack = 0; - L = 0; R = elements-1; - done = elements-1 <= M_SMALL; - - while(!done) { - /* Use median of three for pivot choice */ - P=(L+R)>>1; - BOTH_SWAP(P, L+1); - if (ARR(L+1) < ARR(R)) { BOTH_SWAP(L+1, R); } - if (ARR(L) < ARR(R)) { BOTH_SWAP(L, R); } - if (ARR(L+1) < ARR(L)) { BOTH_SWAP(L+1, L); } - - i = L+1; j = R; piv = ARR(L); pid = IDX(L); - - do { - do { i = i+1; } while(ARR(i) > piv); - do { j = j-1; } while(ARR(j) < piv); - if (j < i) - break; - BOTH_SWAP(i, j); - } while(1); - BOTH_SWAP(L, j); - /* Left subfile is (L, j-1) */ - /* Right subfile is (i, R) */ - sz_left = j-L; - sz_right = R-i+1; - if (sz_left <= M_SMALL && sz_right <= M_SMALL) { - /* both subfiles are small */ - /* if stack empty */ - if (stack == 0) { - done = 1; - } else { - stack--; - L = beg[stack]; - R = end[stack]; - } - } else if (sz_left <= M_SMALL || sz_right <= M_SMALL) { - /* exactly one of the subfiles is small */ - /* (L,R) = large subfile */ - if (sz_left > sz_right) { - /* Implicit: L = L; */ - R = j-1; - } else { - L = i; - /* Implicit: R = R; */ - } - } else { - /* none of the subfiles is small */ - /* push large subfile */ - /* (L,R) = small subfile */ - if (sz_left > sz_right) { - beg[stack] = L; - end[stack] = j-1; - stack++; - L = i; - /* Implicit: R = R */ - } else { - beg[stack] = i; - end[stack] = R; - stack++; - /* Implicit: L = L; */ - R = j-1; - } - } - } /* while not done */ - /* Now insertion sort on the concatenation of subfiles */ - for(i=elements-2; i>=0; i--) { - if (ARR(i) < ARR(i+1)) { - piv = ARR(i); - pid = IDX(i); - j = i+1; - do { - ARR(j-1) = ARR(j); - IDX(j-1) = IDX(j); - j = j+1; - } while(j < elements && ARR(j) > piv); - ARR(j-1) = piv; - IDX(j-1) = pid; - } - } -} - -#undef MAX_LEVELS -#undef M_SMALL - -void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder) -{ - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d", - dimension + TH_INDEX_BASE); - - THTensor_(resizeAs)(rt_, t); - THTensor_(copy)(rt_, t); - - { - THLongStorage *size = THTensor_(newSizeOf)(t); - THLongTensor_resize(ri_, size, NULL); - THLongStorage_free(size); - } - - if(descendingOrder) - { - TH_TENSOR_DIM_APPLY2(real, rt_, int64_t, ri_, dimension, - int64_t i; - for(i = 0; i < ri__size; i++) - ri__data[i*ri__stride] = i; - THTensor_(quicksortdescend)(rt__data, ri__data, rt__size, rt__stride);) - } - else - { - TH_TENSOR_DIM_APPLY2(real, rt_, int64_t, ri_, dimension, - int64_t i; - for(i = 0; i < ri__size; i++) - ri__data[i*ri__stride] = i; - THTensor_(quicksortascend)(rt__data, ri__data, rt__size, rt__stride);) - } -} - -/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's -public domain implementation at http://ndevilla.free.fr/median/median/ -Adapted similarly to the above Quicksort algorithm. -This version does not produce indices along with values. */ -static void THTensor_(quickselectnoidx)(real *arr, int64_t k, int64_t elements, int64_t stride) -{ - int64_t P, L, R, i, j; - real rswap, piv; - L = 0; - R = elements-1; - - do { - if (R <= L) /* One element only */ - return; - - if (R == L+1) { /* Two elements only */ - if (ARR(L) > ARR(R)) { - ARR_SWAP(L, R); - } - return; - } - - /* Use median of three for pivot choice */ - P=(L+R)>>1; - ARR_SWAP(P, L+1); - if (ARR(L+1) > ARR(R)) { ARR_SWAP(L+1, R); } - if (ARR(L) > ARR(R)) { ARR_SWAP(L, R); } - if (ARR(L+1) > ARR(L)) { ARR_SWAP(L+1, L); } - - i = L+1; - j = R; - piv = ARR(L); - do { - do i++; while(ARR(i) < piv); - do j--; while(ARR(j) > piv); - if (j < i) - break; - ARR_SWAP(i, j); - } while(1); - ARR_SWAP(L, j); - - /* Re-set active partition */ - if (j <= k) L=i; - if (j >= k) R=j-1; - } while(1); -} - -/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's -public domain implementation at http://ndevilla.free.fr/median/median/ -Adapted similarly to the above Quicksort algorithm. */ -static void THTensor_(quickselect)(real *arr, int64_t *idx, int64_t k, int64_t elements, int64_t stride) -{ - int64_t P, L, R, i, j, swap; - real rswap, piv; - L = 0; - R = elements-1; - - do { - if (R <= L) /* One element only */ - return; - - if (R == L+1) { /* Two elements only */ - if (ARR(L) > ARR(R)) { - BOTH_SWAP(L, R); - } - return; - } - - /* Use median of three for pivot choice */ - P=(L+R)>>1; - BOTH_SWAP(P, L+1); - if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); } - if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); } - if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); } - - i = L+1; - j = R; - piv = ARR(L); - do { - do i++; while(ARR(i) < piv); - do j--; while(ARR(j) > piv); - if (j < i) - break; - BOTH_SWAP(i, j); - } while(1); - BOTH_SWAP(L, j); - - /* Re-set active partition */ - if (j <= k) L=i; - if (j >= k) R=j-1; - } while(1); -} - -#undef ARR -#undef IDX -#undef LONG_SWAP -#undef REAL_SWAP -#undef BOTH_SWAP - -void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) -{ - THLongStorage *dim; - THTensor *temp_; - THLongTensor *tempi_; - real *temp__data; - int64_t *tempi__data; - int64_t t_size_dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); - - int in_dims = THTensor_(_nDimension)(t); - THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); - THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(values_, dim, NULL); - THLongTensor_resize(indices_, dim, NULL); - THLongStorage_free(dim); - - t_size_dim = THTensor_(size)(t, dimension); - - temp_ = THTensor_(new)(); - THTensor_(resize1d)(temp_, t_size_dim); - temp__data = THTensor_(data)(temp_); - - tempi_ = THLongTensor_new(); - THLongTensor_resize1d(tempi_, t_size_dim); - tempi__data = THLongTensor_data(tempi_); - - TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - int64_t i; - real mode = 0; - int64_t modei = 0; - int64_t temp_freq = 0; - int64_t max_freq = 0; - for(i = 0; i < t_size_dim; i++) - temp__data[i] = t_data[i*t_stride]; - for(i = 0; i < t_size_dim; i++) - tempi__data[i] = i; - THTensor_(quicksortascend)(temp__data, tempi__data, t_size_dim, 1); - - for(i = 0; i < t_size_dim; i++) - { - temp_freq++; - if ((i == t_size_dim - 1) || (temp__data[i] != temp__data[i+1])) - { - if (temp_freq > max_freq) - { - mode = temp__data[i]; - modei = tempi__data[i]; - max_freq = temp_freq; - } - temp_freq = 0; - } - } - *values__data = mode; - *indices__data = modei;); - - THTensor_(free)(temp_); - THLongTensor_free(tempi_); - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } -} - -void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim) -{ - THLongStorage *dim; - THTensor *temp_; - THLongTensor *tempi_; - real *temp__data; - int64_t *tempi__data; - int64_t t_size_dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); - THArgCheck(k > 0 && k <= t->size(dimension), 2, "selected index out of range"); - - int in_dims = THTensor_(_nDimension)(t); - THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); - THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(values_, dim, NULL); - THLongTensor_resize(indices_, dim, NULL); - THLongStorage_free(dim); - - t_size_dim = THTensor_(size)(t, dimension); - - temp_ = THTensor_(new)(); - THTensor_(resize1d)(temp_, t_size_dim); - temp__data = THTensor_(data)(temp_); - - tempi_ = THLongTensor_new(); - THLongTensor_resize1d(tempi_, t_size_dim); - tempi__data = THLongTensor_data(tempi_); - - TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - int64_t i; - for(i = 0; i < t_size_dim; i++) - temp__data[i] = t_data[i*t_stride]; - for(i = 0; i < t_size_dim; i++) - tempi__data[i] = i; - THTensor_(quickselect)(temp__data, tempi__data, k - 1, t_size_dim, 1); - *values__data = temp__data[k-1]; - *indices__data = tempi__data[k-1];); - - THTensor_(free)(temp_); - THLongTensor_free(tempi_); - if (!keepdim) { - THTensor_(squeeze1d)(values_, values_, dimension); - THLongTensor_squeeze1d(indices_, indices_, dimension); - } -} - -void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) -{ - int64_t t_size_dim, k; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); - - t_size_dim = THTensor_(size)(t, dimension); - k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */ - - THTensor_(kthvalue)(values_, indices_, t, k+1, dimension, keepdim); -} - -void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted) -{ -#ifndef USE_TH_SIZE_ZERO_DIM - int numDims = THTensor_(_nDimension)(t); -#else - int numDims = THTensor_(nDimension)(t); -#endif - THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range"); - - int64_t sliceSize = THTensor_(size)(t, dim); -#ifndef USE_TH_SIZE_ZERO_DIM - THArgCheck(k > 0 && k <= sliceSize, 2, "k not in range for dimension"); -#else - THArgCheck(k >= 0 && k <= sliceSize, 2, "k not in range for dimension"); -#endif - - THTensor *tmpResults = THTensor_(new)(); - THTensor_(resize1d)(tmpResults, sliceSize); - real *tmp__data = THTensor_(data)(tmpResults); - - THLongTensor *tmpIndices = THLongTensor_new(); - THLongTensor_resize1d(tmpIndices, sliceSize); - int64_t *tmpi__data = THLongTensor_data(tmpIndices); - - THLongStorage *topKSize = THTensor_(newSizeOf)(t); - THLongStorage_set(topKSize, dim, k); - THTensor_(resize)(rt_, topKSize, NULL); - THLongTensor_resize(ri_, topKSize, NULL); - THLongStorage_free(topKSize); - - if (dir) { - /* k largest elements, descending order (optional: see sorted) */ - int64_t K = sliceSize - k; - TH_TENSOR_DIM_APPLY3(real, t, real, rt_, int64_t, ri_, dim, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - int64_t i; - for(i = 0; i < sliceSize; i++) - { - tmp__data[i] = t_data[i*t_stride]; - tmpi__data[i] = i; - } - if (K > 0) - THTensor_(quickselect)(tmp__data, tmpi__data, K - 1, sliceSize, 1); - if (sorted) - THTensor_(quicksortdescend)(tmp__data + K, tmpi__data + K, k, 1); - for(i = 0; i < k; i++) - { - rt__data[i*rt__stride] = tmp__data[i + K]; - ri__data[i*ri__stride] = tmpi__data[i + K]; - }) - } - else { - /* k smallest elements, ascending order (optional: see sorted) */ - TH_TENSOR_DIM_APPLY3(real, t, real, rt_, int64_t, ri_, dim, - TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, - int64_t i; - for(i = 0; i < sliceSize; i++) - { - tmp__data[i] = t_data[i*t_stride]; - tmpi__data[i] = i; - } - THTensor_(quickselect)(tmp__data, tmpi__data, k - 1, sliceSize, 1); - if (sorted) - THTensor_(quicksortascend)(tmp__data, tmpi__data, k - 1, 1); - for(i = 0; i < k; i++) - { - rt__data[i*rt__stride] = tmp__data[i]; - ri__data[i*ri__stride] = tmpi__data[i]; - }) - } - - THTensor_(free)(tmpResults); - THLongTensor_free(tmpIndices); -} - -void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k) -{ - int64_t t_size_0, t_size_1; - int64_t t_stride_0, t_stride_1; - int64_t r__stride_0, r__stride_1; - real *t_data, *r__data; - int64_t r, c; - - THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); - - THTensor_(resizeAs)(r_, t); - - t_size_0 = THTensor_(size)(t, 0); - t_size_1 = THTensor_(size)(t, 1); - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - r__stride_0 = THTensor_(stride)(r_, 0); - r__stride_1 = THTensor_(stride)(r_, 1); - r__data = THTensor_(data)(r_); - t_data = THTensor_(data)(t); - - for(r = 0; r < t_size_0; r++) - { - int64_t sz = THMin(r+k+1, t_size_1); - for(c = THMax(0, r+k+1); c < t_size_1; c++) - r__data[r*r__stride_0+c*r__stride_1] = 0; - for(c = 0; c < sz; c++) - r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; - } -} - -void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k) -{ - int64_t t_size_0, t_size_1; - int64_t t_stride_0, t_stride_1; - int64_t r__stride_0, r__stride_1; - real *t_data, *r__data; - int64_t r, c; - - THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); - - THTensor_(resizeAs)(r_, t); - - t_size_0 = THTensor_(size)(t, 0); - t_size_1 = THTensor_(size)(t, 1); - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - r__stride_0 = THTensor_(stride)(r_, 0); - r__stride_1 = THTensor_(stride)(r_, 1); - r__data = THTensor_(data)(r_); - t_data = THTensor_(data)(t); - - for(r = 0; r < t_size_0; r++) - { - int64_t sz = THMin(r+k, t_size_1); - for(c = THMax(0, r+k); c < t_size_1; c++) - r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; - for(c = 0; c < sz; c++) - r__data[r*r__stride_0+c*r__stride_1] = 0; - } -} - -void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension) -{ - THTensor* inputs[2]; - inputs[0] = ta; - inputs[1] = tb; - THTensor_(catArray)(r_, inputs, 2, dimension); -} - -void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension); -inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension) -{ - int first_dims = first->dim(); - int second_dims = second->dim(); - THArgCheck(first_dims == second_dims, 0, - "Tensors must have same number of dimensions: got %d and %d", - first_dims, second_dims); - for (int dim = 0; dim < first_dims; dim++) { - if (dim == dimension) { - continue; - } - int64_t first_dim_size = first->size(dim); - int64_t second_dim_size = second->size(dim); - THArgCheck(first_dim_size == second_dim_size, 0, - "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d", - dimension, (long long)first_dim_size, (long long)second_dim_size, dim); - } -} - -void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension) -{ - // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible - // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors - // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific - // size (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - bool allSkipped= true; - int64_t nDims = 0; - THTensor *notSkippedTensor; // non-owning reference - auto should_skip = [](THTensor *t) { return t->is_empty() && t->dim() == 1; }; - for (int i = 0; i < numInputs; i++) { - if (should_skip(inputs[i])) { - continue; - } - // We've found a non-empty tensor - allSkipped = false; - notSkippedTensor = inputs[i]; - nDims = notSkippedTensor->dim(); - break; - } - if (allSkipped) { - return; - } - - // Compute cat_dimension based on the non-empty tensor - THArgCheck(dimension < nDims, 4, "invalid dimension %d", dimension); - THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); - - // Compute size of the result in the cat dimension - int64_t cat_dim_size = 0; - for (int i = 0; i < numInputs; i++) { - THTensor *tensor = inputs[i]; - if (should_skip(tensor)) { - continue; - } - THTensor_(check_shape_except_dim)(notSkippedTensor, tensor, dimension); - cat_dim_size += tensor->size(dimension); - } - - // Compute the size of the result - THLongStorage *size = THLongStorage_newWithSize(nDims); - for (int dim = 0; dim < nDims; dim++) { - int64_t result_dim_size = notSkippedTensor->size(dim); - if (dim == dimension) { - result_dim_size = cat_dim_size; - } - THLongStorage_data(size)[dim] = result_dim_size; - } - THTensor_(resize)(result, size, NULL); - - // Check contiguity of all inputs and result - bool allContiguous = true; - for (int i = 0; i < numInputs; i++) { - if(!should_skip(inputs[i])) { - allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]); - } - } - allContiguous = allContiguous && THTensor_(isContiguous)(result); - - // First path is for contiguous inputs along dim 0 - // Second path for non-contiguous - int64_t offset; - if (dimension == 0 && allContiguous) { - real* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset(); - offset = 0; - for (int j = 0; j < numInputs; j++) { - if (!should_skip(inputs[j])) { - THTensor* input0 = inputs[j]; - real* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset(); - int64_t input0_size = THTensor_(nElement)(input0); - // C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this. - if (input0_size != 0) { - memcpy(result_data + offset, input0_data, input0_size*sizeof(real)); - } - offset += input0_size; - } - } - } else { - offset = 0; - for (int j = 0; j < numInputs; j++) { - if (!should_skip(inputs[j])) { - int64_t dimSize = inputs[j]->size(dimension); - THTensor *nt = THTensor_(newWithTensor)(result); - THTensor_(narrow)(nt, NULL, dimension, offset, dimSize); - THTensor_(copy)(nt, inputs[j]); - THTensor_(free)(nt); - offset += dimSize; - } - } - } - THLongStorage_free(size); -} - -int THTensor_(equal)(THTensor *ta, THTensor* tb) -{ - int equal = 1; - if(!THTensor_(isSameSizeAs)(ta, tb)) - return 0; - - if (THTensor_(isContiguous)(ta) && THTensor_(isContiguous)(tb)) { - real *tap = THTensor_(data)(ta); - real *tbp = THTensor_(data)(tb); - ptrdiff_t sz = THTensor_(nElement)(ta); - ptrdiff_t i; - for (i=0; idim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(unsigned char, r_, real, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, real value) \ - { \ - THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(real, r_, real, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(unsigned char, r_, real, ta, real, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ - } \ - void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(real, r_, real, ta, real, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ - } \ - - -TENSOR_IMPLEMENT_LOGICAL(lt,<) -TENSOR_IMPLEMENT_LOGICAL(gt,>) -TENSOR_IMPLEMENT_LOGICAL(le,<=) -TENSOR_IMPLEMENT_LOGICAL(ge,>=) -TENSOR_IMPLEMENT_LOGICAL(eq,==) -TENSOR_IMPLEMENT_LOGICAL(ne,!=) - - -#ifdef _OPENMP - -#define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD) \ - void THTensor_(NAME)(THTensor *r_, THTensor *t) \ - { \ - THTensor_(resizeAs)(r_, t); \ - ptrdiff_t r_Size = THTensor_(nElement)(r_); \ - int r_Contig = THTensor_(isContiguous)(r_); \ - int tContig = THTensor_(isContiguous)(t); \ - int inOMP = omp_in_parallel(); \ - if( !inOMP ){ \ - TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = CFUNC(*t_data);, OMP_THRESHOLD); \ - } else { \ - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = CFUNC(*t_data);); \ - } \ - } - -#define LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) \ - LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD) - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD) \ - void THTensor_(NAME)(THTensor *r_, THTensor *t) \ - { \ - THTensor_(resizeAs)(r_, t); \ - ptrdiff_t r_Size = THTensor_(nElement)(r_); \ - int r_Contig = THTensor_(isContiguous)(r_); \ - int tContig = THTensor_(isContiguous)(t); \ - if (r_Contig && tContig) { \ - TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(NAME)(r__data, t_data, r__len);); \ - } else { \ - int inOMP = omp_in_parallel(); \ - if( !inOMP ){ \ - TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = CFUNC(*t_data);, OMP_THRESHOLD); \ - } \ - else { \ - TH_TENSOR_APPLY2(real, r_, real, t, *r__data = CFUNC(*t_data);); \ - } \ - } \ - } - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) \ - LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD) - -#else - -#define LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) \ - void THTensor_(NAME)(THTensor *r_, THTensor *t) \ - { \ - THTensor_(resizeAs)(r_, t); \ - TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ - } \ - -#define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, PSEUDO_OMP_THRESHOLD) \ - LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) \ - void THTensor_(NAME)(THTensor *r_, THTensor *t) \ - { \ - THTensor_(resizeAs)(r_, t); \ - int r_Contig = THTensor_(isContiguous)(r_); \ - int tContig = THTensor_(isContiguous)(t); \ - if (r_Contig && tContig) { \ - TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(NAME)(r__data, t_data, r__len);); \ - } else { \ - TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ - } \ - } \ - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, PSEUDO_OMP_THRESHOLD) \ - LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) - -#endif - -#define EXPAND(...) __VA_ARGS__ - -#define GET_4TH_ARG(ARG0, ARG1, ARG2, ARG3, ...) ARG3 - -#define LAB_IMPLEMENT_BASIC_FUNCTION_CHOOSE(...) \ - EXPAND(GET_4TH_ARG(__VA_ARGS__, LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS, LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS, )) - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_CHOOSE(...) \ - EXPAND(GET_4TH_ARG(__VA_ARGS__, LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS, LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS, )) - -#define LAB_IMPLEMENT_BASIC_FUNCTION(...) EXPAND(LAB_IMPLEMENT_BASIC_FUNCTION_CHOOSE(__VA_ARGS__)(__VA_ARGS__)) - -#define LAB_IMPLEMENT_VECTORIZED_FUNCTION(...) EXPAND(LAB_IMPLEMENT_VECTORIZED_FUNCTION_CHOOSE(__VA_ARGS__)(__VA_ARGS__)) - -/* - * LAB_IMPLEMENT_BASIC_FUNCTION is a macro with optional parameters, you can use it flexibly. - * The macro will discard the invalid openmp threshold if openmp is unavailable. The macro will give a default threshold even if you forget to pass one. - * In other word, - * (A), If openmp is UNavailable, the two usage below is both right. - * (1) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity, OMP_OVERHEAD_THRESHOLD) // discard the invalid openmp threshold - * (2) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity) - * (B), If openmp is available, the two usage below is also both right. - * (1) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity, OMP_OVERHEAD_THRESHOLD) - * (2) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity) // pass the default openmp threshold - * So do LAB_IMPLEMENT_VECTORIZED_FUNCTION. -*/ - -LAB_IMPLEMENT_BASIC_FUNCTION(neg,-) - -#if defined(TH_REAL_IS_LONG) -LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs) -#endif /* int64_t only part */ - -#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) -LAB_IMPLEMENT_BASIC_FUNCTION(abs,abs) -#endif /* int only part */ - -#if defined(TH_REAL_IS_BYTE) - -int THTensor_(logicalAndAll)(THTensor *tensor) -{ - real prod = 1; - int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if(inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, &&:prod, prod = prod && *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, tContig, srcContig, real, r_, real, t, real, src, *r__data = *t_data | *src_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } #else - serial_path = 1; + serial_path = 1; #endif + } + } else { + serial_path = 1; + } if (serial_path) { - TH_TENSOR_APPLY(real, tensor, prod = prod && *tensor_data;); + TH_TENSOR_APPLY3(real, r_, real, t, real, src, *r__data = *t_data | *src_data;); } - return prod; +#endif } -int THTensor_(logicalAnyAll)(THTensor *tensor) +void THTensor_(cbitxor)(THTensor *r_, THTensor *t, THTensor *src) { - real sum = 0; +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_HALF) + (void)r_; + (void)t; + (void)src; + return THError("cbitxor is only supported for integer type tensors"); +#else + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int64_t srcSize = THTensor_(nElement)(src); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); + int srcContig = THTensor_(isContiguous)(src); int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if(inOMP) { - serial_path = 1; - } else { - TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, ||:sum, sum = sum || *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); - } + if (srcSize == r_Size){ + if (r_Contig && tContig && srcContig) { + real *tp = THTensor_(data)(t); + real *sp = THTensor_(data)(src); + real *rp = THTensor_(data)(r_); + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; i= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - + THTensor_(resizeAs)(r_, t); + int64_t r_Size = THTensor_(nElement)(r_); + int r_Contig = THTensor_(isContiguous)(r_); + int tContig = THTensor_(isContiguous)(t); int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - int r_Contig = THTensor_(isContiguous)(r_); + if (r_Contig && tContig) { real *tp = THTensor_(data)(t); real *rp = THTensor_(data)(r_); - if(r_Contig && (tp != rp)){ - ptrdiff_t iter = 0; - ptrdiff_t r_Size = THTensor_(nElement)(r_); - int r_Dim = r_->_dim(); - #pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD) - for (iter = 0; iter < r_Size; iter++) { - int j; - int64_t quot; - int64_t rem = iter; - ptrdiff_t tBasicIndex = 0; - - for(j = 0; j < r_Dim; ++j) { - if(j != dimension){ - quot = rem/r_->stride(j); - rem = rem%r_->stride(j); - tBasicIndex += quot*t->stride(j); - } - } - real *t_data = tp+tBasicIndex; - real *r__data = rp+iter; - *r__data = 1; - for(j=0; j < t->size(dimension); ++j) { - *r__data = *r__data && *(t_data + j*t->stride(dimension)); - } - } - } else { + int64_t i; + #pragma omp parallel for if(r_Size > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for (i=0; istride(dimension) == 1) { - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal prod = 1; - int64_t i; - for(i = 0; i < t_size; i++) - prod = prod && t_data[i*t_stride]; - *r__data = (real)prod;); - } else { - THTensor_(fill)(r_, 1); - THTensor *temp_ = THTensor_(newWithTensor)(r_); - // r_.expand_as(t) - THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(temp_, dimension, 0); - - TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data && *t_data;); - THTensor_(free)(temp_); - } } - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); + if (serial_path) { + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = THTensor_(powOne)(value, *t_data);); } } -void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim) +void THTensor_(addcmul)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2) { - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + int64_t r_Size = THTensor_(nElement)(r_); + int64_t src1Size = THTensor_(nElement)(src1); + int64_t src2Size = THTensor_(nElement)(src2); + int r_Contig = THTensor_(isContiguous)(r_); + int src1Contig = THTensor_(isContiguous)(src1); + int src2Contig = THTensor_(isContiguous)(src2); int serial_path = 0; -#ifdef _OPENMP - int inOMP = omp_in_parallel(); - if (inOMP) { - serial_path = 1; - } else { - int r_Contig = THTensor_(isContiguous)(r_); - real *tp = THTensor_(data)(t); - real *rp = THTensor_(data)(r_); - if(r_Contig && (tp != rp)){ - ptrdiff_t iter = 0; - ptrdiff_t r_Size = THTensor_(nElement)(r_); - int r_Dim = r_->_dim(); - #pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD) - for (iter = 0; iter < r_Size; iter++) { - int j; - int64_t quot; - int64_t rem = iter; - ptrdiff_t tBasicIndex = 0; - - for(j = 0; j < r_Dim; ++j) { - if(j != dimension){ - quot = rem/r_->stride(j); - rem = rem%r_->stride(j); - tBasicIndex += quot*t->stride(j); - } - } - real *t_data = tp+tBasicIndex; - real *r__data = rp+iter; - *r__data = 0; - for(j=0; j < t->size(dimension); ++j) { - *r__data = *r__data || *(t_data + j*t->stride(dimension)); - } - } - } else { + if( (src1Size == src2Size) && (src1Size == r_Size) ){ +#if _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { serial_path = 1; + } else { + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, src1Contig, src2Contig, real, r_, real, src1, real, src2, *r__data += value * *src1_data * *src2_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); } - } #else - serial_path = 1; + (void)r_Contig; + (void)src1Contig; + (void)src2Contig; + serial_path = 1; #endif - if (serial_path) { - // two implementations optimized for data locality - if (t->stride(dimension) == 1) { - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal sum = 0; - int64_t i; - for(i = 0; i < t_size; i++) - sum = sum || t_data[i*t_stride]; - *r__data = (real)sum;); - } else { - THTensor_(zero)(r_); - THTensor *temp_ = THTensor_(newWithTensor)(r_); - // r_.expand_as(t) - THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); - THTensor_setStrideAtDim(temp_, dimension, 0); - - TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data || *t_data;); - THTensor_(free)(temp_); - } + } else { + serial_path = 1; } - - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); + if (serial_path) { + TH_TENSOR_APPLY3(real, r_, real, src1, real, src2, *r__data += value * *src1_data * *src2_data;); } } -#endif /* Byte only part */ - -/* floating point only now */ -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) - -#if defined (TH_REAL_IS_FLOAT) -#define TH_MATH_NAME(fn) fn##f +void THTensor_(addcdiv)(THTensor *r_, THTensor *t, real value, THTensor *src1, THTensor *src2) +{ + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } + int64_t r_Size = THTensor_(nElement)(r_); + int64_t src1Size = THTensor_(nElement)(src1); + int64_t src2Size = THTensor_(nElement)(src2); + int r_Contig = THTensor_(isContiguous)(r_); + int src1Contig = THTensor_(isContiguous)(src1); + int src2Contig = THTensor_(isContiguous)(src2); + int serial_path = 0; + if( (src1Size == src2Size) && (src1Size == r_Size) ){ +#if _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY3_OMP(r_Size, r_Contig, src1Contig, src2Contig, real, r_, real, src1, real, src2, *r__data += value * *src1_data / *src2_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } #else -#define TH_MATH_NAME(fn) fn + (void)r_Contig; + (void)src1Contig; + (void)src2Contig; + serial_path = 1; #endif - -LAB_IMPLEMENT_BASIC_FUNCTION(log,TH_MATH_NAME(log)) -LAB_IMPLEMENT_BASIC_FUNCTION(lgamma,TH_MATH_NAME(lgamma)) -LAB_IMPLEMENT_BASIC_FUNCTION(digamma,TH_MATH_NAME(TH_digamma)) -LAB_IMPLEMENT_BASIC_FUNCTION(trigamma,TH_MATH_NAME(TH_trigamma)) -LAB_IMPLEMENT_BASIC_FUNCTION(log10,TH_MATH_NAME(log10)) -LAB_IMPLEMENT_BASIC_FUNCTION(log1p,TH_MATH_NAME(log1p)) -LAB_IMPLEMENT_BASIC_FUNCTION(log2,TH_MATH_NAME(log2)) -LAB_IMPLEMENT_BASIC_FUNCTION(erf,TH_MATH_NAME(erf)) -LAB_IMPLEMENT_BASIC_FUNCTION(erfc,TH_MATH_NAME(erfc)) -LAB_IMPLEMENT_BASIC_FUNCTION(erfinv,TH_erfinv) -LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil)) -LAB_IMPLEMENT_BASIC_FUNCTION(floor,TH_MATH_NAME(floor)) -LAB_IMPLEMENT_BASIC_FUNCTION(round,TH_MATH_NAME(round)) -LAB_IMPLEMENT_BASIC_FUNCTION(abs,TH_MATH_NAME(fabs)) -LAB_IMPLEMENT_BASIC_FUNCTION(trunc,TH_MATH_NAME(trunc)) -LAB_IMPLEMENT_BASIC_FUNCTION(frac,TH_MATH_NAME(TH_frac)) -LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / ) - -LAB_IMPLEMENT_BASIC_FUNCTION(exp,TH_MATH_NAME(exp),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(expm1,TH_MATH_NAME(expm1),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(cos,TH_MATH_NAME(cos),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(acos,TH_MATH_NAME(acos),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(cosh,TH_MATH_NAME(cosh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(sin,TH_MATH_NAME(sin),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(asin,TH_MATH_NAME(asin),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt),HYPER_TH_OMP_OVERHEAD_THRESHOLD) - -LAB_IMPLEMENT_VECTORIZED_FUNCTION(sigmoid,TH_MATH_NAME(TH_sigmoid),HYPER_TH_OMP_OVERHEAD_THRESHOLD) - -void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty) -{ - THTensor_(resizeAs)(r_, tx); - TH_TENSOR_APPLY3(real, r_, real, tx, real, ty, *r__data = TH_MATH_NAME(atan2)(*tx_data,*ty_data);); -} - -void THTensor_(polygamma)(THTensor *r_, int64_t n, THTensor *t) { - switch (n) { - case 0: THTensor_(digamma)(r_, t); return; - case 1: THTensor_(trigamma)(r_, t); return; - default: THError("polygamma(n,x) is not implemented for n>=2"); + } else { + serial_path = 1; } -} - -void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight) -{ - THArgCheck(THTensor_(nElement)(a) == THTensor_(nElement)(b), 2, "sizes do not match"); - THTensor_(resizeAs)(r_, a); - TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_MATH_NAME(TH_lerp)(*a_data, *b_data, weight);); -} - -void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim) -{ - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "invalid dimension %d", - dimension + TH_INDEX_BASE); - - THTensor_(sum)(r_, t, dimension, keepdim); - THTensor_(div)(r_, r_, t->size(dimension)); -} - -void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) -{ - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - // Uses Welford's algorithm for numeric stability - accreal mean = 0; - accreal M2 = 0; - - int64_t i; - for (i = 0; i < t_size; i++) - { - real z = t_data[i*t_stride]; - real delta = z - mean; - mean += delta / (i + 1); - real delta2 = z - mean; - M2 += delta * delta2; - } - - if (biased && t_size >= 2) - { - *r__data = TH_MATH_NAME(sqrt)(M2 / t_size); - } else if (!biased && t_size >= 2) { - *r__data = TH_MATH_NAME(sqrt)(M2 / (t_size - 1)); - } else if (biased && t_size == 1) { - *r__data = 0; - } else { - *r__data = NAN; - }); - - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); + if (serial_path) { + TH_TENSOR_APPLY3(real, r_, real, src1, real, src2, *r__data += value * *src1_data / *src2_data;); } } -void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) +void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *mat, THTensor *vec) { - THLongStorage *dim; - - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", - dimension + TH_INDEX_BASE); - - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - // Uses Welford's algorithm for numeric stability - accreal mean = 0; - accreal M2 = 0; - - int64_t i; - for (i = 0; i < t_size; i++) - { - real z = t_data[i*t_stride]; - real delta = z - mean; - mean += delta / (i + 1); - real delta2 = z - mean; - M2 += delta * delta2; - } - - if (biased && t_size >= 2) - { - *r__data = M2 / t_size; - } else if (!biased && t_size >= 2) { - *r__data = M2 / (t_size - 1); - } else if (biased && t_size == 1) { - *r__data = 0; - } else { - *r__data = NAN; - }); + if( (mat->dim() != 2) || (vec->dim() != 1) ) + THError("matrix and vector expected, got %dD, %dD", + mat->dim(), vec->dim()); - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); + if( mat->size(1) != vec->size(0) ) { + THDescBuff bm = THTensor_(sizeDesc)(mat); + THDescBuff bv = THTensor_(sizeDesc)(vec); + THError("size mismatch, %s, %s", bm.str, bv.str); } -} -void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim) -{ - THLongStorage *dim; + if(t->dim() != 1) + THError("vector expected, got t: %dD", t->dim()); - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", - dimension + TH_INDEX_BASE); + if(t->size(0) != mat->size(0)) { + THDescBuff bt = THTensor_(sizeDesc)(t); + THDescBuff bm = THTensor_(sizeDesc)(mat); + THError("size mismatch, t: %s, mat: %s", bt.str, bm.str); + } - THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); + if(r_ != t) + { + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); + } - #define DIM_REDUCE(reduce, transform) \ - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, \ - accreal sum = 0; \ - int64_t i; \ - for(i = 0; i < t_size; i++) { \ - (reduce); \ - } \ - (transform);) \ + // n == 1 || lda >= max(1, m) + #define LDA_COND(M, N, LDA) ((N) == 1 || (LDA) >= THMax(1, (M))) - if(value == 0) { - DIM_REDUCE(sum += t_data[i*t_stride] != 0.0, - *r__data = sum); - } else if (value == 1) { - DIM_REDUCE(sum += TH_MATH_NAME(fabs)(t_data[i*t_stride]), - *r__data = sum); - } else if (value == 2) { - DIM_REDUCE(sum += t_data[i*t_stride] * t_data[i*t_stride], - *r__data = TH_MATH_NAME(sqrt)(sum)); - } else if (value == 3) { - DIM_REDUCE(sum += TH_MATH_NAME(fabs)(t_data[i*t_stride] * t_data[i*t_stride] * t_data[i*t_stride]), - *r__data = TH_MATH_NAME(pow)(sum, 1.0/3)); - } else if (value == INFINITY) { - DIM_REDUCE(sum = THMax(sum, TH_MATH_NAME(fabs)(t_data[i*t_stride])), - *r__data = sum); - } else { - DIM_REDUCE(sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(t_data[i*t_stride]), value), - *r__data = TH_MATH_NAME(pow)(sum, 1.0/value)); + if(mat->stride(0) == 1 && LDA_COND(mat->size(0), mat->size(1), mat->stride(1))) + { + THBlas_(gemv)('n', mat->size(0), mat->size(1), + alpha, THTensor_(data)(mat), mat->stride(1), + THTensor_(data)(vec), vec->stride(0), + beta, THTensor_(data)(r_), r_->stride(0)); } - - if (!keepdim) { - THTensor_(squeeze1d)(r_, r_, dimension); + else if(mat->stride(1) == 1 && LDA_COND(mat->size(1), mat->size(0), mat->stride(0))) + { + THBlas_(gemv)('t', mat->size(1), mat->size(0), + alpha, THTensor_(data)(mat), mat->stride(0), + THTensor_(data)(vec), vec->stride(0), + beta, THTensor_(data)(r_), r_->stride(0)); } - #undef DIM_REDUCE -} + else + { + THTensor *cmat = THTensor_(newContiguous)(mat); -accreal THTensor_(normall)(THTensor *tensor, real value) -{ - accreal sum = 0; - if(value == 0) { - TH_TENSOR_APPLY(real, tensor, sum += *tensor_data != 0.0;); - return sum; - } else if(value == 1) { - TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(fabs)(*tensor_data);); - return sum; - } else if(value == 2) { - TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += z*z;); - return sqrt(sum); - } else if(value == 3) { - TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += std::abs(z*z*z);); - return TH_MATH_NAME(pow)(sum, 1.0/3); - } else if(value == INFINITY) { - TH_TENSOR_APPLY(real, tensor, sum = THMax(sum, TH_MATH_NAME(fabs)(*tensor_data));); - return sum; - } else { - TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*tensor_data), value);); - return TH_MATH_NAME(pow)(sum, 1.0/value); + THBlas_(gemv)('t', mat->size(1), mat->size(0), + alpha, THTensor_(data)(cmat), cmat->stride(0), + THTensor_(data)(vec), vec->stride(0), + beta, THTensor_(data)(r_), r_->stride(0)); + + THTensor_(free)(cmat); } + + #undef LDA_COND } -void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension, real maxnorm) +void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain) { - THTensor *rowR, *rowS; - - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(src), 3, "invalid dimension %d", - dimension + TH_INDEX_BASE); - THArgCheck(value > 0, 2, "non-positive-norm not supported"); - THArgCheck(THTensor_(nDimension)(src) > 1, 1, "need at least 2 dimensions, got %d dimensions", - THTensor_(nDimension)(src)); - - rowR = THTensor_(new)(); - rowS = THTensor_(new)(); + int64_t N1 = m1->size(0); + int64_t N2 = m2->size(0); + int64_t dim; + real *m1_p; + real *m2_p; + real *r_p; + int64_t i; - THTensor_(resizeAs)(res, src); + THTensor_(resize2d)(r_, N1, N2); - for (int64_t i = 0; i < src->size(dimension); i++) - { - real norm = 0; - real new_norm; + m1 = THTensor_(newContiguous)(m1); + m2 = THTensor_(newContiguous)(m2); - THTensor_(select)(rowS, src, dimension, i); - THTensor_(select)(rowR, res, dimension, i); - if (value == 1) { - TH_TENSOR_APPLY(real, rowS, norm += fabs(*rowS_data);); - } else if (value == 2) { - TH_TENSOR_APPLY(real, rowS, accreal z = *rowS_data; norm += z*z;); - } else if (value == INFINITY) { - TH_TENSOR_APPLY(real, rowS, norm = THMax(norm, TH_MATH_NAME(fabs)(*rowS_data));); - } else { - TH_TENSOR_APPLY(real, rowS, norm += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*rowS_data), value);); - } + THTensor_(resize2d)(m1, N1, THTensor_(nElement)(m1) / N1); + THTensor_(resize2d)(m2, N2, THTensor_(nElement)(m2) / N2); - if (value != INFINITY) { - norm = pow(norm, 1/value); - } + dim = m1->size(1); + THArgCheck(m1->size(1) == m2->size(1), 3, "m1 and m2 must have the same inner vector dim"); - if (norm > maxnorm) - { - new_norm = maxnorm / (norm + 1e-7); + m1_p = THTensor_(data)(m1); + m2_p = THTensor_(data)(m2); + r_p = THTensor_(data)(r_); - TH_TENSOR_APPLY2( - real, rowR, real, rowS, - *rowR_data = (*rowS_data) * new_norm; - ) +#pragma omp parallel for private(i) + for (i=0; i(0, THTensor_(nElement)(tensor) - (biased ? 0 : 1)); - return sum; -} + if( (m1->dim() != 2) || (m2->dim() != 2)) + THError("matrices expected, got %dD, %dD tensors", m1->dim(), m2->dim()); -accreal THTensor_(stdall)(THTensor *tensor, int biased) -{ - return sqrt(THTensor_(varall)(tensor, biased)); -} + if(m1->size(1) != m2->size(0)) { + THDescBuff bm1 = THTensor_(sizeDesc)(m1); + THDescBuff bm2 = THTensor_(sizeDesc)(m2); + THError("size mismatch, m1: %s, m2: %s", bm1.str, bm2.str); + } -void THTensor_(linspace)(THTensor *r_, real a, real b, int64_t n) -{ - real i = 0; + if( t->dim() != 2 ) + THError("matrix expected, got %dD tensor for t", t->dim()); - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + if( (t->size(0) != m1->size(0)) || (t->size(1) != m2->size(1)) ) { + THDescBuff bt = THTensor_(sizeDesc)(t); + THDescBuff bm1 = THTensor_(sizeDesc)(m1); + THDescBuff bm2 = THTensor_(sizeDesc)(m2); + THError("size mismatch, t: %s, m1: %s, m2: %s", bt.str, bm1.str, bm2.str); + } - if (THTensor_(nElement)(r_) != n) { - THTensor_(resize1d)(r_, n); + if(t != r_) + { + THTensor_(resizeAs)(r_, t); + if (beta != 0.0) { + THTensor_(copy)(r_, t); + } } - if (n == 0) { - } else if (n == 1) { - THTensor_(set1d)(r_, 0, a); - } else { - TH_TENSOR_APPLY(real, r_, - *r__data = a + (b-a)/((real)(n-1))*i; - i++; - ); + // n == 1 || ldc >= max(1, m) + #define LDC_COND(M, N, LDC) ((N) == 1 || (LDC) >= THMax(1, M)) + + /* r_ */ + if(r_->stride(0) == 1 && + LDC_COND(r_->size(0), r_->size(1), r_->stride(1))) + { + transpose_r = 'n'; + r__ = r_; + } + else if(r_->stride(1) == 1 && + LDC_COND(r_->size(1), r_->size(0), r_->stride(0))) + { + THTensor *swap = m2; + m2 = m1; + m1 = swap; + transpose_r = 't'; + r__ = r_; + } + else + { + transpose_r = 'n'; + // make r__ FORTRAN contiguous + THTensor *transp_r_ = THTensor_(newTranspose)(r_, 0, 1); + r__ = THTensor_(newClone)(transp_r_); + THTensor_(free)(transp_r_); + THTensor_(transpose)(r__, NULL, 0, 1); } -} -void THTensor_(logspace)(THTensor *r_, real a, real b, int64_t n) -{ - real i = 0; + #undef LDC_COND - // NumPy allows you to pass different points even if n <= 1 -- should we? - THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + int64_t m = r__->size((transpose_r == 'n' ? 0 : 1)); + int64_t n = r__->size((transpose_r == 'n' ? 1 : 0)); + int64_t k = m1->size((transpose_r == 'n' ? 1 : 0)); + int64_t ldr__ = r__->stride((transpose_r == 'n' ? 1 : 0)); - if (THTensor_(nElement)(r_) != n) { - THTensor_(resize1d)(r_, n); + /* m1 */ + /* Need ldm1_ >= max(1, (transpose_m1 == 'n' ? m : k)) */ + if(m1->stride((transpose_r == 'n' ? 0 : 1)) == 1 && + m1->stride((transpose_r == 'n' ? 1 : 0)) >= THMax(1, m)) + { + transpose_m1 = 'n'; + m1_ = m1; } - - if (n == 0) { - } else if (n == 1) { - THTensor_(set1d)(r_, 0, TH_MATH_NAME(pow)(10.0, a)); - } else { - TH_TENSOR_APPLY(real, r_, - *r__data = TH_MATH_NAME(pow)(10.0, a + i*(b-a)/((real)(n-1))); - i++; - ); + else if(m1->stride((transpose_r == 'n' ? 1 : 0)) == 1 && + m1->stride((transpose_r == 'n' ? 0 : 1)) >= THMax(1, k)) + { + transpose_m1 = 't'; + m1_ = m1; + } + else + { + transpose_m1 = (transpose_r == 'n' ? 't' : 'n'); + m1_ = THTensor_(newContiguous)(m1); + free_m1 = 1; } -} - -void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minvalue, real maxvalue) -{ - real minval; - real maxval; - real *h_data; - THTensor_(resize1d)(hist, nbins); - THTensor_(zero)(hist); - minval = minvalue; - maxval = maxvalue; - if (minval == maxval) + /* m2 */ + /* Need ldm2_ >= max(1, (transpose_m2 == 'n' ? k : n)) */ + if(m2->stride((transpose_r == 'n' ? 0 : 1)) == 1 && + m2->stride((transpose_r == 'n' ? 1 : 0)) >= THMax(1, k)) + { + transpose_m2 = 'n'; + m2_ = m2; + } + else if(m2->stride((transpose_r == 'n' ? 1 : 0)) == 1 && + m2->stride((transpose_r == 'n' ? 0 : 1)) >= THMax(1, n)) { - minval = THTensor_(minall)(tensor); - maxval = THTensor_(maxall)(tensor); + transpose_m2 = 't'; + m2_ = m2; } - if (minval == maxval) + else { - minval = minval - 1; - maxval = maxval + 1; + transpose_m2 = (transpose_r == 'n' ? 't' : 'n'); + m2_ = THTensor_(newContiguous)(m2); + free_m2 = 1; } - h_data = THTensor_(data)(hist); + int64_t ldm1_ = (transpose_m1 == 'n' ? m1_->stride((transpose_r == 'n' ? 1 : 0)) : m1_->stride((transpose_r == 'n' ? 0 : 1))); + int64_t ldm2_ = (transpose_m2 == 'n' ? m2_->stride((transpose_r == 'n' ? 1 : 0)) : m2_->stride((transpose_r == 'n' ? 0 : 1))); - TH_TENSOR_APPLY(real, tensor, - if (*tensor_data >= minval && *tensor_data <= maxval) { - const int bin = (int)((*tensor_data-minval) / (maxval-minval) * nbins); - h_data[THMin(bin, nbins-1)] += 1; - } - ); +#pragma omp critical(blasgemm) + /* do the operation */ + THBlas_(gemm)(transpose_m1, + transpose_m2, + m, + n, + k, + alpha, + THTensor_(data)(m1_), + ldm1_, + THTensor_(data)(m2_), + ldm2_, + beta, + THTensor_(data)(r__), + ldr__); + + /* free intermediate variables */ + if(free_m1) + THTensor_(free)(m1_); + + if(free_m2) + THTensor_(free)(m2_); + + if(r__ != r_) + THTensor_(freeCopyTo)(r__, r_); } -void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minvalue, real maxvalue) +void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2) { - THArgCheck(THTensor_(_nDimension)(tensor) < 3, 2, "invalid dimension %d, the input must be a 2d tensor", THTensor_(_nDimension)(tensor)); - - int dimension = 1; - THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(tensor), 2, "invalid dimension %d", - dimension + TH_INDEX_BASE); + if( (vec1->dim() != 1) || (vec2->dim() != 1) ) + THError("vector and vector expected, got %dD, %dD tensors", + vec1->dim(), vec2->dim()); - real minval; - real maxval; + if(t->dim() != 2) + THError("expected matrix, got %dD tensor for t", t->dim()); - THTensor_(resize2d)(hist, tensor->size(0), nbins); - THTensor_(zero)(hist); + if( (t->size(0) != vec1->size(0)) || (t->size(1) != vec2->size(0)) ) { + THDescBuff bt = THTensor_(sizeDesc)(t); + THDescBuff bv1 = THTensor_(sizeDesc)(vec1); + THDescBuff bv2 = THTensor_(sizeDesc)(vec2); + THError("size mismatch, t: %s, vec1: %s, vec2: %s", bt.str, bv1.str, bv2.str); + } - minval = minvalue; - maxval = maxvalue; - if (minval == maxval) + if(r_ != t) { - minval = THTensor_(minall)(tensor); - maxval = THTensor_(maxall)(tensor); + THTensor_(resizeAs)(r_, t); + THTensor_(copy)(r_, t); } - if (minval == maxval) - { - minval = minval - 1; - maxval = maxval + 1; + + if(beta == 0) { + THTensor_(zero)(r_); } + else if(beta != 1) + THTensor_(mul)(r_, r_, beta); - TH_TENSOR_DIM_APPLY2(real, tensor, real, hist, dimension, int64_t i; - for(i = 0; i < tensor_size; i++) - { - if(tensor_data[i*tensor_stride] >= minval && tensor_data[i*tensor_stride] <= maxval) { - const int bin = (int)((tensor_data[i*tensor_stride]-minval) / (maxval-minval) * nbins); - hist_data[THMin(bin, nbins-1)] += 1; - } - } - ); -} + // n == 1 || lda >= max(1, m) + #define LDA_COND(M, N, LDA) ((N) == 1 || (LDA) >= THMax(1, (M))) -// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. -// Assumes x is close to zero and uses a Taylor expansion. -static inline real THTensor_(beta_grad_alpha_small)(real x, real alpha, real beta) { - const real factor = TH_MATH_NAME(TH_digamma)(alpha) - TH_MATH_NAME(TH_digamma)(alpha + beta) - TH_MATH_NAME(log)(x); - real numer = 1; - real series = numer / alpha * (factor + 1 / alpha); - for (int i = 1; i <= 10; ++i) { - numer *= (i - beta) * x / i; - const real denom = alpha + i; - series += numer / denom * (factor + 1 / denom); + if(r_->stride(0) == 1 && LDA_COND(vec1->size(0), vec2->size(0), r_->stride(1))) + { + THBlas_(ger)(vec1->size(0), vec2->size(0), + alpha, THTensor_(data)(vec1), vec1->stride(0), + THTensor_(data)(vec2), vec2->stride(0), + THTensor_(data)(r_), r_->stride(1)); } - const real result = x * TH_MATH_NAME(pow)(1 - x, -beta) * series; - return th_isnan(result) ? 0.0 : result; -} - -// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta. -// Assumes x is close to zero and uses a Taylor expansion. -static inline real THTensor_(beta_grad_beta_small)(real x, real alpha, real beta) { - const real factor = TH_MATH_NAME(TH_digamma)(alpha+beta) - TH_MATH_NAME(TH_digamma)(beta); - real numer = 1; - real betas = 1; - real dbetas = 0; - real series = factor / alpha; - for (int i = 1; i <= 8; ++i) { - numer *= -x / i; - dbetas = dbetas * (beta - i) + betas; - betas = betas * (beta - i); - series += numer / (alpha + i) * (dbetas + factor * betas); + else if(r_->stride(1) == 1 && LDA_COND(vec2->size(0), vec1->size(0), r_->stride(0))) + { + THBlas_(ger)(vec2->size(0), vec1->size(0), + alpha, THTensor_(data)(vec2), vec2->stride(0), + THTensor_(data)(vec1), vec1->stride(0), + THTensor_(data)(r_), r_->stride(0)); } - const real result = -TH_MATH_NAME(pow)(1 - x, 1 - beta) * series; - return th_isnan(result) ? 0.0 : result; -} + else + { + THTensor *cr = THTensor_(newClone)(r_); -// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. -// Assumes alpha and beta are both large and uses a Rice saddle point expansion. -// To ensure numerical stability, this computation is performed at higher precision. -static inline real THTensor_(beta_grad_alpha_mid)(double x, double alpha, double beta) { - const double total = alpha + beta; - const double mean = alpha / total; - const double std = sqrt(alpha * beta / (total + 1)) / total; - if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) { - // Avoid the singularity at x = mean. - const double poly = 47 * x * (beta*beta)*(beta*beta) + alpha * ( - (43 + 20 * (16 + 27 * beta) * x) * (beta*beta)*beta + alpha * ( - 3 * (59 + 180 * beta - 90 * x) * (beta*beta) + alpha * ( - (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * ( - 8 * (1 - x) * (135 * beta - 11))))); - const double prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total); - const double prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total); - return prefactor_num / (1 - x) * poly / prefactor_den; + THBlas_(ger)(vec2->size(0), vec1->size(0), + alpha, THTensor_(data)(vec2), vec2->stride(0), + THTensor_(data)(vec1), vec1->stride(0), + THTensor_(data)(cr), cr->stride(0)); + + THTensor_(freeCopyTo)(cr, r_); } - const double prefactor = -x / sqrt(2 * alpha * beta / total); - const double stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha*alpha)) - * (1 + 1 / (12 * beta) + 1 / (288 * beta*beta)) - / (1 + 1 / (12 * total) + 1 / (288 * total*total)); - const double term1_num = 2 * (alpha*alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta*beta); - const double axbx = alpha * (x-1) + beta * x; - const double term1_den = sqrt(2 * alpha / beta) * pow(total, 1.5f) * axbx*axbx; - const double term1 = term1_num / term1_den; - const double term2 = 0.5f * log(alpha / (total * x)); - const double term3_num = sqrt(8 * alpha * beta / total); - const double term3_den = beta * x + alpha * (x - 1); - const double term3 = term3_num / term3_den; - const double term4_base = beta * log(beta / (total * (1 - x))) + - alpha * log(alpha / (total * x)); - const double term4 = pow(term4_base, -1.5f); - const double term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4)); - return stirling * prefactor * term1234; -} -// Computes a scaled reparameterized gradient -// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x) -// for random number x drawn from a Beta distribution Beta(alpha,beta). -// This function inputs total=alpha+beta to make it easy to implement -// Dirichlet reparameterized gradients in terms of Betas. -static inline real THTensor_(dirichlet_grad_one)(real x, real alpha, real total) { - const real beta = total - alpha; - const real boundary = total * x * (1 - x); + #undef LDA_COND +} - // Use an asymptotic approximation for x close to 0. - if (x <= 0.5f && boundary < 2.5f) { - return THTensor_(beta_grad_alpha_small)(x, alpha, beta); - } +void THTensor_(addbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2) +{ + int64_t batch; - // Use an asymptotic approximation for x close to 1. - if (x >= 0.5f && boundary < 0.75f) { - return -THTensor_(beta_grad_beta_small)(1 - x, beta, alpha); - } + THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor"); + THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor"); + THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2, + "equal number of batches expected, got %d, %d", + THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0)); + THArgCheck(THTensor_(size)(batch1, 2) == THTensor_(size)(batch2, 1), 2, + "wrong matrix size, batch1: %dx%d, batch2: %dx%d", + THTensor_(size)(batch1, 1), THTensor_(size)(batch1,2), + THTensor_(size)(batch2, 1), THTensor_(size)(batch2,2)); - // Use an asymptotic approximation when alpha and (total - alpha) are both large. - if (alpha > 6 && beta > 6) { - return THTensor_(beta_grad_alpha_mid)(x, alpha, beta); - } + int64_t dim1 = THTensor_(size)(batch1, 1); + int64_t dim2 = THTensor_(size)(batch2, 2); + THArgCheck(THTensor_(size)(t, 0) == dim1, 1, "output tensor of incorrect size"); + THArgCheck(THTensor_(size)(t, 1) == dim2, 1, "output tensor of incorrect size"); - // Use a rational correction to an analytic approximation. - static const real c[2][3][3][4] = { - {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863}, - {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033}, - {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}}, - {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814}, - {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057}, - {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}}, - {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565}, - {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181}, - {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}}, - {{{1, -0.02924021934, -0.04438342661, 0.007285809825}, - {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521}, - {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}}, - {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273}, - {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956}, - {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}}, - {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05}, - {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05}, - {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}}, - }; - const real u = TH_MATH_NAME(log)(x); - const real a = TH_MATH_NAME(log)(alpha) - u; - const real b = TH_MATH_NAME(log)(total) - a; - const real pow_u[3] = {1, u, u * u}; - const real pow_a[3] = {1, a, a * a}; - real p = 0.0; - real q = 0.0; - for (int i = 0; i < 3; ++i) { - for (int j = 0; j < 3; ++j) { - const real ua = pow_u[i] * pow_a[j]; - p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3]))); - q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3]))); + if (t != result) { + THTensor_(resizeAs)(result, t); + if (beta != 0.0) { + THTensor_(copy)(result, t); } } - const real approx = x * (TH_MATH_NAME(TH_digamma)(total) - TH_MATH_NAME(TH_digamma)(alpha)) / beta; - return p / q * approx; -} -void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THTensor *total) -{ - x = THTensor_(newContiguous)(x); - alpha = THTensor_(newContiguous)(alpha); - total = THTensor_(newContiguous)(total); - TH_CHECK_SAME_SIZE(alpha, x); - TH_CHECK_SAME_SIZE(total, x); - THTensor_(resizeAs)(self, x); - THTensor* grad = THTensor_(newContiguous)(self); + THTensor *matrix1 = THTensor_(new)(); + THTensor *matrix2 = THTensor_(new)(); - real*const grad_data = THTensor_(data)(grad); - real*const x_data = THTensor_(data)(x); - real*const alpha_data = THTensor_(data)(alpha); - real*const total_data = THTensor_(data)(total); - const int64_t numel = THTensor_(nElement)(x); - int64_t i; - #pragma omp parallel for if(numel > TH_OMP_OVERHEAD_THRESHOLD) private(i) - for(i = 0; i < numel; ++i) { - grad_data[i] = THTensor_(dirichlet_grad_one)(x_data[i], alpha_data[i], total_data[i]); + for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) { + THTensor_(select)(matrix1, batch1, 0, batch); + THTensor_(select)(matrix2, batch2, 0, batch); + + THTensor_(addmm)(result, beta, result, alpha, matrix1, matrix2); + beta = 1; // accumulate output once } - THTensor_(freeCopyTo)(grad, self); + THTensor_(free)(matrix1); + THTensor_(free)(matrix2); } - -#undef TH_MATH_NAME -#endif /* floating point only part */ -#undef IS_NONZERO -#endif +#endif /* TH_GENERIC_FILE */ diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp new file mode 100644 index 0000000000000..69ca98f8787bd --- /dev/null +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -0,0 +1,2408 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensorMoreMath.cpp" +#else + +#include + +void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2) +{ + int64_t batch; + + THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch1)); + THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor, got %dD", THTensor_(nDimension)(batch2)); + THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2, + "equal number of batches expected, got %d, %d", + THTensor_(size)(batch1, 0), THTensor_(size)(batch2, 0)); + THArgCheck(THTensor_(size)(batch1, 2) == THTensor_(size)(batch2, 1), 2, + "wrong matrix size, batch1: %dx%d, batch2: %dx%d", + THTensor_(size)(batch1, 1), THTensor_(size)(batch1, 2), + THTensor_(size)(batch2, 1), THTensor_(size)(batch2, 2)); + + int64_t bs = THTensor_(size)(batch1, 0); + int64_t dim1 = THTensor_(size)(batch1, 1); + int64_t dim2 = THTensor_(size)(batch2, 2); + THArgCheck(THTensor_(size)(t, 0) == bs, 1, "output tensor of incorrect size"); + THArgCheck(THTensor_(size)(t, 1) == dim1, 1, "output tensor of incorrect size"); + THArgCheck(THTensor_(size)(t, 2) == dim2, 1, "output tensor of incorrect size"); + + if (t != result) { + THTensor_(resizeAs)(result, t); + if (beta != 0.0) { + THTensor_(copy)(result, t); + } + } + + THTensor *matrix1 = THTensor_(new)(); + THTensor *matrix2 = THTensor_(new)(); + THTensor *result_matrix = THTensor_(new)(); + + for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) { + THTensor_(select)(matrix1, batch1, 0, batch); + THTensor_(select)(matrix2, batch2, 0, batch); + THTensor_(select)(result_matrix, result, 0, batch); + + THTensor_(addmm)(result_matrix, beta, result_matrix, alpha, matrix1, matrix2); + } + + THTensor_(free)(matrix1); + THTensor_(free)(matrix2); + THTensor_(free)(result_matrix); +} + +ptrdiff_t THTensor_(numel)(THTensor *t) +{ + return THTensor_(nElement)(t); +} + + +// Helper function to be used in a reduction operation. +// Due to resize semantics of outputs, if the specified output tensor r_ has +// same size as the output of the reduction operation, then any noncontiguities +// in r_ should be preserved. +// The reduction operation, however, needs to act on r_ with an extra dimension +// (the reduced dimension), so this function "resizes" r_ and preserves its +// noncontiguities if necessary. +void THTensor_(preserveReduceDimSemantics)( + THTensor *r_, int in_dims, int reduce_dimension, int keepdim) { + if (r_ && !keepdim && + THTensor_(_nDimension)(r_) == in_dims - 1 && + THTensor_(_nDimension)(r_) != 0) { + THTensor_(unsqueeze1d)(r_, r_, reduce_dimension); + } +} + +void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + int in_dims = THTensor_(_nDimension)(t); + THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); + THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + real theMax; + real value; + int64_t theIndex; + int64_t i; + TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + theMax = t_data[0]; + theIndex = 0; + + for(i = 0; i < t_size; i++) + { + value = t_data[i*t_stride]; + /* This is not the same as value>theMax in the case of NaNs */ + if(!(value <= theMax)) + { + theIndex = i; + theMax = value; + th_isnan_break(value) + } + } + *indices__data = theIndex; + *values__data = theMax;); + } else { + if (THTensor_(_nDimension)(t) > 1) { + THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); + THTensor_(copy)(values_, t0); + THTensor_(free)(t0); + } else { + THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); + } + THLongTensor_zero(indices_); + + if(t->size(dimension) == 1) { + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } + return; + } + + THTensor *tempValues_ = THTensor_(newWithTensor)(values_); + // tempValues_.expand_as(t) + THTensor_setSizeAtDim(tempValues_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(tempValues_, dimension, 0); + + THLongTensor *tempIndices_ = THLongTensor_newWithTensor(indices_); + // tempIndices_.expand_as(t) + THTensor_setSizeAtDim(tempIndices_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(tempIndices_, dimension, 0); + + TH_TENSOR_APPLY3_D(real, t, real, tempValues_, int64_t, tempIndices_, dimension, + if(!(*t_data <= *tempValues__data) && !th_isnan(*tempValues__data)) { + *tempValues__data = *t_data; + *tempIndices__data = *tempIndices__dimOffset; + }); + + THTensor_(free)(tempValues_); + THLongTensor_free(tempIndices_); + } + + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } +} + +void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + int in_dims = THTensor_(_nDimension)(t); + THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); + THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + real theMax; + real value; + int64_t theIndex; + int64_t i; + TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + theMax = t_data[0]; + theIndex = 0; + + for(i = 0; i < t_size; i++) + { + value = t_data[i*t_stride]; + /* This is not the same as value>theMax in the case of NaNs */ + if(!(value >= theMax)) + { + theIndex = i; + theMax = value; + th_isnan_break(value) + } + } + *indices__data = theIndex; + *values__data = theMax;); + } else { + if (THTensor_(_nDimension)(t) > 1) { + THTensor *t0 = THTensor_(newSelect)(t, dimension, 0); + THTensor_(copy)(values_, t0); + THTensor_(free)(t0); + } else { + THTensor_(fill)(values_, THTensor_(get1d)(t, 0)); + } + THLongTensor_zero(indices_); + + if(t->size(dimension) == 1) { + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } + return; + } + + THTensor *tempValues_ = THTensor_(newWithTensor)(values_); + // tempValues_.expand_as(t) + THTensor_setSizeAtDim(tempValues_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(tempValues_, dimension, 0); + + THLongTensor *tempIndices_ = THLongTensor_newWithTensor(indices_); + // tempIndices_.expand_as(t) + THTensor_setSizeAtDim(tempIndices_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(tempIndices_, dimension, 0); + + TH_TENSOR_APPLY3_D(real, t, real, tempValues_, int64_t, tempIndices_, dimension, + if(!(*t_data >= *tempValues__data) && !th_isnan(*tempValues__data)) { + *tempValues__data = *t_data; + *tempIndices__data = *tempIndices__dimOffset; + }); + + THTensor_(free)(tempValues_); + THLongTensor_free(tempIndices_); + } + + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } +} + +void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + int r_Contig = THTensor_(isContiguous)(r_); + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + if(r_Contig && (tp != rp)){ + ptrdiff_t iter = 0; + ptrdiff_t r_Size = THTensor_(nElement)(r_); + int r_Dim = r_->_dim(); + #pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD) + for (iter = 0; iter < r_Size; iter++) { + int j; + int64_t quot; + int64_t rem = iter; + ptrdiff_t tBasicIndex = 0; + + for(j = 0; j < r_Dim; ++j) { + if(j != dimension){ + quot = rem/r_->stride(j); + rem = rem%r_->stride(j); + tBasicIndex += quot*t->stride(j); + } + } + real *t_data = tp+tBasicIndex; + real *r__data = rp+iter; + *r__data = 0; + for(j=0; j < t->size(dimension); ++j) { + *r__data += *(t_data + j*t->stride(dimension)); + } + } + } else { + serial_path = 1; + } + } +#else + serial_path = 1; +#endif + if (serial_path) { + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + int64_t i; + for(i = 0; i < t_size; i++) + sum += t_data[i*t_stride]; + *r__data = (real)sum;); + } else { + THTensor_(zero)(r_); + THTensor *temp_ = THTensor_(newWithTensor)(r_); + // r_.expand_as(t) + THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(temp_, dimension, 0); + + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;); + THTensor_(free)(temp_); + } + } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + int r_Contig = THTensor_(isContiguous)(r_); + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + if(r_Contig && (tp != rp)){ + ptrdiff_t iter = 0; + ptrdiff_t r_Size = THTensor_(nElement)(r_); + int r_Dim = r_->_dim(); + #pragma omp parallel for if ( r_Size > HYPER_TH_OMP_OVERHEAD_THRESHOLD) + for (iter = 0; iter < r_Size; iter++) { + int j; + int64_t quot; + int64_t rem = iter; + ptrdiff_t tBasicIndex = 0; + + for(j = 0; j < r_Dim; ++j) { + if(j != dimension){ + quot = rem/r_->stride(j); + rem = rem%r_->stride(j); + tBasicIndex += quot*t->stride(j); + } + } + real *t_data = tp+tBasicIndex; + real *r__data = rp+iter; + *r__data = 1; + for(j=0; j < t->size(dimension); ++j) { + *r__data *= *(t_data + j*t->stride(dimension)); + } + } + } else { + serial_path = 1; + } + } +#else + serial_path = 1; +#endif + + if(serial_path) { + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal prod = 1; + int64_t i; + for(i = 0; i < t_size; i++) + prod *= t_data[i*t_stride]; + *r__data = (real)prod;); + } else { + THTensor_(fill)(r_, 1); + THTensor *temp_ = THTensor_(newWithTensor)(r_); + // r_.expand_as(t) + THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(temp_, dimension, 0); + + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;); + THTensor_(free)(temp_); + } + } + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(resizeAs)(r_, t); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal cumsum = 0; + int64_t i; + for(i = 0; i < t_size; i++) + { + cumsum += t_data[i*t_stride]; + r__data[i*r__stride] = (real)cumsum; + }); +} + +void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(resizeAs)(r_, t); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal cumprod = 1; + int64_t i; + for(i = 0; i < t_size; i++) + { + cumprod *= t_data[i*t_stride]; + r__data[i*r__stride] = (real)cumprod; + }); +} + + +void THTensor_(sign)(THTensor *r_, THTensor *t) +{ + THTensor_(resizeAs)(r_, t); + +#if defined (TH_REAL_IS_BYTE) + TH_TENSOR_APPLY2(real, r_, real, t, + if (*t_data > 0) *r__data = 1; + else *r__data = 0;); +#else + TH_TENSOR_APPLY2(real, r_, real, t, + if (*t_data > 0) *r__data = 1; + else if (*t_data < 0) *r__data = -1; + else *r__data = 0;); +#endif +} + + +accreal THTensor_(trace)(THTensor *t) +{ + real *t_data = THTensor_(data)(t); + accreal sum = 0; + int64_t i = 0; + int64_t t_stride_0, t_stride_1, t_diag_size; + + THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); + + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + t_diag_size = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)); + while(i < t_diag_size) + { + sum += t_data[i*(t_stride_0+t_stride_1)]; + i++; + } + + return sum; +} + +void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension) +{ + int i; + + if(THTensor_(nDimension)(a) != THTensor_(nDimension)(b)) + THError("inconsistent tensor dimension %dD, %dD", + THTensor_(nDimension)(a), THTensor_(nDimension)(b)); + + for(i = 0; i < THTensor_(nDimension)(a); i++) + { + if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) { + THDescBuff ba = THTensor_(sizeDesc)(a); + THDescBuff bb = THTensor_(sizeDesc)(b); + THError("inconsistent tensor sizes %s, %s", ba.str, bb.str); + } + } + + if(dimension < 0) + { + for(i = 0; i < THTensor_(nDimension)(a); i++) + { + if(THTensor_(size)(a, i) == 3) + { + dimension = i; + break; + } + } + if(dimension < 0) { + THDescBuff ba = THTensor_(sizeDesc)(a); + THError("no dimension of size 3 in a: %s", ba.str); + } + } + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(a), 3, "dimension %d out of range", + dimension + TH_INDEX_BASE); + THArgCheck(THTensor_(size)(a, dimension) == 3, 3, "dimension %d does not have size 3", + dimension + TH_INDEX_BASE); + + THTensor_(resizeAs)(r_, a); + + TH_TENSOR_DIM_APPLY3(real, a, real, b, real, r_, dimension, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + r__data[0*r__stride] = a_data[1*a_stride]*b_data[2*b_stride] - a_data[2*a_stride]*b_data[1*b_stride]; + r__data[1*r__stride] = a_data[2*a_stride]*b_data[0*b_stride] - a_data[0*a_stride]*b_data[2*b_stride]; + r__data[2*r__stride] = a_data[0*a_stride]*b_data[1*b_stride] - a_data[1*a_stride]*b_data[0*b_stride];); +} + +void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src) { + THTensor_(resizeAs)(r, t); + TH_TENSOR_APPLY3(real, r, real, t, real, src, + *r_data = *t_data > *src_data ? *t_data : *src_data;); +} + +void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src) { + THTensor_(resizeAs)(r, t); + TH_TENSOR_APPLY3(real, r, real, t, real, src, + *r_data = *t_data < *src_data ? *t_data : *src_data;); +} + +void THTensor_(cmaxValue)(THTensor *r, THTensor *t, real value) { + THTensor_(resizeAs)(r, t); + TH_TENSOR_APPLY2(real, r, real, t, + *r_data = *t_data < value ? value : *t_data;); // this order propagates NaN +} + +void THTensor_(cminValue)(THTensor *r, THTensor *t, real value) { + THTensor_(resizeAs)(r, t); + TH_TENSOR_APPLY2(real, r, real, t, + *r_data = *t_data > value ? value : *t_data;); // this order propagates NaN +} + +void THTensor_(zerosLike)(THTensor *r_, THTensor *input) +{ + THTensor_(resizeAs)(r_, input); + THTensor_(zero)(r_); +} + +void THTensor_(onesLike)(THTensor *r_, THTensor *input) +{ + THTensor_(resizeAs)(r_, input); + THTensor_(fill)(r_, 1); +} + +void THTensor_(diag)(THTensor *r_, THTensor *t, int k) +{ +#ifndef USE_TH_SIZE_ZERO_DIM + AT_ASSERT(!t->is_empty()) +#endif + THArgCheck(THTensor_(nDimension)(t) == 1 || THTensor_(nDimension)(t) == 2, 1, "matrix or a vector expected"); + + if(THTensor_(nDimension)(t) == 1) + { + real *t_data = THTensor_(data)(t); + int64_t t_stride_0 = THTensor_(stride)(t, 0); + int64_t t_size = THTensor_(size)(t, 0); + int64_t sz = t_size + (k >= 0 ? k : -k); + real *r__data; + int64_t r__stride_0; + int64_t r__stride_1; + int64_t i; + + THTensor_(resize2d)(r_, sz, sz); + THTensor_(zero)(r_); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data += (k >= 0 ? k*r__stride_1 : -k*r__stride_0); + + for(i = 0; i < t_size; i++) + r__data[i*(r__stride_0+r__stride_1)] = t_data[i*t_stride_0]; + } + else + { + real *t_data = THTensor_(data)(t); + int64_t t_stride_0 = THTensor_(stride)(t, 0); + int64_t t_stride_1 = THTensor_(stride)(t, 1); + int64_t sz; + real *r__data; + int64_t r__stride_0; + int64_t i; + + if(k >= 0) + sz = THMin(THTensor_(size)(t, 0), THTensor_(size)(t, 1)-k); + else + sz = THMin(THTensor_(size)(t, 0)+k, THTensor_(size)(t, 1)); + THTensor_(resize1d)(r_, sz); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_, 0); + + t_data += (k >= 0 ? k*t_stride_1 : -k*t_stride_0); + for(i = 0; i < sz; i++) + r__data[i*r__stride_0] = t_data[i*(t_stride_0+t_stride_1)]; + } +} + +void THTensor_(eye)(THTensor *r_, int64_t n, int64_t m) +{ + real *r__data; + int64_t i, sz; + + THArgCheck(n > 0, 1, "invalid argument"); + + if(m <= 0) + m = n; + + THTensor_(resize2d)(r_, n, m); + THTensor_(zero)(r_); + + i = 0; + r__data = THTensor_(data)(r_); + sz = THMin(THTensor_(size)(r_, 0), THTensor_(size)(r_, 1)); + for(i = 0; i < sz; i++) + r__data[i*(r_->stride(0)+r_->stride(1))] = 1; +} + + +void THTensor_(range)(THTensor *r_, accreal xmin, accreal xmax, accreal step) +{ + ptrdiff_t size; + real i = 0; + + THArgCheck(step > 0 || step < 0, 3, "step must be nonzero"); + THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin)) + , 2, "upper bound and larger bound inconsistent with step sign"); + + size = (ptrdiff_t) (((xmax - xmin) / step) + 1); + + if (THTensor_(nElement)(r_) != size) { + THTensor_(resize1d)(r_, size); + } + + TH_TENSOR_APPLY(real, r_, *r__data = xmin + (i++)*step;); +} + +void THTensor_(arange)(THTensor *r_, accreal xmin, accreal xmax, accreal step) { + ptrdiff_t size; + real i = 0; + + THArgCheck(step > 0 || step < 0, 3, "step must be nonzero"); + THArgCheck(((step > 0) && (xmax >= xmin)) || ((step < 0) && (xmax <= xmin)) + , 2, "upper bound and larger bound inconsistent with step sign"); + + size = (ptrdiff_t) ceil((double)(xmax - xmin) / step); + + if (THTensor_(nElement)(r_) != size) { + THTensor_(resize1d)(r_, size); + } + + TH_TENSOR_APPLY(real, r_, *r__data = xmin + (i++)*step;); +} + +void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n) +{ + real *r__data; + int64_t r__stride_0; + int64_t i; + + THArgCheck(n > 0, 1, "must be strictly positive"); + + THTensor_(resize1d)(r_, n); + r__data = THTensor_(data)(r_); + r__stride_0 = THTensor_(stride)(r_,0); + + for(i = 0; i < n; i++) + r__data[i*r__stride_0] = (real)(i); + + for(i = 0; i < n-1; i++) + { + int64_t z = THRandom_random(_generator) % (n-i); + real sav = r__data[i*r__stride_0]; + r__data[i*r__stride_0] = r__data[(z+i)*r__stride_0]; + r__data[(z+i)*r__stride_0] = sav; + } +} + +/* I cut and pasted (slightly adapted) the quicksort code from + Sedgewick's 1978 "Implementing Quicksort Programs" article + http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf + + It is the state of the art existing implementation. The macros + are here to make as close a match as possible to the pseudocode of + Program 2 p.851 + + Note that other partition schemes exist, and are typically presented + in textbook, but those are less efficient. See e.g. + http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto + + Julien, November 12th 2013 +*/ +#define MAX_LEVELS 300 +#define M_SMALL 10 /* Limit for small subfiles */ + +#define ARR(III) arr[(III)*stride] +#define IDX(III) idx[(III)*stride] + +#define LONG_SWAP(AAA, BBB) swap = AAA; AAA = BBB; BBB = swap +#define REAL_SWAP(AAA, BBB) rswap = AAA; AAA = BBB; BBB = rswap + +#define ARR_SWAP(III, JJJ) \ + REAL_SWAP(ARR(III), ARR(JJJ)); + +#define BOTH_SWAP(III, JJJ) \ + REAL_SWAP(ARR(III), ARR(JJJ)); \ + LONG_SWAP(IDX(III), IDX(JJJ)) + +static void THTensor_(quicksortascend)(real *arr, int64_t *idx, int64_t elements, int64_t stride) +{ + int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left; + real rswap, piv; + unsigned char done = 0; + + /* beg[0]=0; end[0]=elements; */ + stack = 0; + L = 0; R = elements-1; + done = elements-1 <= M_SMALL; + + while(!done) { + /* Use median of three for pivot choice */ + P=(L+R)>>1; + BOTH_SWAP(P, L+1); + if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); } + if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); } + if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); } + + i = L+1; j = R; piv = ARR(L); pid = IDX(L); + + do { + do { i = i+1; } while(ARR(i) < piv); + do { j = j-1; } while(ARR(j) > piv); + if (j < i) + break; + BOTH_SWAP(i, j); + } while(1); + BOTH_SWAP(L, j); + /* Left subfile is (L, j-1) */ + /* Right subfile is (i, R) */ + sz_left = j-L; + sz_right = R-i+1; + if (sz_left <= M_SMALL && sz_right <= M_SMALL) { + /* both subfiles are small */ + /* if stack empty */ + if (stack == 0) { + done = 1; + } else { + stack--; + L = beg[stack]; + R = end[stack]; + } + } else if (sz_left <= M_SMALL || sz_right <= M_SMALL) { + /* exactly one of the subfiles is small */ + /* (L,R) = large subfile */ + if (sz_left > sz_right) { + /* Implicit: L = L; */ + R = j-1; + } else { + L = i; + /* Implicit: R = R; */ + } + } else { + /* none of the subfiles is small */ + /* push large subfile */ + /* (L,R) = small subfile */ + if (sz_left > sz_right) { + beg[stack] = L; + end[stack] = j-1; + stack++; + L = i; + /* Implicit: R = R */ + } else { + beg[stack] = i; + end[stack] = R; + stack++; + /* Implicit: L = L; */ + R = j-1; + } + } + } /* while not done */ + /* Now insertion sort on the concatenation of subfiles */ + for(i=elements-2; i>=0; i--) { + if (ARR(i) > ARR(i+1)) { + piv = ARR(i); + pid = IDX(i); + j = i+1; + do { + ARR(j-1) = ARR(j); + IDX(j-1) = IDX(j); + j = j+1; + } while(j < elements && ARR(j) < piv); + ARR(j-1) = piv; + IDX(j-1) = pid; + } + } +} + +static void THTensor_(quicksortdescend)(real *arr, int64_t *idx, int64_t elements, int64_t stride) +{ + int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left; + real rswap, piv; + unsigned char done = 0; + + /* beg[0]=0; end[0]=elements; */ + stack = 0; + L = 0; R = elements-1; + done = elements-1 <= M_SMALL; + + while(!done) { + /* Use median of three for pivot choice */ + P=(L+R)>>1; + BOTH_SWAP(P, L+1); + if (ARR(L+1) < ARR(R)) { BOTH_SWAP(L+1, R); } + if (ARR(L) < ARR(R)) { BOTH_SWAP(L, R); } + if (ARR(L+1) < ARR(L)) { BOTH_SWAP(L+1, L); } + + i = L+1; j = R; piv = ARR(L); pid = IDX(L); + + do { + do { i = i+1; } while(ARR(i) > piv); + do { j = j-1; } while(ARR(j) < piv); + if (j < i) + break; + BOTH_SWAP(i, j); + } while(1); + BOTH_SWAP(L, j); + /* Left subfile is (L, j-1) */ + /* Right subfile is (i, R) */ + sz_left = j-L; + sz_right = R-i+1; + if (sz_left <= M_SMALL && sz_right <= M_SMALL) { + /* both subfiles are small */ + /* if stack empty */ + if (stack == 0) { + done = 1; + } else { + stack--; + L = beg[stack]; + R = end[stack]; + } + } else if (sz_left <= M_SMALL || sz_right <= M_SMALL) { + /* exactly one of the subfiles is small */ + /* (L,R) = large subfile */ + if (sz_left > sz_right) { + /* Implicit: L = L; */ + R = j-1; + } else { + L = i; + /* Implicit: R = R; */ + } + } else { + /* none of the subfiles is small */ + /* push large subfile */ + /* (L,R) = small subfile */ + if (sz_left > sz_right) { + beg[stack] = L; + end[stack] = j-1; + stack++; + L = i; + /* Implicit: R = R */ + } else { + beg[stack] = i; + end[stack] = R; + stack++; + /* Implicit: L = L; */ + R = j-1; + } + } + } /* while not done */ + /* Now insertion sort on the concatenation of subfiles */ + for(i=elements-2; i>=0; i--) { + if (ARR(i) < ARR(i+1)) { + piv = ARR(i); + pid = IDX(i); + j = i+1; + do { + ARR(j-1) = ARR(j); + IDX(j-1) = IDX(j); + j = j+1; + } while(j < elements && ARR(j) > piv); + ARR(j-1) = piv; + IDX(j-1) = pid; + } + } +} + +#undef MAX_LEVELS +#undef M_SMALL + +void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + THTensor_(resizeAs)(rt_, t); + THTensor_(copy)(rt_, t); + + { + THLongStorage *size = THTensor_(newSizeOf)(t); + THLongTensor_resize(ri_, size, NULL); + THLongStorage_free(size); + } + + if(descendingOrder) + { + TH_TENSOR_DIM_APPLY2(real, rt_, int64_t, ri_, dimension, + int64_t i; + for(i = 0; i < ri__size; i++) + ri__data[i*ri__stride] = i; + THTensor_(quicksortdescend)(rt__data, ri__data, rt__size, rt__stride);) + } + else + { + TH_TENSOR_DIM_APPLY2(real, rt_, int64_t, ri_, dimension, + int64_t i; + for(i = 0; i < ri__size; i++) + ri__data[i*ri__stride] = i; + THTensor_(quicksortascend)(rt__data, ri__data, rt__size, rt__stride);) + } +} + +/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's +public domain implementation at http://ndevilla.free.fr/median/median/ +Adapted similarly to the above Quicksort algorithm. +This version does not produce indices along with values. */ +static void THTensor_(quickselectnoidx)(real *arr, int64_t k, int64_t elements, int64_t stride) +{ + int64_t P, L, R, i, j; + real rswap, piv; + L = 0; + R = elements-1; + + do { + if (R <= L) /* One element only */ + return; + + if (R == L+1) { /* Two elements only */ + if (ARR(L) > ARR(R)) { + ARR_SWAP(L, R); + } + return; + } + + /* Use median of three for pivot choice */ + P=(L+R)>>1; + ARR_SWAP(P, L+1); + if (ARR(L+1) > ARR(R)) { ARR_SWAP(L+1, R); } + if (ARR(L) > ARR(R)) { ARR_SWAP(L, R); } + if (ARR(L+1) > ARR(L)) { ARR_SWAP(L+1, L); } + + i = L+1; + j = R; + piv = ARR(L); + do { + do i++; while(ARR(i) < piv); + do j--; while(ARR(j) > piv); + if (j < i) + break; + ARR_SWAP(i, j); + } while(1); + ARR_SWAP(L, j); + + /* Re-set active partition */ + if (j <= k) L=i; + if (j >= k) R=j-1; + } while(1); +} + +/* Implementation of the Quickselect algorithm, based on Nicolas Devillard's +public domain implementation at http://ndevilla.free.fr/median/median/ +Adapted similarly to the above Quicksort algorithm. */ +static void THTensor_(quickselect)(real *arr, int64_t *idx, int64_t k, int64_t elements, int64_t stride) +{ + int64_t P, L, R, i, j, swap; + real rswap, piv; + L = 0; + R = elements-1; + + do { + if (R <= L) /* One element only */ + return; + + if (R == L+1) { /* Two elements only */ + if (ARR(L) > ARR(R)) { + BOTH_SWAP(L, R); + } + return; + } + + /* Use median of three for pivot choice */ + P=(L+R)>>1; + BOTH_SWAP(P, L+1); + if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); } + if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); } + if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); } + + i = L+1; + j = R; + piv = ARR(L); + do { + do i++; while(ARR(i) < piv); + do j--; while(ARR(j) > piv); + if (j < i) + break; + BOTH_SWAP(i, j); + } while(1); + BOTH_SWAP(L, j); + + /* Re-set active partition */ + if (j <= k) L=i; + if (j >= k) R=j-1; + } while(1); +} + +#undef ARR +#undef IDX +#undef LONG_SWAP +#undef REAL_SWAP +#undef BOTH_SWAP + +real THTensor_(medianall)(THTensor *tensor) +{ + THArgCheck(tensor->_dim() > 0, 1, "tensor must have one dimension"); + + real theMedian; + ptrdiff_t numel; + int64_t k; + THTensor *temp_; + real *temp__data; + + numel = THTensor_(nElement)(tensor); + k = (numel-1) >> 1; + + temp_ = THTensor_(newClone)(tensor); + temp__data = THTensor_(data)(temp_); + + THTensor_(quickselectnoidx)(temp__data, k, numel, 1); + + theMedian = temp__data[k]; + + THTensor_(free)(temp_); + + return theMedian; +} + +void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + THTensor *temp_; + THLongTensor *tempi_; + real *temp__data; + int64_t *tempi__data; + int64_t t_size_dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); + + int in_dims = THTensor_(_nDimension)(t); + THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); + THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + t_size_dim = THTensor_(size)(t, dimension); + + temp_ = THTensor_(new)(); + THTensor_(resize1d)(temp_, t_size_dim); + temp__data = THTensor_(data)(temp_); + + tempi_ = THLongTensor_new(); + THLongTensor_resize1d(tempi_, t_size_dim); + tempi__data = THLongTensor_data(tempi_); + + TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + int64_t i; + real mode = 0; + int64_t modei = 0; + int64_t temp_freq = 0; + int64_t max_freq = 0; + for(i = 0; i < t_size_dim; i++) + temp__data[i] = t_data[i*t_stride]; + for(i = 0; i < t_size_dim; i++) + tempi__data[i] = i; + THTensor_(quicksortascend)(temp__data, tempi__data, t_size_dim, 1); + + for(i = 0; i < t_size_dim; i++) + { + temp_freq++; + if ((i == t_size_dim - 1) || (temp__data[i] != temp__data[i+1])) + { + if (temp_freq > max_freq) + { + mode = temp__data[i]; + modei = tempi__data[i]; + max_freq = temp_freq; + } + temp_freq = 0; + } + } + *values__data = mode; + *indices__data = modei;); + + THTensor_(free)(temp_); + THLongTensor_free(tempi_); + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } +} + +void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim) +{ + THLongStorage *dim; + THTensor *temp_; + THLongTensor *tempi_; + real *temp__data; + int64_t *tempi__data; + int64_t t_size_dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); + THArgCheck(k > 0 && k <= t->size(dimension), 2, "selected index out of range"); + + int in_dims = THTensor_(_nDimension)(t); + THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim); + THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(values_, dim, NULL); + THLongTensor_resize(indices_, dim, NULL); + THLongStorage_free(dim); + + t_size_dim = THTensor_(size)(t, dimension); + + temp_ = THTensor_(new)(); + THTensor_(resize1d)(temp_, t_size_dim); + temp__data = THTensor_(data)(temp_); + + tempi_ = THLongTensor_new(); + THLongTensor_resize1d(tempi_, t_size_dim); + tempi__data = THLongTensor_data(tempi_); + + TH_TENSOR_DIM_APPLY3(real, t, real, values_, int64_t, indices_, dimension, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + int64_t i; + for(i = 0; i < t_size_dim; i++) + temp__data[i] = t_data[i*t_stride]; + for(i = 0; i < t_size_dim; i++) + tempi__data[i] = i; + THTensor_(quickselect)(temp__data, tempi__data, k - 1, t_size_dim, 1); + *values__data = temp__data[k-1]; + *indices__data = tempi__data[k-1];); + + THTensor_(free)(temp_); + THLongTensor_free(tempi_); + if (!keepdim) { + THTensor_(squeeze1d)(values_, values_, dimension); + THLongTensor_squeeze1d(indices_, indices_, dimension); + } +} + +void THTensor_(median)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim) +{ + int64_t t_size_dim, k; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "dimension out of range"); + + t_size_dim = THTensor_(size)(t, dimension); + k = (t_size_dim-1) >> 1; /* take middle or one-before-middle element */ + + THTensor_(kthvalue)(values_, indices_, t, k+1, dimension, keepdim); +} + +void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted) +{ +#ifndef USE_TH_SIZE_ZERO_DIM + int numDims = THTensor_(_nDimension)(t); +#else + int numDims = THTensor_(nDimension)(t); +#endif + THArgCheck(dim >= 0 && dim < numDims, 3, "dim not in range"); + + int64_t sliceSize = THTensor_(size)(t, dim); +#ifndef USE_TH_SIZE_ZERO_DIM + THArgCheck(k > 0 && k <= sliceSize, 2, "k not in range for dimension"); +#else + THArgCheck(k >= 0 && k <= sliceSize, 2, "k not in range for dimension"); +#endif + + THTensor *tmpResults = THTensor_(new)(); + THTensor_(resize1d)(tmpResults, sliceSize); + real *tmp__data = THTensor_(data)(tmpResults); + + THLongTensor *tmpIndices = THLongTensor_new(); + THLongTensor_resize1d(tmpIndices, sliceSize); + int64_t *tmpi__data = THLongTensor_data(tmpIndices); + + THLongStorage *topKSize = THTensor_(newSizeOf)(t); + THLongStorage_set(topKSize, dim, k); + THTensor_(resize)(rt_, topKSize, NULL); + THLongTensor_resize(ri_, topKSize, NULL); + THLongStorage_free(topKSize); + + if (dir) { + /* k largest elements, descending order (optional: see sorted) */ + int64_t K = sliceSize - k; + TH_TENSOR_DIM_APPLY3(real, t, real, rt_, int64_t, ri_, dim, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + int64_t i; + for(i = 0; i < sliceSize; i++) + { + tmp__data[i] = t_data[i*t_stride]; + tmpi__data[i] = i; + } + if (K > 0) + THTensor_(quickselect)(tmp__data, tmpi__data, K - 1, sliceSize, 1); + if (sorted) + THTensor_(quicksortdescend)(tmp__data + K, tmpi__data + K, k, 1); + for(i = 0; i < k; i++) + { + rt__data[i*rt__stride] = tmp__data[i + K]; + ri__data[i*ri__stride] = tmpi__data[i + K]; + }) + } + else { + /* k smallest elements, ascending order (optional: see sorted) */ + TH_TENSOR_DIM_APPLY3(real, t, real, rt_, int64_t, ri_, dim, + TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM, + int64_t i; + for(i = 0; i < sliceSize; i++) + { + tmp__data[i] = t_data[i*t_stride]; + tmpi__data[i] = i; + } + THTensor_(quickselect)(tmp__data, tmpi__data, k - 1, sliceSize, 1); + if (sorted) + THTensor_(quicksortascend)(tmp__data, tmpi__data, k - 1, 1); + for(i = 0; i < k; i++) + { + rt__data[i*rt__stride] = tmp__data[i]; + ri__data[i*ri__stride] = tmpi__data[i]; + }) + } + + THTensor_(free)(tmpResults); + THLongTensor_free(tmpIndices); +} + +void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k) +{ + int64_t t_size_0, t_size_1; + int64_t t_stride_0, t_stride_1; + int64_t r__stride_0, r__stride_1; + real *t_data, *r__data; + int64_t r, c; + + THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); + + THTensor_(resizeAs)(r_, t); + + t_size_0 = THTensor_(size)(t, 0); + t_size_1 = THTensor_(size)(t, 1); + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data = THTensor_(data)(r_); + t_data = THTensor_(data)(t); + + for(r = 0; r < t_size_0; r++) + { + int64_t sz = THMin(r+k+1, t_size_1); + for(c = THMax(0, r+k+1); c < t_size_1; c++) + r__data[r*r__stride_0+c*r__stride_1] = 0; + for(c = 0; c < sz; c++) + r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; + } +} + +void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k) +{ + int64_t t_size_0, t_size_1; + int64_t t_stride_0, t_stride_1; + int64_t r__stride_0, r__stride_1; + real *t_data, *r__data; + int64_t r, c; + + THArgCheck(THTensor_(_nDimension)(t) == 2, 1, "expected a matrix"); + + THTensor_(resizeAs)(r_, t); + + t_size_0 = THTensor_(size)(t, 0); + t_size_1 = THTensor_(size)(t, 1); + t_stride_0 = THTensor_(stride)(t, 0); + t_stride_1 = THTensor_(stride)(t, 1); + r__stride_0 = THTensor_(stride)(r_, 0); + r__stride_1 = THTensor_(stride)(r_, 1); + r__data = THTensor_(data)(r_); + t_data = THTensor_(data)(t); + + for(r = 0; r < t_size_0; r++) + { + int64_t sz = THMin(r+k, t_size_1); + for(c = THMax(0, r+k); c < t_size_1; c++) + r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; + for(c = 0; c < sz; c++) + r__data[r*r__stride_0+c*r__stride_1] = 0; + } +} + +void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension) +{ + THTensor* inputs[2]; + inputs[0] = ta; + inputs[1] = tb; + THTensor_(catArray)(r_, inputs, 2, dimension); +} + +void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension); +inline void THTensor_(check_shape_except_dim)(THTensor *first, THTensor *second, int dimension) +{ + int first_dims = first->dim(); + int second_dims = second->dim(); + THArgCheck(first_dims == second_dims, 0, + "Tensors must have same number of dimensions: got %d and %d", + first_dims, second_dims); + for (int dim = 0; dim < first_dims; dim++) { + if (dim == dimension) { + continue; + } + int64_t first_dim_size = first->size(dim); + int64_t second_dim_size = second->size(dim); + THArgCheck(first_dim_size == second_dim_size, 0, + "Sizes of tensors must match except in dimension %d. Got %lld and %lld in dimension %d", + dimension, (long long)first_dim_size, (long long)second_dim_size, dim); + } +} + +void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension) +{ + // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible + // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors + // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific + // size (i.e. other empty sizes are not skipped). + // FIXME: warn if this is the case + bool allSkipped= true; + int64_t nDims = 0; + THTensor *notSkippedTensor; // non-owning reference + auto should_skip = [](THTensor *t) { return t->is_empty() && t->dim() == 1; }; + for (int i = 0; i < numInputs; i++) { + if (should_skip(inputs[i])) { + continue; + } + // We've found a non-empty tensor + allSkipped = false; + notSkippedTensor = inputs[i]; + nDims = notSkippedTensor->dim(); + break; + } + if (allSkipped) { + return; + } + + // Compute cat_dimension based on the non-empty tensor + THArgCheck(dimension < nDims, 4, "invalid dimension %d", dimension); + THArgCheck(numInputs > 0, 3, "invalid number of inputs %d", numInputs); + + // Compute size of the result in the cat dimension + int64_t cat_dim_size = 0; + for (int i = 0; i < numInputs; i++) { + THTensor *tensor = inputs[i]; + if (should_skip(tensor)) { + continue; + } + THTensor_(check_shape_except_dim)(notSkippedTensor, tensor, dimension); + cat_dim_size += tensor->size(dimension); + } + + // Compute the size of the result + THLongStorage *size = THLongStorage_newWithSize(nDims); + for (int dim = 0; dim < nDims; dim++) { + int64_t result_dim_size = notSkippedTensor->size(dim); + if (dim == dimension) { + result_dim_size = cat_dim_size; + } + THLongStorage_data(size)[dim] = result_dim_size; + } + THTensor_(resize)(result, size, NULL); + + // Check contiguity of all inputs and result + bool allContiguous = true; + for (int i = 0; i < numInputs; i++) { + if(!should_skip(inputs[i])) { + allContiguous = allContiguous && THTensor_(isContiguous)(inputs[i]); + } + } + allContiguous = allContiguous && THTensor_(isContiguous)(result); + + // First path is for contiguous inputs along dim 0 + // Second path for non-contiguous + int64_t offset; + if (dimension == 0 && allContiguous) { + real* result_data = THStorage_(data)(THTensor_getStoragePtr(result)) + result->storage_offset(); + offset = 0; + for (int j = 0; j < numInputs; j++) { + if (!should_skip(inputs[j])) { + THTensor* input0 = inputs[j]; + real* input0_data = THStorage_(data)(THTensor_getStoragePtr(input0)) + input0->storage_offset(); + int64_t input0_size = THTensor_(nElement)(input0); + // C standard says you can't pass nullptrs to memcpy, even if the size is 0; ubsan checks this. + if (input0_size != 0) { + memcpy(result_data + offset, input0_data, input0_size*sizeof(real)); + } + offset += input0_size; + } + } + } else { + offset = 0; + for (int j = 0; j < numInputs; j++) { + if (!should_skip(inputs[j])) { + int64_t dimSize = inputs[j]->size(dimension); + THTensor *nt = THTensor_(newWithTensor)(result); + THTensor_(narrow)(nt, NULL, dimension, offset, dimSize); + THTensor_(copy)(nt, inputs[j]); + THTensor_(free)(nt); + offset += dimSize; + } + } + } + THLongStorage_free(size); +} + +int THTensor_(equal)(THTensor *ta, THTensor* tb) +{ + int equal = 1; + if(!THTensor_(isSameSizeAs)(ta, tb)) + return 0; + + if (THTensor_(isContiguous)(ta) && THTensor_(isContiguous)(tb)) { + real *tap = THTensor_(data)(ta); + real *tbp = THTensor_(data)(tb); + ptrdiff_t sz = THTensor_(nElement)(ta); + ptrdiff_t i; + for (i=0; idim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(unsigned char, r_, real, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, real value) \ + { \ + THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(real, r_, real, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(unsigned char, r_, real, ta, real, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(real, r_, real, ta, real, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + + +TENSOR_IMPLEMENT_LOGICAL(lt,<) +TENSOR_IMPLEMENT_LOGICAL(gt,>) +TENSOR_IMPLEMENT_LOGICAL(le,<=) +TENSOR_IMPLEMENT_LOGICAL(ge,>=) +TENSOR_IMPLEMENT_LOGICAL(eq,==) +TENSOR_IMPLEMENT_LOGICAL(ne,!=) + + +#ifdef _OPENMP + +#define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t) \ + { \ + THTensor_(resizeAs)(r_, t); \ + ptrdiff_t r_Size = THTensor_(nElement)(r_); \ + int r_Contig = THTensor_(isContiguous)(r_); \ + int tContig = THTensor_(isContiguous)(t); \ + int inOMP = omp_in_parallel(); \ + if( !inOMP ){ \ + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = CFUNC(*t_data);, OMP_THRESHOLD); \ + } else { \ + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = CFUNC(*t_data);); \ + } \ + } + +#define LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) \ + LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD) + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t) \ + { \ + THTensor_(resizeAs)(r_, t); \ + ptrdiff_t r_Size = THTensor_(nElement)(r_); \ + int r_Contig = THTensor_(isContiguous)(r_); \ + int tContig = THTensor_(isContiguous)(t); \ + if (r_Contig && tContig) { \ + TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(NAME)(r__data, t_data, r__len);); \ + } else { \ + int inOMP = omp_in_parallel(); \ + if( !inOMP ){ \ + TH_TENSOR_APPLY2_OMP(r_Size, r_Contig, tContig, real, r_, real, t, *r__data = CFUNC(*t_data);, OMP_THRESHOLD); \ + } \ + else { \ + TH_TENSOR_APPLY2(real, r_, real, t, *r__data = CFUNC(*t_data);); \ + } \ + } \ + } + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) \ + LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD) + +#else + +#define LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t) \ + { \ + THTensor_(resizeAs)(r_, t); \ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ + } \ + +#define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, PSEUDO_OMP_THRESHOLD) \ + LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS(NAME, CFUNC) + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) \ + void THTensor_(NAME)(THTensor *r_, THTensor *t) \ + { \ + THTensor_(resizeAs)(r_, t); \ + int r_Contig = THTensor_(isContiguous)(r_); \ + int tContig = THTensor_(isContiguous)(t); \ + if (r_Contig && tContig) { \ + TH_TENSOR_APPLY2_CONTIG(real, r_, real, t, THVector_(NAME)(r__data, t_data, r__len);); \ + } else { \ + TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \ + } \ + } \ + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS(NAME, CFUNC, PSEUDO_OMP_THRESHOLD) \ + LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS(NAME, CFUNC) + +#endif + +#define EXPAND(...) __VA_ARGS__ + +#define GET_4TH_ARG(ARG0, ARG1, ARG2, ARG3, ...) ARG3 + +#define LAB_IMPLEMENT_BASIC_FUNCTION_CHOOSE(...) \ + EXPAND(GET_4TH_ARG(__VA_ARGS__, LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS, LAB_IMPLEMENT_BASIC_FUNCTION_2_ARGS, )) + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION_CHOOSE(...) \ + EXPAND(GET_4TH_ARG(__VA_ARGS__, LAB_IMPLEMENT_VECTORIZED_FUNCTION_3_ARGS, LAB_IMPLEMENT_VECTORIZED_FUNCTION_2_ARGS, )) + +#define LAB_IMPLEMENT_BASIC_FUNCTION(...) EXPAND(LAB_IMPLEMENT_BASIC_FUNCTION_CHOOSE(__VA_ARGS__)(__VA_ARGS__)) + +#define LAB_IMPLEMENT_VECTORIZED_FUNCTION(...) EXPAND(LAB_IMPLEMENT_VECTORIZED_FUNCTION_CHOOSE(__VA_ARGS__)(__VA_ARGS__)) + +/* + * LAB_IMPLEMENT_BASIC_FUNCTION is a macro with optional parameters, you can use it flexibly. + * The macro will discard the invalid openmp threshold if openmp is unavailable. The macro will give a default threshold even if you forget to pass one. + * In other word, + * (A), If openmp is UNavailable, the two usage below is both right. + * (1) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity, OMP_OVERHEAD_THRESHOLD) // discard the invalid openmp threshold + * (2) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity) + * (B), If openmp is available, the two usage below is also both right. + * (1) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity, OMP_OVERHEAD_THRESHOLD) + * (2) LAB_IMPLEMENT_BASIC_FUNCTION(type_func, func_entity) // pass the default openmp threshold + * So do LAB_IMPLEMENT_VECTORIZED_FUNCTION. +*/ + +LAB_IMPLEMENT_BASIC_FUNCTION(neg,-) + +#if defined(TH_REAL_IS_LONG) +LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs) +#endif /* int64_t only part */ + +#if defined(TH_REAL_IS_SHORT) || defined(TH_REAL_IS_INT) +LAB_IMPLEMENT_BASIC_FUNCTION(abs,abs) +#endif /* int only part */ + +#if defined(TH_REAL_IS_BYTE) /* Byte only part */ + +int THTensor_(logicalAndAll)(THTensor *tensor) +{ + real prod = 1; + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if(inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, &&:prod, prod = prod && *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } +#else + serial_path = 1; +#endif + if (serial_path) { + TH_TENSOR_APPLY(real, tensor, prod = prod && *tensor_data;); + } + return prod; +} + +int THTensor_(logicalAnyAll)(THTensor *tensor) +{ + real sum = 0; + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if(inOMP) { + serial_path = 1; + } else { + TH_TENSOR_APPLY_REDUCTION_OMP(real, tensor, ||:sum, sum = sum || *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD); + } +#else + serial_path = 1; +#endif + if (serial_path) { + TH_TENSOR_APPLY(real, tensor, sum = sum || *tensor_data;); + } + return (bool)sum; +} + +void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + int r_Contig = THTensor_(isContiguous)(r_); + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + if(r_Contig && (tp != rp)){ + ptrdiff_t iter = 0; + ptrdiff_t r_Size = THTensor_(nElement)(r_); + int r_Dim = r_->_dim(); + #pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD) + for (iter = 0; iter < r_Size; iter++) { + int j; + int64_t quot; + int64_t rem = iter; + ptrdiff_t tBasicIndex = 0; + + for(j = 0; j < r_Dim; ++j) { + if(j != dimension){ + quot = rem/r_->stride(j); + rem = rem%r_->stride(j); + tBasicIndex += quot*t->stride(j); + } + } + real *t_data = tp+tBasicIndex; + real *r__data = rp+iter; + *r__data = 1; + for(j=0; j < t->size(dimension); ++j) { + *r__data = *r__data && *(t_data + j*t->stride(dimension)); + } + } + } else { + serial_path = 1; + } + } +#else + serial_path = 1; +#endif + + if(serial_path) { + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal prod = 1; + int64_t i; + for(i = 0; i < t_size; i++) + prod = prod && t_data[i*t_stride]; + *r__data = (real)prod;); + } else { + THTensor_(fill)(r_, 1); + THTensor *temp_ = THTensor_(newWithTensor)(r_); + // r_.expand_as(t) + THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(temp_, dimension, 0); + + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data && *t_data;); + THTensor_(free)(temp_); + } + } + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "dimension %d out of range", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + int serial_path = 0; +#ifdef _OPENMP + int inOMP = omp_in_parallel(); + if (inOMP) { + serial_path = 1; + } else { + int r_Contig = THTensor_(isContiguous)(r_); + real *tp = THTensor_(data)(t); + real *rp = THTensor_(data)(r_); + if(r_Contig && (tp != rp)){ + ptrdiff_t iter = 0; + ptrdiff_t r_Size = THTensor_(nElement)(r_); + int r_Dim = r_->_dim(); + #pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD) + for (iter = 0; iter < r_Size; iter++) { + int j; + int64_t quot; + int64_t rem = iter; + ptrdiff_t tBasicIndex = 0; + + for(j = 0; j < r_Dim; ++j) { + if(j != dimension){ + quot = rem/r_->stride(j); + rem = rem%r_->stride(j); + tBasicIndex += quot*t->stride(j); + } + } + real *t_data = tp+tBasicIndex; + real *r__data = rp+iter; + *r__data = 0; + for(j=0; j < t->size(dimension); ++j) { + *r__data = *r__data || *(t_data + j*t->stride(dimension)); + } + } + } else { + serial_path = 1; + } + } +#else + serial_path = 1; +#endif + if (serial_path) { + // two implementations optimized for data locality + if (t->stride(dimension) == 1) { + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + accreal sum = 0; + int64_t i; + for(i = 0; i < t_size; i++) + sum = sum || t_data[i*t_stride]; + *r__data = (real)sum;); + } else { + THTensor_(zero)(r_); + THTensor *temp_ = THTensor_(newWithTensor)(r_); + // r_.expand_as(t) + THTensor_setSizeAtDim(temp_, dimension, t->size(dimension)); + THTensor_setStrideAtDim(temp_, dimension, 0); + + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data || *t_data;); + THTensor_(free)(temp_); + } + } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +#endif /* Byte only part */ + +/* floating point only now */ +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) + +#if defined (TH_REAL_IS_FLOAT) +#define TH_MATH_NAME(fn) fn##f +#else +#define TH_MATH_NAME(fn) fn +#endif + +LAB_IMPLEMENT_BASIC_FUNCTION(log,TH_MATH_NAME(log)) +LAB_IMPLEMENT_BASIC_FUNCTION(lgamma,TH_MATH_NAME(lgamma)) +LAB_IMPLEMENT_BASIC_FUNCTION(digamma,TH_MATH_NAME(TH_digamma)) +LAB_IMPLEMENT_BASIC_FUNCTION(trigamma,TH_MATH_NAME(TH_trigamma)) +LAB_IMPLEMENT_BASIC_FUNCTION(log10,TH_MATH_NAME(log10)) +LAB_IMPLEMENT_BASIC_FUNCTION(log1p,TH_MATH_NAME(log1p)) +LAB_IMPLEMENT_BASIC_FUNCTION(log2,TH_MATH_NAME(log2)) +LAB_IMPLEMENT_BASIC_FUNCTION(erf,TH_MATH_NAME(erf)) +LAB_IMPLEMENT_BASIC_FUNCTION(erfc,TH_MATH_NAME(erfc)) +LAB_IMPLEMENT_BASIC_FUNCTION(erfinv,TH_erfinv) +LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil)) +LAB_IMPLEMENT_BASIC_FUNCTION(floor,TH_MATH_NAME(floor)) +LAB_IMPLEMENT_BASIC_FUNCTION(round,TH_MATH_NAME(round)) +LAB_IMPLEMENT_BASIC_FUNCTION(abs,TH_MATH_NAME(fabs)) +LAB_IMPLEMENT_BASIC_FUNCTION(trunc,TH_MATH_NAME(trunc)) +LAB_IMPLEMENT_BASIC_FUNCTION(frac,TH_MATH_NAME(TH_frac)) +LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / ) + +LAB_IMPLEMENT_BASIC_FUNCTION(exp,TH_MATH_NAME(exp),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(expm1,TH_MATH_NAME(expm1),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(cos,TH_MATH_NAME(cos),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(acos,TH_MATH_NAME(acos),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(cosh,TH_MATH_NAME(cosh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(sin,TH_MATH_NAME(sin),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(asin,TH_MATH_NAME(asin),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(sinh,TH_MATH_NAME(sinh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt),HYPER_TH_OMP_OVERHEAD_THRESHOLD) +LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt),HYPER_TH_OMP_OVERHEAD_THRESHOLD) + +LAB_IMPLEMENT_VECTORIZED_FUNCTION(sigmoid,TH_MATH_NAME(TH_sigmoid),HYPER_TH_OMP_OVERHEAD_THRESHOLD) + +void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty) +{ + THTensor_(resizeAs)(r_, tx); + TH_TENSOR_APPLY3(real, r_, real, tx, real, ty, *r__data = TH_MATH_NAME(atan2)(*tx_data,*ty_data);); +} + +void THTensor_(polygamma)(THTensor *r_, int64_t n, THTensor *t) { + switch (n) { + case 0: THTensor_(digamma)(r_, t); return; + case 1: THTensor_(trigamma)(r_, t); return; + default: THError("polygamma(n,x) is not implemented for n>=2"); + } +} + +void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight) +{ + THArgCheck(THTensor_(nElement)(a) == THTensor_(nElement)(b), 2, "sizes do not match"); + THTensor_(resizeAs)(r_, a); + TH_TENSOR_APPLY3(real, r_, real, a, real, b, *r__data = TH_MATH_NAME(TH_lerp)(*a_data, *b_data, weight);); +} + +void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim) +{ + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 2, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + THTensor_(sum)(r_, t, dimension, keepdim); + THTensor_(div)(r_, r_, t->size(dimension)); +} + +void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + // Uses Welford's algorithm for numeric stability + accreal mean = 0; + accreal M2 = 0; + + int64_t i; + for (i = 0; i < t_size; i++) + { + real z = t_data[i*t_stride]; + real delta = z - mean; + mean += delta / (i + 1); + real delta2 = z - mean; + M2 += delta * delta2; + } + + if (biased && t_size >= 2) + { + *r__data = TH_MATH_NAME(sqrt)(M2 / t_size); + } else if (!biased && t_size >= 2) { + *r__data = TH_MATH_NAME(sqrt)(M2 / (t_size - 1)); + } else if (biased && t_size == 1) { + *r__data = 0; + } else { + *r__data = NAN; + }); + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, + // Uses Welford's algorithm for numeric stability + accreal mean = 0; + accreal M2 = 0; + + int64_t i; + for (i = 0; i < t_size; i++) + { + real z = t_data[i*t_stride]; + real delta = z - mean; + mean += delta / (i + 1); + real delta2 = z - mean; + M2 += delta * delta2; + } + + if (biased && t_size >= 2) + { + *r__data = M2 / t_size; + } else if (!biased && t_size >= 2) { + *r__data = M2 / (t_size - 1); + } else if (biased && t_size == 1) { + *r__data = 0; + } else { + *r__data = NAN; + }); + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } +} + +void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int keepdim) +{ + THLongStorage *dim; + + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(t), 3, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + THTensor_(preserveReduceDimSemantics)(r_, THTensor_(_nDimension)(t), dimension, keepdim); + dim = THTensor_(newSizeOf)(t); + THLongStorage_set(dim, dimension, 1); + THTensor_(resize)(r_, dim, NULL); + THLongStorage_free(dim); + + #define DIM_REDUCE(reduce, transform) \ + TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, \ + accreal sum = 0; \ + int64_t i; \ + for(i = 0; i < t_size; i++) { \ + (reduce); \ + } \ + (transform);) \ + + if(value == 0) { + DIM_REDUCE(sum += t_data[i*t_stride] != 0.0, + *r__data = sum); + } else if (value == 1) { + DIM_REDUCE(sum += TH_MATH_NAME(fabs)(t_data[i*t_stride]), + *r__data = sum); + } else if (value == 2) { + DIM_REDUCE(sum += t_data[i*t_stride] * t_data[i*t_stride], + *r__data = TH_MATH_NAME(sqrt)(sum)); + } else if (value == 3) { + DIM_REDUCE(sum += TH_MATH_NAME(fabs)(t_data[i*t_stride] * t_data[i*t_stride] * t_data[i*t_stride]), + *r__data = TH_MATH_NAME(pow)(sum, 1.0/3)); + } else if (value == INFINITY) { + DIM_REDUCE(sum = THMax(sum, TH_MATH_NAME(fabs)(t_data[i*t_stride])), + *r__data = sum); + } else { + DIM_REDUCE(sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(t_data[i*t_stride]), value), + *r__data = TH_MATH_NAME(pow)(sum, 1.0/value)); + } + + if (!keepdim) { + THTensor_(squeeze1d)(r_, r_, dimension); + } + #undef DIM_REDUCE +} + +accreal THTensor_(normall)(THTensor *tensor, real value) +{ + accreal sum = 0; + if(value == 0) { + TH_TENSOR_APPLY(real, tensor, sum += *tensor_data != 0.0;); + return sum; + } else if(value == 1) { + TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(fabs)(*tensor_data);); + return sum; + } else if(value == 2) { + TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += z*z;); + return sqrt(sum); + } else if(value == 3) { + TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += std::abs(z*z*z);); + return TH_MATH_NAME(pow)(sum, 1.0/3); + } else if(value == INFINITY) { + TH_TENSOR_APPLY(real, tensor, sum = THMax(sum, TH_MATH_NAME(fabs)(*tensor_data));); + return sum; + } else { + TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*tensor_data), value);); + return TH_MATH_NAME(pow)(sum, 1.0/value); + } +} + +void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension, real maxnorm) +{ + THTensor *rowR, *rowS; + + THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(src), 3, "invalid dimension %d", + dimension + TH_INDEX_BASE); + THArgCheck(value > 0, 2, "non-positive-norm not supported"); + THArgCheck(THTensor_(nDimension)(src) > 1, 1, "need at least 2 dimensions, got %d dimensions", + THTensor_(nDimension)(src)); + + rowR = THTensor_(new)(); + rowS = THTensor_(new)(); + + THTensor_(resizeAs)(res, src); + + for (int64_t i = 0; i < src->size(dimension); i++) + { + real norm = 0; + real new_norm; + + THTensor_(select)(rowS, src, dimension, i); + THTensor_(select)(rowR, res, dimension, i); + if (value == 1) { + TH_TENSOR_APPLY(real, rowS, norm += fabs(*rowS_data);); + } else if (value == 2) { + TH_TENSOR_APPLY(real, rowS, accreal z = *rowS_data; norm += z*z;); + } else if (value == INFINITY) { + TH_TENSOR_APPLY(real, rowS, norm = THMax(norm, TH_MATH_NAME(fabs)(*rowS_data));); + } else { + TH_TENSOR_APPLY(real, rowS, norm += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*rowS_data), value);); + } + + if (value != INFINITY) { + norm = pow(norm, 1/value); + } + + if (norm > maxnorm) + { + new_norm = maxnorm / (norm + 1e-7); + + TH_TENSOR_APPLY2( + real, rowR, real, rowS, + *rowR_data = (*rowS_data) * new_norm; + ) + } + else + THTensor_(copy)(rowR, rowS); + } + + THTensor_(free)(rowR); + THTensor_(free)(rowS); +} + +accreal THTensor_(dist)(THTensor *tensor, THTensor *src, real value) +{ + real sum = 0; + TH_TENSOR_APPLY2(real, tensor, real, src, + sum += TH_MATH_NAME(pow)( + TH_MATH_NAME(fabs)(*tensor_data - *src_data), value);); + return TH_MATH_NAME(pow)(sum, 1.0/value); +} + +accreal THTensor_(meanall)(THTensor *tensor) +{ + return THTensor_(sumall)(tensor)/THTensor_(nElement)(tensor); +} + +accreal THTensor_(varall)(THTensor *tensor, int biased) +{ + accreal mean = THTensor_(meanall)(tensor); + accreal sum = 0; + TH_TENSOR_APPLY(real, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean);); + sum /= std::max(0, THTensor_(nElement)(tensor) - (biased ? 0 : 1)); + return sum; +} + +accreal THTensor_(stdall)(THTensor *tensor, int biased) +{ + return sqrt(THTensor_(varall)(tensor, biased)); +} + +void THTensor_(linspace)(THTensor *r_, real a, real b, int64_t n) +{ + real i = 0; + + // NumPy allows you to pass different points even if n <= 1 -- should we? + THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + + if (THTensor_(nElement)(r_) != n) { + THTensor_(resize1d)(r_, n); + } + + if (n == 0) { + } else if (n == 1) { + THTensor_(set1d)(r_, 0, a); + } else { + TH_TENSOR_APPLY(real, r_, + *r__data = a + (b-a)/((real)(n-1))*i; + i++; + ); + } +} + +void THTensor_(logspace)(THTensor *r_, real a, real b, int64_t n) +{ + real i = 0; + + // NumPy allows you to pass different points even if n <= 1 -- should we? + THArgCheck(n > 1 || ((n == 0 || n == 1) && (a == b)), 3, "invalid number of points"); + + if (THTensor_(nElement)(r_) != n) { + THTensor_(resize1d)(r_, n); + } + + if (n == 0) { + } else if (n == 1) { + THTensor_(set1d)(r_, 0, TH_MATH_NAME(pow)(10.0, a)); + } else { + TH_TENSOR_APPLY(real, r_, + *r__data = TH_MATH_NAME(pow)(10.0, a + i*(b-a)/((real)(n-1))); + i++; + ); + } +} + +void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minvalue, real maxvalue) +{ + real minval; + real maxval; + real *h_data; + + THTensor_(resize1d)(hist, nbins); + THTensor_(zero)(hist); + minval = minvalue; + maxval = maxvalue; + if (minval == maxval) + { + minval = THTensor_(minall)(tensor); + maxval = THTensor_(maxall)(tensor); + } + if (minval == maxval) + { + minval = minval - 1; + maxval = maxval + 1; + } + + h_data = THTensor_(data)(hist); + + TH_TENSOR_APPLY(real, tensor, + if (*tensor_data >= minval && *tensor_data <= maxval) { + const int bin = (int)((*tensor_data-minval) / (maxval-minval) * nbins); + h_data[THMin(bin, nbins-1)] += 1; + } + ); +} + +void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, int64_t nbins, real minvalue, real maxvalue) +{ + THArgCheck(THTensor_(_nDimension)(tensor) < 3, 2, "invalid dimension %d, the input must be a 2d tensor", THTensor_(_nDimension)(tensor)); + + int dimension = 1; + THArgCheck(dimension >= 0 && dimension < THTensor_(_nDimension)(tensor), 2, "invalid dimension %d", + dimension + TH_INDEX_BASE); + + real minval; + real maxval; + + THTensor_(resize2d)(hist, tensor->size(0), nbins); + THTensor_(zero)(hist); + + minval = minvalue; + maxval = maxvalue; + if (minval == maxval) + { + minval = THTensor_(minall)(tensor); + maxval = THTensor_(maxall)(tensor); + } + if (minval == maxval) + { + minval = minval - 1; + maxval = maxval + 1; + } + + TH_TENSOR_DIM_APPLY2(real, tensor, real, hist, dimension, int64_t i; + for(i = 0; i < tensor_size; i++) + { + if(tensor_data[i*tensor_stride] >= minval && tensor_data[i*tensor_stride] <= maxval) { + const int bin = (int)((tensor_data[i*tensor_stride]-minval) / (maxval-minval) * nbins); + hist_data[THMin(bin, nbins-1)] += 1; + } + } + ); +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. +// Assumes x is close to zero and uses a Taylor expansion. +static inline real THTensor_(beta_grad_alpha_small)(real x, real alpha, real beta) { + const real factor = TH_MATH_NAME(TH_digamma)(alpha) - TH_MATH_NAME(TH_digamma)(alpha + beta) - TH_MATH_NAME(log)(x); + real numer = 1; + real series = numer / alpha * (factor + 1 / alpha); + for (int i = 1; i <= 10; ++i) { + numer *= (i - beta) * x / i; + const real denom = alpha + i; + series += numer / denom * (factor + 1 / denom); + } + const real result = x * TH_MATH_NAME(pow)(1 - x, -beta) * series; + return th_isnan(result) ? 0.0 : result; +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta. +// Assumes x is close to zero and uses a Taylor expansion. +static inline real THTensor_(beta_grad_beta_small)(real x, real alpha, real beta) { + const real factor = TH_MATH_NAME(TH_digamma)(alpha+beta) - TH_MATH_NAME(TH_digamma)(beta); + real numer = 1; + real betas = 1; + real dbetas = 0; + real series = factor / alpha; + for (int i = 1; i <= 8; ++i) { + numer *= -x / i; + dbetas = dbetas * (beta - i) + betas; + betas = betas * (beta - i); + series += numer / (alpha + i) * (dbetas + factor * betas); + } + const real result = -TH_MATH_NAME(pow)(1 - x, 1 - beta) * series; + return th_isnan(result) ? 0.0 : result; +} + +// Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha. +// Assumes alpha and beta are both large and uses a Rice saddle point expansion. +// To ensure numerical stability, this computation is performed at higher precision. +static inline real THTensor_(beta_grad_alpha_mid)(double x, double alpha, double beta) { + const double total = alpha + beta; + const double mean = alpha / total; + const double std = sqrt(alpha * beta / (total + 1)) / total; + if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) { + // Avoid the singularity at x = mean. + const double poly = 47 * x * (beta*beta)*(beta*beta) + alpha * ( + (43 + 20 * (16 + 27 * beta) * x) * (beta*beta)*beta + alpha * ( + 3 * (59 + 180 * beta - 90 * x) * (beta*beta) + alpha * ( + (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * ( + 8 * (1 - x) * (135 * beta - 11))))); + const double prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total); + const double prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total); + return prefactor_num / (1 - x) * poly / prefactor_den; + } + const double prefactor = -x / sqrt(2 * alpha * beta / total); + const double stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha*alpha)) + * (1 + 1 / (12 * beta) + 1 / (288 * beta*beta)) + / (1 + 1 / (12 * total) + 1 / (288 * total*total)); + const double term1_num = 2 * (alpha*alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta*beta); + const double axbx = alpha * (x-1) + beta * x; + const double term1_den = sqrt(2 * alpha / beta) * pow(total, 1.5f) * axbx*axbx; + const double term1 = term1_num / term1_den; + const double term2 = 0.5f * log(alpha / (total * x)); + const double term3_num = sqrt(8 * alpha * beta / total); + const double term3_den = beta * x + alpha * (x - 1); + const double term3 = term3_num / term3_den; + const double term4_base = beta * log(beta / (total * (1 - x))) + + alpha * log(alpha / (total * x)); + const double term4 = pow(term4_base, -1.5f); + const double term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4)); + return stirling * prefactor * term1234; +} + +// Computes a scaled reparameterized gradient +// -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x) +// for random number x drawn from a Beta distribution Beta(alpha,beta). +// This function inputs total=alpha+beta to make it easy to implement +// Dirichlet reparameterized gradients in terms of Betas. +static inline real THTensor_(dirichlet_grad_one)(real x, real alpha, real total) { + const real beta = total - alpha; + const real boundary = total * x * (1 - x); + + // Use an asymptotic approximation for x close to 0. + if (x <= 0.5f && boundary < 2.5f) { + return THTensor_(beta_grad_alpha_small)(x, alpha, beta); + } + + // Use an asymptotic approximation for x close to 1. + if (x >= 0.5f && boundary < 0.75f) { + return -THTensor_(beta_grad_beta_small)(1 - x, beta, alpha); + } + + // Use an asymptotic approximation when alpha and (total - alpha) are both large. + if (alpha > 6 && beta > 6) { + return THTensor_(beta_grad_alpha_mid)(x, alpha, beta); + } + + // Use a rational correction to an analytic approximation. + static const real c[2][3][3][4] = { + {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863}, + {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033}, + {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}}, + {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814}, + {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057}, + {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}}, + {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565}, + {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181}, + {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}}, + {{{1, -0.02924021934, -0.04438342661, 0.007285809825}, + {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521}, + {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}}, + {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273}, + {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956}, + {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}}, + {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05}, + {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05}, + {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}}, + }; + const real u = TH_MATH_NAME(log)(x); + const real a = TH_MATH_NAME(log)(alpha) - u; + const real b = TH_MATH_NAME(log)(total) - a; + const real pow_u[3] = {1, u, u * u}; + const real pow_a[3] = {1, a, a * a}; + real p = 0.0; + real q = 0.0; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + const real ua = pow_u[i] * pow_a[j]; + p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3]))); + q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3]))); + } + } + const real approx = x * (TH_MATH_NAME(TH_digamma)(total) - TH_MATH_NAME(TH_digamma)(alpha)) / beta; + return p / q * approx; +} + +void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THTensor *total) +{ + x = THTensor_(newContiguous)(x); + alpha = THTensor_(newContiguous)(alpha); + total = THTensor_(newContiguous)(total); + TH_CHECK_SAME_SIZE(alpha, x); + TH_CHECK_SAME_SIZE(total, x); + THTensor_(resizeAs)(self, x); + THTensor* grad = THTensor_(newContiguous)(self); + + real*const grad_data = THTensor_(data)(grad); + real*const x_data = THTensor_(data)(x); + real*const alpha_data = THTensor_(data)(alpha); + real*const total_data = THTensor_(data)(total); + const int64_t numel = THTensor_(nElement)(x); + int64_t i; + #pragma omp parallel for if(numel > TH_OMP_OVERHEAD_THRESHOLD) private(i) + for(i = 0; i < numel; ++i) { + grad_data[i] = THTensor_(dirichlet_grad_one)(x_data[i], alpha_data[i], total_data[i]); + } + + THTensor_(freeCopyTo)(grad, self); +} + +#undef TH_MATH_NAME +#endif /* floating point only part */ +#undef IS_NONZERO + +#endif /* TH_GENERIC_FILE */ diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp index 6a599b3b655fd..cfa934800b9bd 100644 --- a/aten/src/THC/THCTensor.cpp +++ b/aten/src/THC/THCTensor.cpp @@ -222,7 +222,7 @@ void THCTensor_squeeze1d(THCState *state, THCTensor *self, THCTensor *src, int d THTensor_setSizeAtDim(self, d, self->size(d+1)); THTensor_setStrideAtDim(self, d, self->stride(d+1)); } - THTensor_resizeDim(self, self->dim_ - 1); + THTensor_resizeDim(self, self->dim() - 1); } } diff --git a/aten/src/THC/THCTensorSort.cu b/aten/src/THC/THCTensorSort.cu index ed1342f53b8e6..27c3b5106a619 100644 --- a/aten/src/THC/THCTensorSort.cu +++ b/aten/src/THC/THCTensorSort.cu @@ -3,60 +3,61 @@ void THCudaLongTensor_fillSliceWithIndex(THCState* state, THCudaLongTensor* t, int dim) { - int64_t dims = THCudaLongTensor__nDimension(state, t); + int64_t dims = THCudaLongTensor_nDimension(state, t); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); ptrdiff_t inElements = THCudaLongTensor_nElement(state, t); - int64_t sliceSize = THCudaLongTensor_size(state, t, dim); - ptrdiff_t numSlices = inElements / sliceSize; + if (inElements > 0) { + int64_t sliceSize = THCudaLongTensor_size(state, t, dim); + ptrdiff_t numSlices = inElements == 0 ? 0 : inElements / sliceSize; - dim3 grid; - if (!THC_getGridFromTiles(numSlices, grid)) { - THError("Slice to fill with indices is too large"); - } + dim3 grid; + if (!THC_getGridFromTiles(numSlices, grid)) { + THError("Slice to fill with indices is too large"); + } - int64_t maxThreads = - THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock; - int64_t numThreads = sliceSize; - if (numThreads > maxThreads) { - numThreads = maxThreads; - } + int64_t maxThreads = + THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock; + int64_t numThreads = sliceSize; + if (numThreads > maxThreads) { + numThreads = maxThreads; + } - dim3 block(numThreads); + dim3 block(numThreads); -#define FILL_INDEX(T, DIM) \ - fillSliceWithIndex \ - <<>>( \ - info, numSlices, sliceSize, info.strides[collapseDim]) +#define FILL_INDEX(T, DIM) \ + fillSliceWithIndex \ + <<>>( \ + info, numSlices, sliceSize, info.strides[collapseDim]) - if (THCTensor_canUse32BitIndexMath(state, t)) { - TensorInfo info = - getTensorInfo(state, t); - info.reduceDim(dim); - int collapseDim = info.collapseDims(dim); + if (THCTensor_canUse32BitIndexMath(state, t)) { + TensorInfo info = + getTensorInfo(state, t); + info.reduceDim(dim); + int collapseDim = info.collapseDims(dim); - if (info.isContiguous()) { - FILL_INDEX(unsigned int, -2); - } else { - if (info.dims == 1) { - FILL_INDEX(unsigned int, 1); - } else if (info.dims == 2) { - FILL_INDEX(unsigned int, 2); + if (info.isContiguous()) { + FILL_INDEX(unsigned int, -2); } else { - FILL_INDEX(unsigned int, -1); + if (info.dims == 1) { + FILL_INDEX(unsigned int, 1); + } else if (info.dims == 2) { + FILL_INDEX(unsigned int, 2); + } else { + FILL_INDEX(unsigned int, -1); + } } + } else { + TensorInfo info = + getTensorInfo(state, t); + info.reduceDim(dim); + int collapseDim = info.collapseDims(dim); + + // catch-all implementation + FILL_INDEX(uint64_t, -1); } - } else { - TensorInfo info = - getTensorInfo(state, t); - info.reduceDim(dim); - int collapseDim = info.collapseDims(dim); - - // catch-all implementation - FILL_INDEX(uint64_t, -1); - } #undef FILL_INDEX - - THCudaCheck(cudaGetLastError()); + THCudaCheck(cudaGetLastError()); + } } diff --git a/aten/src/THC/generic/THCTensor.cpp b/aten/src/THC/generic/THCTensor.cpp index 023a55d176be8..de5dc476deacf 100644 --- a/aten/src/THC/generic/THCTensor.cpp +++ b/aten/src/THC/generic/THCTensor.cpp @@ -407,7 +407,7 @@ void THCTensor_(select)(THCState *state, THCTensor *self, THCTensor *src, int di THTensor_setSizeAtDim(self, d, self->size(d+1)); THTensor_setStrideAtDim(self, d, self->stride(d+1)); } - THTensor_resizeDim(self, self->dim_ - 1); + THTensor_resizeDim(self, self->dim() - 1); } void THCTensor_(transpose)(THCState *state, THCTensor *self, THCTensor *src, int dimension1, int dimension2) diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu index 1d934595aabd2..6e04953a3d1d2 100644 --- a/aten/src/THC/generic/THCTensorIndex.cu +++ b/aten/src/THC/generic/THCTensorIndex.cu @@ -9,10 +9,10 @@ static ptrdiff_t THCTensor_(getSliceSize)(THCState *state, THCTensor *dst, THCudaLongTensor *index, THCTensor *src) { - int dstDims = THCTensor_(_nDimension)(state, dst); - int srcDims = (src == nullptr) ? dstDims : THCTensor_(_nDimension)(state, src); + int dstDims = THCTensor_(nDimension)(state, dst); + int srcDims = (src == nullptr) ? dstDims : THCTensor_(nDimension)(state, src); - THArgCheck(THCudaLongTensor__nDimension(state, index) == 1, 4, + THArgCheck(THCudaLongTensor_nDimension(state, index) == 1, 4, "expecting vector of indices"); THArgCheck(dim >= 0 && dim < dstDims, 2, "Indexing dim is out of bounds"); @@ -97,11 +97,11 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices)); - int dims = THCTensor_(_nDimension)(state, dst); + int dims = THCTensor_(nDimension)(state, dst); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - dims = THCTensor_(_nDimension)(state, src); + dims = THCTensor_(nDimension)(state, src); THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING); - dims = THCudaLongTensor__nDimension(state, indices); + dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); // The `src` is partitioned into two parts: @@ -112,8 +112,12 @@ void THCTensor_(indexCopy)(THCState *state, THCTensor *dst, int dim, THCudaLongT ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src); ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src); int64_t dstCopyDimSize = THCTensor_(size)(state, dst, dim); - ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices); + + if (sliceSize == 0) { + return; + } + cudaStream_t stream = THCState_getCurrentStream(state); int indContig = THCudaLongTensor_isContiguous(state, indices); @@ -282,11 +286,11 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices)); - int dims = THCTensor_(_nDimension)(state, dst); + int dims = THCTensor_(nDimension)(state, dst); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - dims = THCTensor_(_nDimension)(state, src); + dims = THCTensor_(nDimension)(state, src); THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING); - dims = THCudaLongTensor__nDimension(state, indices); + dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); // The `src` is partitioned into two parts: @@ -297,8 +301,11 @@ void THCTensor_(indexAdd)(THCState *state, THCTensor *dst, int dim, THCudaLongTe ptrdiff_t sliceSize = THCTensor_(getSliceSize)(state, dst, dim, indices, src); ptrdiff_t srcTotalSize = THCTensor_(nElement)(state, src); int64_t dstAddDimSize = THCTensor_(size)(state, dst, dim); - ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices); + + if (sliceSize == 0) { + return; + } cudaStream_t stream = THCState_getCurrentStream(state); int indContig = THCudaLongTensor_isContiguous(state, indices); @@ -402,9 +409,9 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, dst)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices)); - int dims = THCTensor_(_nDimension)(state, dst); + int dims = THCTensor_(nDimension)(state, dst); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - dims = THCudaLongTensor__nDimension(state, indices); + dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); // The `src` is partitioned into two parts: @@ -416,8 +423,11 @@ void THCTensor_(indexFill)(THCState *state, THCTensor *dst, int dim, THCudaLongT THCTensor_(getSliceSize)(state, dst, dim, indices, nullptr); ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst); int64_t dstFillDimSize = THCTensor_(size)(state, dst, dim); - ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices); + + if (sliceSize == 0) { + return; + } cudaStream_t stream = THCState_getCurrentStream(state); int indContig = THCudaLongTensor_isContiguous(state, indices); @@ -508,25 +518,26 @@ void THCTensor_(indexSelect)(THCState *state, THCTensor *dst, THCTensor *src, in { THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, dst, src, indices)); - int dims = THCTensor_(_nDimension)(state, dst); + int dims = THCTensor_(nDimension)(state, dst); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - dims = THCTensor_(_nDimension)(state, src); + dims = THCTensor_(nDimension)(state, src); THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING); - dims = THCudaLongTensor__nDimension(state, indices); + dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 5, CUTORCH_DIM_WARNING); ptrdiff_t numIndices = THCudaLongTensor_nElement(state, indices); - int srcDims = THCTensor_(_nDimension)(state, src); + int srcDims = THCTensor_(nDimension)(state, src); cudaStream_t stream = THCState_getCurrentStream(state); - THArgCheck(THCudaLongTensor__nDimension(state, indices) <= 1, 3, + THArgCheck(THCudaLongTensor_nDimension(state, indices) <= 1, 3, "Index is supposed to be an empty tensor or a vector"); THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds"); THArgCheck(srcDims > 0, 2, "Source tensor is empty"); THLongStorage *newSize; +#ifndef USE_TH_SIZE_ZERO_DIM if (numIndices == 0) { newSize = THCTensor_(newSizeOf)(state, src); THLongStorage_set(newSize, 0, numIndices); @@ -534,12 +545,18 @@ void THCTensor_(indexSelect)(THCState *state, THCTensor *dst, THCTensor *src, in THLongStorage_free(newSize); return; } +#endif newSize = THCTensor_(newSizeOf)(state, src); THLongStorage_set(newSize, dim, numIndices); THCTensor_(resize)(state, dst, newSize, NULL); THLongStorage_free(newSize); + ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst); + if (dstTotalSize == 0) { + return; + } + int indContig = THCudaLongTensor_isContiguous(state, indices); // The `src` is partitioned into two parts: @@ -547,7 +564,6 @@ void THCTensor_(indexSelect)(THCState *state, THCTensor *dst, THCTensor *src, in // total size of the tensor ignoring dimension `dim`; // -the number of indices we are choosing, which is the total size // of the tensor `indices`. - ptrdiff_t dstTotalSize = THCTensor_(nElement)(state, dst); int64_t srcSelectDimSize = THCTensor_(size)(state, src, dim); ptrdiff_t sliceSize = dstTotalSize / numIndices; diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu index babbd6d24eb61..4ebf6d93e8067 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ b/aten/src/THC/generic/THCTensorMathBlas.cu @@ -50,7 +50,8 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec)); if( (mat->dim() != 2) || (vec->dim() != 1) ) - THError("matrix and vector expected"); + THError("2D tensor and 1D tensor expected, got %dD, %dD tensors", + mat->dim(), vec->dim()); if( mat->size(1) != vec->size(0) ) THError("size mismatch"); @@ -151,7 +152,8 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2)); if ( (vec1->dim() != 1) || (vec2->dim() != 1) ) { - THError("vector and vector expected"); + THError("1D tensors expected, got %dD, %dD tensors", + vec1->dim(), vec2->dim()); } if (t->dim() != 2) { @@ -248,10 +250,10 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real THCTensor *r__, *m1_, *m2_; if( (m1->dim() != 2) || (m2->dim() != 2) ) - THError("matrices expected, got %dD, %dD tensors", m1->dim(), m2->dim()); + THError("2D tensors expected, got %dD, %dD tensors", m1->dim(), m2->dim()); if(t->dim() != 2) - THError("matrix expected, got %dD tensor for t", t->dim()); + THError("2D tensor expected, got %dD tensor for t", t->dim()); if(m1->size(1) != m2->size(0)) { THCDescBuff bm1 = THCTensor_(sizeDesc)(state, m1); diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu index 1c9d9eac6ac60..0563a103f73ef 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.cu +++ b/aten/src/THC/generic/THCTensorMathReduce.cu @@ -61,22 +61,25 @@ THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, real value, THCTensor *self_; THCTensor *src_ = THCTensor_(newTranspose)(state, src, dimension, 0); THCTensor *data = THCTensor_(newClone)(state, src_); - ptrdiff_t size = THCTensor_(nElement)(state, data)/data->size(0); + int64_t numel = THCTensor_(nElement)(state, data); - THArgCheck(dimension >= 0 && dimension < THCTensor_(_nDimension)(state, src), 3, "invalid dimension"); + THArgCheck(dimension >= 0 && dimension < THCTensor_(nDimension)(state, src), 3, "invalid dimension"); THArgCheck(THCNumerics::gt(value, scalar_cast(0)), 2, "non-positive-norm not supported"); - THArgCheck(THCTensor_(_nDimension)(state, src) > 1, 1, "need at least 2 dimensions"); + THArgCheck(THCTensor_(nDimension)(state, src) > 1, 1, "need at least 2 dimensions"); - dim3 grid(data->size(0)); - dim3 threads(32); + if (numel > 0) { + ptrdiff_t size = numel / data->size(0); + dim3 grid(data->size(0)); + dim3 threads(32); - THCTensor_kernel_renorm - <<>> - (THCTensor_(data)(state, data), scalar_cast(value), size, scalar_cast(maxnorm)); + THCTensor_kernel_renorm + <<>> + (THCTensor_(data)(state, data), scalar_cast(value), size, scalar_cast(maxnorm)); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); + cudaError errcode = cudaGetLastError(); + if(errcode != cudaSuccess) + THError(cudaGetErrorString(errcode)); + } THCTensor_(free)(state, src_); self_ = THCTensor_(newTranspose)(state, data, dimension, 0); diff --git a/aten/src/THC/generic/THCTensorScatterGather.cu b/aten/src/THC/generic/THCTensorScatterGather.cu index f04ae5a557b95..a050e1975edfe 100644 --- a/aten/src/THC/generic/THCTensorScatterGather.cu +++ b/aten/src/THC/generic/THCTensorScatterGather.cu @@ -12,25 +12,25 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - THArgCheck(THCudaLongTensor__nDimension(state, index) == THCTensor_(_nDimension)(state, src), 4, + THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 4, "Index tensor must have same dimensions as input tensor"); THLongStorage *indexSize = THCudaLongTensor_newSizeOf(state, index); THArgCheck(THCTensor_(isSize)(state, tensor, indexSize), 4, "Index tensor must have the same size as output tensor."); THLongStorage_free(indexSize); - THArgCheck(dim >= 0 && dim < THCTensor_(_nDimension)(state, tensor), 3, + THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 3, "Index dimension is out of bounds"); - THArgCheck(THCTensor_(_nDimension)(state, src) == THCTensor_(_nDimension)(state, tensor), 2, + THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 2, "Input tensor must have same dimensions as output tensor"); - for (int d = 0; d < THCTensor_(_nDimension)(state, tensor); d++) { + for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) { if (d != dim) { THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 2, "Input tensor must have same size as output tensor apart from the specified dimension"); } } - THArgCheck(THCTensor_(_nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, + THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, 1, CUTORCH_DIM_WARNING); @@ -47,44 +47,46 @@ void THCTensor_(gather)(THCState* state, THCTensor *tensor, tensor = THCTensor_(newContiguous)(state, tensor); } - if (THCTensor_canUse32BitIndexMath(state, tensor) && - THCTensor_canUse32BitIndexMath(state, src) && - THCTensor_canUse32BitIndexMath(state, index)) { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - - // Specialize for a small number of dimensions. - switch (indexInfo.dims) { - case 1: - RUN(unsigned int, 1, real); - THCudaCheck(cudaGetLastError()); - break; - case 2: - RUN(unsigned int, 2, real); - THCudaCheck(cudaGetLastError()); - break; - case 3: - RUN(unsigned int, 3, real); - THCudaCheck(cudaGetLastError()); - break; - default: - RUN(unsigned int, -1, real); - THCudaCheck(cudaGetLastError()); - break; + if (totalElements > 0) { + if (THCTensor_canUse32BitIndexMath(state, tensor) && + THCTensor_canUse32BitIndexMath(state, src) && + THCTensor_canUse32BitIndexMath(state, index)) { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + + // Specialize for a small number of dimensions. + switch (indexInfo.dims) { + case 1: + RUN(unsigned int, 1, real); + THCudaCheck(cudaGetLastError()); + break; + case 2: + RUN(unsigned int, 2, real); + THCudaCheck(cudaGetLastError()); + break; + case 3: + RUN(unsigned int, 3, real); + THCudaCheck(cudaGetLastError()); + break; + default: + RUN(unsigned int, -1, real); + THCudaCheck(cudaGetLastError()); + break; + } + } else { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + RUN(uint64_t, -1, real); + THCudaCheck(cudaGetLastError()); } - } else { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - RUN(uint64_t, -1, real); - THCudaCheck(cudaGetLastError()); } if (oldTensor) { @@ -107,14 +109,14 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - THArgCheck(dim >= 0 && dim < THCTensor_(_nDimension)(state, tensor), 2, + THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THCudaLongTensor__nDimension(state, index) == THCTensor_(_nDimension)(state, src), 3, + THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3, "Index tensor must have same dimensions as input tensor"); - THArgCheck(THCTensor_(_nDimension)(state, src) == THCTensor_(_nDimension)(state, tensor), 4, + THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4, "Input tensor must have same dimensions as output tensor"); - for (int d = 0; d < THCTensor_(_nDimension)(state, tensor); d++) { + for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) { int64_t indexSizeD = THCudaLongTensor_size(state, index, d); if (d != dim) { THArgCheck(indexSizeD <= THCTensor_(size)(state, tensor, d), 3, @@ -126,7 +128,7 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong THCudaLongTensor_sizeDesc(state, index).str, THCTensor_(sizeDesc)(state, src).str); } - THArgCheck(THCTensor_(_nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, + THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, 1, CUTORCH_DIM_WARNING); const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index); @@ -142,40 +144,42 @@ void THCTensor_(scatter)(THCState* state, THCTensor *tensor, int dim, THCudaLong tensor = THCTensor_(newContiguous)(state, tensor); } - if (THCTensor_canUse32BitIndexMath(state, tensor) && - THCTensor_canUse32BitIndexMath(state, src) && - THCTensor_canUse32BitIndexMath(state, index)) { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - - // Specialize for a small number of dimensions. - switch (indexInfo.dims) { - case 1: - RUN(unsigned int, 1, real); - break; - case 2: - RUN(unsigned int, 2, real); - break; - case 3: - RUN(unsigned int, 3, real); - break; - default: - RUN(unsigned int, -1, real); - break; + if (totalElements > 0) { + if (THCTensor_canUse32BitIndexMath(state, tensor) && + THCTensor_canUse32BitIndexMath(state, src) && + THCTensor_canUse32BitIndexMath(state, index)) { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + + // Specialize for a small number of dimensions. + switch (indexInfo.dims) { + case 1: + RUN(unsigned int, 1, real); + break; + case 2: + RUN(unsigned int, 2, real); + break; + case 3: + RUN(unsigned int, 3, real); + break; + default: + RUN(unsigned int, -1, real); + break; + } + } else { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + + RUN(uint64_t, -1, real) } - } else { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - - RUN(uint64_t, -1, real) } if (oldTensor) { @@ -197,25 +201,30 @@ void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaL THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, tensor, src)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - THArgCheck(dim >= 0 && dim < THCTensor_(_nDimension)(state, tensor), 2, + THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THCudaLongTensor__nDimension(state, index) == THCTensor_(_nDimension)(state, src), 3, + THArgCheck(THCudaLongTensor_nDimension(state, index) == THCTensor_(nDimension)(state, src), 3, "Index tensor must have same dimensions as input tensor"); - THArgCheck(THCTensor_(_nDimension)(state, src) == THCTensor_(_nDimension)(state, tensor), 4, + THArgCheck(THCTensor_(nDimension)(state, src) == THCTensor_(nDimension)(state, tensor), 4, "Input tensor must have same dimensions as output tensor"); - THLongStorage *indexDims = THCudaLongTensor_newSizeOf(state, index); - THArgCheck(THCTensor_(isSize)(state, src, indexDims), 3, - "Index tensor must have the same size as input tensor."); - THLongStorage_free(indexDims); - for (int d = 0; d < THCTensor_(_nDimension)(state, tensor); d++) { + for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) { if (d != dim) { THArgCheck(THCTensor_(size)(state, tensor, d) == THCTensor_(size)(state, src, d), 4, "Input tensor must have same size as output tensor apart from the specified dimension"); } + int64_t indexSizeD = THCudaLongTensor_size(state, index, d); + if (d != dim) { + THArgCheck(indexSizeD <= THCTensor_(size)(state, tensor, d), 3, + "Index tensor must not have larger size than output tensor apart from the specified dimension %d, but got index %s output %s", + dim, THCudaLongTensor_sizeDesc(state, index).str, THCTensor_(sizeDesc)(state, tensor).str); + } + THArgCheck(indexSizeD <= THCTensor_(size)(state, src, d), 3, + "Index tensor must not have larger size than input tensor, but got index %s input %s", + THCudaLongTensor_sizeDesc(state, index).str, THCTensor_(sizeDesc)(state, src).str); } - THArgCheck(THCTensor_(_nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, + THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, 1, CUTORCH_DIM_WARNING); const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index); @@ -231,40 +240,42 @@ void THCTensor_(scatterAdd)(THCState* state, THCTensor *tensor, int dim, THCudaL tensor = THCTensor_(newContiguous)(state, tensor); } - if (THCTensor_canUse32BitIndexMath(state, tensor) && - THCTensor_canUse32BitIndexMath(state, src) && - THCTensor_canUse32BitIndexMath(state, index)) { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - - // Specialize for a small number of dimensions. - switch (indexInfo.dims) { - case 1: - RUN(unsigned int, 1, real); - break; - case 2: - RUN(unsigned int, 2, real); - break; - case 3: - RUN(unsigned int, 3, real); - break; - default: - RUN(unsigned int, -1, real); - break; + if (totalElements > 0) { + if (THCTensor_canUse32BitIndexMath(state, tensor) && + THCTensor_canUse32BitIndexMath(state, src) && + THCTensor_canUse32BitIndexMath(state, index)) { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + + // Specialize for a small number of dimensions. + switch (indexInfo.dims) { + case 1: + RUN(unsigned int, 1, real); + break; + case 2: + RUN(unsigned int, 2, real); + break; + case 3: + RUN(unsigned int, 3, real); + break; + default: + RUN(unsigned int, -1, real); + break; + } + } else { + TensorInfo tensorInfo = + getTensorInfo(state, tensor); + TensorInfo srcInfo = + getTensorInfo(state, src); + TensorInfo indexInfo = + getTensorInfo(state, index); + + RUN(uint64_t, -1, real) } - } else { - TensorInfo tensorInfo = - getTensorInfo(state, tensor); - TensorInfo srcInfo = - getTensorInfo(state, src); - TensorInfo indexInfo = - getTensorInfo(state, index); - - RUN(uint64_t, -1, real) } if (oldTensor) { @@ -288,13 +299,13 @@ THCTensor_(scatterFill)(THCState* state, THCTensor *tensor, THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, tensor)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, index)); - THArgCheck(dim >= 0 && dim < THCTensor_(_nDimension)(state, tensor), 2, + THArgCheck(dim >= 0 && dim < THCTensor_(nDimension)(state, tensor), 2, "Index dimension is out of bounds"); - THArgCheck(THCudaLongTensor__nDimension(state, index) == - THCTensor_(_nDimension)(state, tensor), 3, + THArgCheck(THCudaLongTensor_nDimension(state, index) == + THCTensor_(nDimension)(state, tensor), 3, "Index tensor must have same dimensions as output tensor"); - for (int d = 0; d < THCTensor_(_nDimension)(state, tensor); d++) { + for (int d = 0; d < THCTensor_(nDimension)(state, tensor); d++) { if (d != dim) { THArgCheck(THCTensor_(size)(state, tensor, d) == THCudaLongTensor_size(state, index, d), 4, @@ -302,7 +313,7 @@ THCTensor_(scatterFill)(THCState* state, THCTensor *tensor, } } - THArgCheck(THCTensor_(_nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, + THArgCheck(THCTensor_(nDimension)(state, tensor) <= MAX_CUTORCH_DIMS, 1, CUTORCH_DIM_WARNING); const ptrdiff_t totalElements = THCudaLongTensor_nElement(state, index); diff --git a/aten/src/THC/generic/THCTensorSort.cu b/aten/src/THC/generic/THCTensorSort.cu index a97d19b96bcc8..c6e95def0a4d2 100644 --- a/aten/src/THC/generic/THCTensorSort.cu +++ b/aten/src/THC/generic/THCTensorSort.cu @@ -13,20 +13,21 @@ THC_API void THCTensor_(sortKeyValueInplace)(THCState* state, THArgCheck(THCTensor_(isSize)(state, key, valueSize), 2, "Key tensor must have same size as value tensor"); THLongStorage_free(valueSize); - int dims = THCudaLongTensor__nDimension(state, value); + int dims = THCudaLongTensor_nDimension(state, value); THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING); - dims = THCTensor_(_nDimension)(state, key); + dims = THCTensor_(nDimension)(state, key); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); ptrdiff_t inElements = THCTensor_(nElement)(state, key); - int64_t keySliceSize = THCTensor_(size)(state, key, dim); - ptrdiff_t keySlices = inElements / keySliceSize; - if (THCTensor_(_nDimension)(state, key) == 0) { + if (inElements == 0) { // Zero-dim tensor; do nothing return; } + int64_t keySliceSize = THCTensor_(size)(state, key, dim); + ptrdiff_t keySlices = inElements / keySliceSize; + // The amount of shared memory and block size is based on // 2^ceil(lg(n)); we choose that sorting implementation for a given // size. @@ -283,11 +284,11 @@ THC_API void THCTensor_(sort)(THCState* state, int dim, int order) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, sorted, input)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 1, indices)); - int64_t dims = THCTensor_(_nDimension)(state, sorted); + int64_t dims = THCTensor_(nDimension)(state, sorted); THArgCheck(dims <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - dims = THCTensor_(_nDimension)(state, input); + dims = THCTensor_(nDimension)(state, input); THArgCheck(dims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); - dims = THCudaLongTensor__nDimension(state, indices); + dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING); // Make sure sufficient output space is allocated diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index c2f3a285fecb6..de120e8c304ab 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -9,16 +9,16 @@ THC_API void THCTensor_(topk)(THCState* state, int64_t k, int dim, int dir, int sorted) { THAssert(topK != NULL && indices != NULL && input_ != NULL); THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, topK, indices, input_)); - THArgCheck(THCTensor_(_nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); - int64_t dims = THCudaLongTensor__nDimension(state, indices); + THArgCheck(THCTensor_(nDimension)(state, topK) <= MAX_CUTORCH_DIMS, 2, CUTORCH_DIM_WARNING); + int64_t dims = THCudaLongTensor_nDimension(state, indices); THArgCheck(dims <= MAX_CUTORCH_DIMS, 3, CUTORCH_DIM_WARNING); - int numDims = THCTensor_(_nDimension)(state, input_); + int numDims = THCTensor_(nDimension)(state, input_); THArgCheck(numDims <= MAX_CUTORCH_DIMS, 4, CUTORCH_DIM_WARNING); THArgCheck(dim >= 0 && dim < numDims, 6, "dim not in range"); int64_t sliceSize = THCTensor_(size)(state, input_, dim); - THArgCheck(k > 0 && k <= sliceSize, 5, "k not in range for dimension"); + THArgCheck(k >= 0 && k <= sliceSize, 5, "k not in range for dimension"); THCTensor *input = THCTensor_(newContiguous)(state, input_); @@ -110,14 +110,16 @@ THC_API void THCTensor_(topk)(THCState* state, \ RUN_DIM(INDEX_T); - // Based on required index size, run the algorithm with the - // appropriate index type - if (THCTensor_canUse32BitIndexMath(state, input) && - THCTensor_canUse32BitIndexMath(state, topK) && - THCTensor_canUse32BitIndexMath(state, indices)) { - RUN_T(uint32_t); - } else { - RUN_T(uint64_t); + if (THCTensor_nElement(state, input) > 0) { + // Based on required index size, run the algorithm with the + // appropriate index type + if (THCTensor_canUse32BitIndexMath(state, input) && + THCTensor_canUse32BitIndexMath(state, topK) && + THCTensor_canUse32BitIndexMath(state, indices)) { + RUN_T(uint32_t); + } else { + RUN_T(uint64_t); + } } #undef RUN_T #undef RUN_DIM diff --git a/caffe2/core/dispatch/DispatchTable.h b/caffe2/core/dispatch/DispatchTable.h index 0f119791dbfa0..223138bc12bf5 100644 --- a/caffe2/core/dispatch/DispatchTable.h +++ b/caffe2/core/dispatch/DispatchTable.h @@ -22,7 +22,7 @@ class ThreadsafeOperatorTable_ final { template void emplace(Key_&& key, void* value) { bool res = map_.write([&](ska::flat_hash_map& map) -> bool { - auto result = map->emplace(std::forward(key), value); + auto result = map.emplace(std::forward(key), value); return result.second; }); if (!res) { @@ -35,7 +35,7 @@ class ThreadsafeOperatorTable_ final { void erase(const Key& key) { auto num_removed = map_.write([&](ska::flat_hash_map& map) -> size_t { - return map->erase(key); + return map.erase(key); }); assert(num_removed <= 1); // This is not a multi-map if (num_removed == 0) { @@ -46,8 +46,8 @@ class ThreadsafeOperatorTable_ final { void* lookup(const Key& key) const { return map_.read([&](const ska::flat_hash_map& map) -> void* { - auto found = map->find(key); - if (found != map->end()) { + auto found = map.find(key); + if (found != map.end()) { return found->second; } else { return nullptr; diff --git a/caffe2/core/dispatch/LeftRight.h b/caffe2/core/dispatch/LeftRight.h index dc60a303c412c..a4f1dfde0f60d 100644 --- a/caffe2/core/dispatch/LeftRight.h +++ b/caffe2/core/dispatch/LeftRight.h @@ -17,7 +17,7 @@ class LeftRight { } template - auto read(F&& readFunc) -> typename std::result_of::type { + auto read(F&& readFunc) const -> typename std::result_of::type { auto localCounterIndex = counterIndex_.load(); ++counters_[localCounterIndex]; try { @@ -34,7 +34,7 @@ class LeftRight { template auto write(F&& writeFunc) -> typename std::result_of::type { std::unique_lock lock(mutex_); - uniqueWrite(std::forward(writeFunc)); + return uniqueWrite(std::forward(writeFunc)); } private: @@ -64,7 +64,7 @@ class LeftRight { std::mutex mutex_; std::atomic counterIndex_{0}; std::atomic dataIndex_{0}; - std::atomic counters_[2]; + mutable std::atomic counters_[2]; T data_[2]; }; diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpClasses.h b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpClasses.h index 1e8156abe4217..70490856b5eca 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpClasses.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpClasses.h @@ -659,336 +659,3 @@ class NHWC2NCHW : public NeuralNetOperator { private: }; - -class Int8Quantize : public NeuralNetOperator { - public: - Int8Quantize() : NeuralNetOperator(NNKind::Int8Quantize) {} - - ~Int8Quantize() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Quantize); - - private: -}; - -class Int8Dequantize : public NeuralNetOperator { - public: - Int8Dequantize() : NeuralNetOperator(NNKind::Int8Dequantize) {} - - ~Int8Dequantize() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Dequantize); - - private: -}; - -class Int8AveragePool : public NeuralNetOperator { - public: - Int8AveragePool() : NeuralNetOperator(NNKind::Int8AveragePool) {} - - Int8AveragePool(const AveragePool& averagePool) - : NeuralNetOperator(NNKind::Int8AveragePool) {} - - ~Int8AveragePool() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8AveragePool); - - private: -}; - -class Int8Conv : public NeuralNetOperator { - public: - Int8Conv() : NeuralNetOperator(NNKind::Int8Conv) {} - - Int8Conv(const Conv& conv) : NeuralNetOperator(NNKind::Int8Conv) {} - - ~Int8Conv() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Conv); - - private: -}; - -class Int8ConvTranspose : public NeuralNetOperator { - public: - Int8ConvTranspose() : NeuralNetOperator(NNKind::Int8ConvTranspose) {} - - Int8ConvTranspose(const ConvTranspose& convTranspose) - : NeuralNetOperator(NNKind::Int8ConvTranspose) {} - - ~Int8ConvTranspose() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8ConvTranspose); - - private: -}; - -class Int8FC : public NeuralNetOperator { - public: - Int8FC() : NeuralNetOperator(NNKind::Int8FC) {} - - Int8FC(const FC& fC) : NeuralNetOperator(NNKind::Int8FC) {} - - ~Int8FC() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8FC); - - private: -}; - -class Int8MaxPool : public NeuralNetOperator { - public: - Int8MaxPool() : NeuralNetOperator(NNKind::Int8MaxPool) {} - - Int8MaxPool(const MaxPool& maxPool) - : NeuralNetOperator(NNKind::Int8MaxPool) {} - - ~Int8MaxPool() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8MaxPool); - - private: -}; - -class Int8Relu : public NeuralNetOperator { - public: - Int8Relu() : NeuralNetOperator(NNKind::Int8Relu) {} - - Int8Relu(const Relu& relu) : NeuralNetOperator(NNKind::Int8Relu) {} - - ~Int8Relu() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Relu); - - private: -}; - -class Int8GivenTensorFill : public NeuralNetOperator { - public: - Int8GivenTensorFill() : NeuralNetOperator(NNKind::Int8GivenTensorFill) {} - - Int8GivenTensorFill(const GivenTensorFill& givenTensorFill) - : NeuralNetOperator(NNKind::Int8GivenTensorFill) {} - - ~Int8GivenTensorFill() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8GivenTensorFill); - - private: -}; - -class Int8Concat : public NeuralNetOperator { - public: - Int8Concat() : NeuralNetOperator(NNKind::Int8Concat) {} - - Int8Concat(const Concat& concat) : NeuralNetOperator(NNKind::Int8Concat) {} - - ~Int8Concat() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Concat); - - private: -}; - -class Int8Softmax : public NeuralNetOperator { - public: - Int8Softmax() : NeuralNetOperator(NNKind::Int8Softmax) {} - - Int8Softmax(const Softmax& softmax) - : NeuralNetOperator(NNKind::Int8Softmax) {} - - ~Int8Softmax() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Softmax); - - private: -}; - -class Int8ChannelShuffle : public NeuralNetOperator { - public: - Int8ChannelShuffle() : NeuralNetOperator(NNKind::Int8ChannelShuffle) {} - - Int8ChannelShuffle(const ChannelShuffle& channelShuffle) - : NeuralNetOperator(NNKind::Int8ChannelShuffle) {} - - ~Int8ChannelShuffle() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8ChannelShuffle); - - private: -}; - -class Int8Sum : public NeuralNetOperator { - public: - Int8Sum() : NeuralNetOperator(NNKind::Int8Sum) {} - - Int8Sum(const Sum& sum) : NeuralNetOperator(NNKind::Int8Sum) {} - - ~Int8Sum() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Sum); - - private: -}; - -class Int8Add : public NeuralNetOperator { - public: - Int8Add() : NeuralNetOperator(NNKind::Int8Add) {} - - Int8Add(const Add& add) : NeuralNetOperator(NNKind::Int8Add) {} - - ~Int8Add() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Add); - - private: -}; - -class Int8Reshape : public NeuralNetOperator { - public: - Int8Reshape() : NeuralNetOperator(NNKind::Int8Reshape) {} - - Int8Reshape(const Reshape& reshape) - : NeuralNetOperator(NNKind::Int8Reshape) {} - - ~Int8Reshape() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Reshape); - - private: -}; - -class Int8Flatten : public NeuralNetOperator { - public: - Int8Flatten() : NeuralNetOperator(NNKind::Int8Flatten) {} - - Int8Flatten(const Flatten& flatten) - : NeuralNetOperator(NNKind::Int8Flatten) {} - - ~Int8Flatten() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8Flatten); - - private: -}; - -class Int8ConvRelu : public NeuralNetOperator { - public: - Int8ConvRelu() : NeuralNetOperator(NNKind::Int8ConvRelu) {} - - Int8ConvRelu(const ConvRelu& convRelu) - : NeuralNetOperator(NNKind::Int8ConvRelu) {} - - ~Int8ConvRelu() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8ConvRelu); - - private: -}; - -class Int8SumRelu : public NeuralNetOperator { - public: - Int8SumRelu() : NeuralNetOperator(NNKind::Int8SumRelu) {} - - Int8SumRelu(const SumRelu& sumRelu) - : NeuralNetOperator(NNKind::Int8SumRelu) {} - - ~Int8SumRelu() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8SumRelu); - - private: -}; - -class Int8AveragePoolRelu : public NeuralNetOperator { - public: - Int8AveragePoolRelu() : NeuralNetOperator(NNKind::Int8AveragePoolRelu) {} - - Int8AveragePoolRelu(const AveragePoolRelu& averagePoolRelu) - : NeuralNetOperator(NNKind::Int8AveragePoolRelu) {} - - ~Int8AveragePoolRelu() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8AveragePoolRelu); - - private: -}; - -class Int8MaxPoolRelu : public NeuralNetOperator { - public: - Int8MaxPoolRelu() : NeuralNetOperator(NNKind::Int8MaxPoolRelu) {} - - Int8MaxPoolRelu(const MaxPoolRelu& maxPoolRelu) - : NeuralNetOperator(NNKind::Int8MaxPoolRelu) {} - - ~Int8MaxPoolRelu() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(Int8MaxPoolRelu); - - private: -}; - -class BatchMatMul : public NeuralNetOperator { - public: - BatchMatMul(bool transA = false, bool transB = true, bool broadcast = false) - : NeuralNetOperator(NNKind::BatchMatMul), - TransA(transA), - TransB(transB), - Broadcast(broadcast) {} - - ~BatchMatMul() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(BatchMatMul); - - bool getTransA() const { - return TransA; - } - - bool getTransB() const { - return TransB; - } - - bool getBroadcast() const { - return Broadcast; - } - - void setTransA(bool transA) { - TransA = transA; - } - - void setTransB(bool transB) { - TransB = transB; - } - - void setBroadcast(bool broadcast) { - Broadcast = broadcast; - } - - private: - bool TransA; - bool TransB; - bool Broadcast; -}; - -class BatchGather : public NeuralNetOperator { - public: - BatchGather() : NeuralNetOperator(NNKind::BatchGather) {} - - ~BatchGather() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(BatchGather); - - private: -}; - -class ConcatBatchMatMulBatchGatherOp : public NeuralNetOperator { - public: - ConcatBatchMatMulBatchGatherOp() - : NeuralNetOperator(NNKind::ConcatBatchMatMulBatchGatherOp) {} - - ~ConcatBatchMatMulBatchGatherOp() {} - - NOMNIGRAPH_DEFINE_NN_RTTI(ConcatBatchMatMulBatchGatherOp); - - private: -}; diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpEnum.h b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpEnum.h index 9c4277293d0b4..4d15dd4061340 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpEnum.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpEnum.h @@ -1,9 +1,4 @@ Relu, Conv, ConvRelu, ConvTranspose, AveragePool, AveragePoolRelu, MaxPool, MaxPoolRelu, Sum, SumRelu, Send, Receive, BatchNormalization, FC, GivenTensorFill, Concat, Softmax, ChannelShuffle, Add, Reshape, Flatten, - NCHW2NHWC, NHWC2NCHW, Int8Quantize, Int8Dequantize, Int8AveragePool, - Int8Conv, Int8ConvTranspose, Int8FC, Int8MaxPool, Int8Relu, - Int8GivenTensorFill, Int8Concat, Int8Softmax, Int8ChannelShuffle, Int8Sum, - Int8Add, Int8Reshape, Int8Flatten, Int8ConvRelu, Int8SumRelu, - Int8AveragePoolRelu, Int8MaxPoolRelu, BatchMatMul, BatchGather, - ConcatBatchMatMulBatchGatherOp + NCHW2NHWC, NHWC2NCHW diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpNames.h b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpNames.h index 87ffda3c4f343..88ffa0b1ba6bb 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpNames.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Generated/OpNames.h @@ -1,92 +1,68 @@ case NNKind::Relu: return "Relu"; + case NNKind::Conv: return "Conv"; + case NNKind::ConvRelu: return "ConvRelu"; + case NNKind::ConvTranspose: return "ConvTranspose"; + case NNKind::AveragePool: return "AveragePool"; + case NNKind::AveragePoolRelu: return "AveragePoolRelu"; + case NNKind::MaxPool: return "MaxPool"; + case NNKind::MaxPoolRelu: return "MaxPoolRelu"; + case NNKind::Sum: return "Sum"; + case NNKind::SumRelu: return "SumRelu"; + case NNKind::Send: return "Send"; + case NNKind::Receive: return "Receive"; + case NNKind::BatchNormalization: return "BatchNormalization"; + case NNKind::FC: return "FC"; + case NNKind::GivenTensorFill: return "GivenTensorFill"; + case NNKind::Concat: return "Concat"; + case NNKind::Softmax: return "Softmax"; + case NNKind::ChannelShuffle: return "ChannelShuffle"; + case NNKind::Add: return "Add"; + case NNKind::Reshape: return "Reshape"; + case NNKind::Flatten: return "Flatten"; + case NNKind::NCHW2NHWC: return "NCHW2NHWC"; + case NNKind::NHWC2NCHW: return "NHWC2NCHW"; -case NNKind::Int8Quantize: - return "Int8Quantize"; -case NNKind::Int8Dequantize: - return "Int8Dequantize"; -case NNKind::Int8AveragePool: - return "Int8AveragePool"; -case NNKind::Int8Conv: - return "Int8Conv"; -case NNKind::Int8ConvTranspose: - return "Int8ConvTranspose"; -case NNKind::Int8FC: - return "Int8FC"; -case NNKind::Int8MaxPool: - return "Int8MaxPool"; -case NNKind::Int8Relu: - return "Int8Relu"; -case NNKind::Int8GivenTensorFill: - return "Int8GivenTensorFill"; -case NNKind::Int8Concat: - return "Int8Concat"; -case NNKind::Int8Softmax: - return "Int8Softmax"; -case NNKind::Int8ChannelShuffle: - return "Int8ChannelShuffle"; -case NNKind::Int8Sum: - return "Int8Sum"; -case NNKind::Int8Add: - return "Int8Add"; -case NNKind::Int8Reshape: - return "Int8Reshape"; -case NNKind::Int8Flatten: - return "Int8Flatten"; -case NNKind::Int8ConvRelu: - return "Int8ConvRelu"; -case NNKind::Int8SumRelu: - return "Int8SumRelu"; -case NNKind::Int8AveragePoolRelu: - return "Int8AveragePoolRelu"; -case NNKind::Int8MaxPoolRelu: - return "Int8MaxPoolRelu"; -case NNKind::BatchMatMul: - return "BatchMatMul"; -case NNKind::BatchGather: - return "BatchGather"; -case NNKind::ConcatBatchMatMulBatchGatherOp: - return "ConcatBatchMatMulBatchGatherOp"; diff --git a/caffe2/core/nomnigraph/op_gen.py b/caffe2/core/nomnigraph/op_gen.py index c62148ea52cff..2d1125f5762ad 100755 --- a/caffe2/core/nomnigraph/op_gen.py +++ b/caffe2/core/nomnigraph/op_gen.py @@ -6,6 +6,8 @@ from __future__ import unicode_literals import argparse +from textwrap import dedent +from subprocess import call def parse_lines(lines): @@ -22,25 +24,27 @@ def parse_lines(lines): index = 0 while index < len(lines): line = lines[index] - if line.lower().startswith('macro'): - assert (parse_state == EMPTY) - macro_line = line.split(' ') + if line.lower().startswith("macro"): + assert parse_state == EMPTY + macro_line = line.split(" ") # Support macros that look like attributes # e.g. macro - CONV_LIKE - curr_macro = ' '.join(macro_line[1:]) - assert (curr_macro not in macros) + curr_macro = " ".join(macro_line[1:]) + assert curr_macro not in macros, 'Macro "{}" defined twice.'.format( + curr_macro + ) macros[curr_macro] = [] parse_state = MACRO - lines = lines[:index] + lines[index + 1:] + lines = lines[:index] + lines[index + 1 :] continue - elif line.lower().startswith('endmacro'): - assert (parse_state == MACRO) + elif line.lower().startswith("endmacro"): + assert parse_state == MACRO parse_state = EMPTY - lines = lines[:index] + lines[index + 1:] + lines = lines[:index] + lines[index + 1 :] continue elif parse_state == MACRO: macros[curr_macro].append(line) - lines = lines[:index] + lines[index + 1:] + lines = lines[:index] + lines[index + 1 :] continue index += 1 @@ -48,7 +52,7 @@ def parse_lines(lines): while index < len(lines): line = lines[index] if line in macros: - lines = lines[:index] + macros[line] + lines[index + 1:] + lines = lines[:index] + macros[line] + lines[index + 1 :] index += len(macros[line]) - 1 index += 1 @@ -63,20 +67,20 @@ def parse_lines(lines): for line in lines: if not len(line): continue - if line[0] == '-': - assert (parse_state is OP) - attr = [_.strip() for _ in line[1:].split(':')] - assert (attr[0][0].isupper()) - if (len(attr) == 2): # attribute : type + if line[0] == "-": + assert parse_state is OP + attr = [_.strip() for _ in line[1:].split(":")] + assert attr[0][0].isupper() + if len(attr) == 2: # attribute : type ops[curr_op]["attributes"].append((attr[0], attr[1])) - elif (len(attr) == 3): # attribute : type + elif len(attr) == 3: # attribute : type ops[curr_op]["attributes"].append((attr[0], attr[1], attr[2])) else: - op = [l.strip() for l in line.split(':')] - assert (len(op[0].split(' ')) == 1) + op = [l.strip() for l in line.split(":")] + assert len(op[0].split(" ")) == 1 parse_state = OP curr_op = op[0] - assert (curr_op not in ops) + assert curr_op not in ops ops[curr_op] = {} op_list.append(curr_op) if len(op) > 1: @@ -101,20 +105,26 @@ def gen_class(op, op_def): attr_arg = "{type} {lower_name}".format( type=t, lower_name=lower_name + default_arg ) - attr_init = "{name}({lower_name})".format( - name=name, lower_name=lower_name - ) + attr_init = "{name}({lower_name})".format(name=name, lower_name=lower_name) attr_declare = "{type} {name};".format(type=t, name=name) - attr_get = """ - {type} get{name}() const {{ - return {name}; - }} -""".format(type=t, name=name) - attr_set = """ - void set{name}({type} {lower_name}) {{ - {name} = {lower_name}; - }} -""".format(type=t, name=name, lower_name=lower_name) + attr_get = dedent( + """ + {type} get{name}() const {{ + return {name}; + }} + """.format( + type=t, name=name + ) + ) + attr_set = dedent( + """ + void set{name}({type} {lower_name}) {{ + {name} = {lower_name}; + }} + """.format( + type=t, name=name, lower_name=lower_name + ) + ) attribute_args.append(attr_arg) attribute_init.append(attr_init) attribute_declarations.append(attr_declare) @@ -132,38 +142,43 @@ def gen_class(op, op_def): name=attr[0], other_op=lower_other_op ) ) - init = """ - {op}(const {other_op}& {lower_other_op}) : - {other_init} {{}} -""".format( - op=op, - other_op=other_op, - lower_other_op=lower_other_op, - other_init=',\n '.join(other_init) + init = dedent( + """ + {op}(const {other_op}& {lower_other_op}) : + {other_init} {{}} + """.format( + op=op, + other_op=other_op, + lower_other_op=lower_other_op, + other_init=",\n ".join(other_init), + ) ) extra_init += init - return """class {op} : public NeuralNetOperator {{ - public: - {op}({attribute_args}) : - {attribute_init} {{}} - {extra_init} - ~{op}() {{}} - - NOMNIGRAPH_DEFINE_NN_RTTI({op}); -{getters}{setters} - private: - {attribute_declarations} -}}; - -""".format( - op=op, - extra_init=extra_init, - getters=''.join(attribute_getters), - setters=''.join(attribute_setters), - attribute_args=',\n '.join(attribute_args), - attribute_init=',\n '.join(attribute_init), - attribute_declarations='\n '.join(attribute_declarations) + return dedent( + """ + class {op} : public NeuralNetOperator {{ + public: + {op}({attribute_args}) : + {attribute_init} {{}} + {extra_init} + ~{op}() {{}} + + NOMNIGRAPH_DEFINE_NN_RTTI({op}); + {getters}{setters} + private: + {attribute_declarations} + }}; + + """.format( + op=op, + extra_init=extra_init, + getters="".join(attribute_getters), + setters="".join(attribute_setters), + attribute_args=",\n".join(attribute_args), + attribute_init=",\n".join(attribute_init), + attribute_declarations="\n".join(attribute_declarations), + ) ) @@ -175,33 +190,51 @@ def gen_classes(ops, op_list): def gen_enum(op_list): - return ',\n'.join([op for op in op_list]) + '\n' + return ",\n".join([op for op in op_list]) + "\n" def gen_names(op_list): f = "" for op in op_list: - f += """case NNKind::{name}: - return \"{name}\"; -""".format(name=op) + f += dedent( + """ + case NNKind::{name}: + return \"{name}\"; + """.format( + name=op + ) + ) return f if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Generate op files.') - parser.add_argument('--install_dir', help='installation directory') - parser.add_argument('--source_def', help='ops.def') + parser = argparse.ArgumentParser(description="Generate op files.") + parser.add_argument("--install_dir", help="installation directory") + parser.add_argument("--source_def", help="ops.def", action="append") args = parser.parse_args() install_dir = args.install_dir + sources = args.source_def - with open(args.source_def, 'rb') as f: - lines = f.readlines() - lines = [l.strip().decode("utf-8") for l in lines] + lines = [] + for source in sources: + with open(source, "rb") as f: + lines_tmp = f.readlines() + lines += [l.strip().decode("utf-8") for l in lines_tmp] ops, op_list = parse_lines(lines) - with open(install_dir + '/OpClasses.h', 'wb') as f: + with open(install_dir + "/OpClasses.h", "wb") as f: f.write(gen_classes(ops, op_list).encode("utf-8")) - with open(install_dir + '/OpNames.h', 'wb') as f: + with open(install_dir + "/OpNames.h", "wb") as f: f.write(gen_names(op_list).encode("utf-8")) - with open(install_dir + '/OpEnum.h', 'wb') as f: + with open(install_dir + "/OpEnum.h", "wb") as f: f.write(gen_enum(op_list).encode("utf-8")) + + try: + cmd = ["clang-format", "-i", install_dir + "/OpClasses.h"] + call(cmd) + cmd = ["clang-format", "-i", install_dir + "/OpNames.h"] + call(cmd) + cmd = ["clang-format", "-i", install_dir + "/OpEnum.h"] + call(cmd) + except Exception: + pass diff --git a/caffe2/core/nomnigraph/ops.def b/caffe2/core/nomnigraph/ops.def index 53dd951c8fc1c..6183e3c25726a 100644 --- a/caffe2/core/nomnigraph/ops.def +++ b/caffe2/core/nomnigraph/ops.def @@ -69,30 +69,3 @@ CopyFromOpenCL NCHW2NHWC NHWC2NCHW -Int8Quantize -Int8Dequantize -Int8AveragePool : AveragePool -Int8Conv : Conv -Int8ConvTranspose : ConvTranspose -Int8FC : FC -Int8MaxPool : MaxPool -Int8Relu : Relu -Int8GivenTensorFill : GivenTensorFill -Int8Concat : Concat -Int8Softmax : Softmax -Int8ChannelShuffle : ChannelShuffle -Int8Sum : Sum -Int8Add : Add -Int8Reshape : Reshape -Int8Flatten : Flatten -Int8ConvRelu : ConvRelu -Int8SumRelu : SumRelu -Int8AveragePoolRelu : AveragePoolRelu -Int8MaxPoolRelu : MaxPoolRelu - -BatchMatMul -- TransA : bool : false -- TransB : bool : true -- Broadcast: bool : false -BatchGather -ConcatBatchMatMulBatchGatherOp diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index ac27cd7253b86..44eb9c7a430a8 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -44,7 +44,7 @@ class IDEEPFallbackOp final : public IDEEPOperator { IDEEPFallbackOp(const OperatorDef& def, Workspace* ws) : IDEEPOperator(def, ws) { CAFFE_ENFORCE_EQ(def.device_option().device_type(), IDEEP); - OperatorDef base_def_(def); + base_def_.CopyFrom(def); // base_def_ runs on CPU, so we will set its device option to CPU. // Copy to allow random_seed to be correctly propagated. base_def_.mutable_device_option()->CopyFrom(def.device_option()); @@ -134,21 +134,14 @@ class IDEEPFallbackOp final : public IDEEPOperator { dtensor->resize(dst_dims, itensor::data_type::f32); } dtensor->set_data_handle(const_cast(src.raw_data())); - } else if (src.template IsType()) { - Blob* dst = OperatorBase::OutputBlob(i); - if (!dst->template IsType()) { - dst->Reset(new itensor()); - } - - auto src_dims = src.dims(); - itensor::dims dst_dims(src_dims.begin(), src_dims.end()); - auto dtensor = dst->template GetMutable(); - if (dtensor->get_dims() != dst_dims) { - dtensor->resize(dst_dims, itensor::data_type::s32); - } - dtensor->set_data_handle(const_cast(src.raw_data())); } else { - CAFFE_THROW("ideep memory only supports float data type."); + VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor"; + auto src_dims = src.dims(); + Blob* dst = OperatorBase::OutputBlob(i); + dst->Reset(new Tensor()); + auto dtensor = dst->template GetMutable(); + dtensor->Resize(src_dims); + dtensor->ShareData(src); } } return true; @@ -159,6 +152,7 @@ class IDEEPFallbackOp final : public IDEEPOperator { vector local_output_blobs_; std::unique_ptr base_op_; std::unique_ptr local_ws_; + OperatorDef base_def_; }; } // namespace caffe2 diff --git a/caffe2/ideep/operators/utility_ops.cc b/caffe2/ideep/operators/utility_ops.cc index 9a2ec875d426b..67d7d2ca2d732 100644 --- a/caffe2/ideep/operators/utility_ops.cc +++ b/caffe2/ideep/operators/utility_ops.cc @@ -30,10 +30,22 @@ class CopyIDEEPToCPUOp final : public IDEEPOperator { USE_SIMPLE_IDEEP_CTOR_DTOR(CopyIDEEPToCPUOp); USE_IDEEP_DEF_ALIASES(); bool RunOnDevice() override { - const auto& X = OperatorBase::Input(0); - auto* Y = OperatorBase::Output(0); - Y->Resize(X.get_dims()); - X.reorder_to(Y->template mutable_data()); + const auto& input_blob = OperatorBase::InputBlob(0); + if (input_blob.template IsType()) { + VLOG(2) << "Directing sharing of TensorCPU"; + const auto& X = OperatorBase::Input(0); + auto* Y = OperatorBase::Output(0); + Y->CopyFrom(X); + } else { + const auto& X = OperatorBase::Input(0); + auto* Y = OperatorBase::Output(0); + Y->Resize(X.get_dims()); + if (X.get_data_type() == itensor::data_type::f32) { + X.reorder_to(Y->template mutable_data()); + } else { + CAFFE_THROW("Unsupported ideep type: ", X.get_data_type()); + } + } return true; } }; diff --git a/caffe2/operators/batch_matmul_op.h b/caffe2/operators/batch_matmul_op.h index 99df277482fee..e594f526a6bf6 100644 --- a/caffe2/operators/batch_matmul_op.h +++ b/caffe2/operators/batch_matmul_op.h @@ -255,7 +255,7 @@ class BatchMatMulOp final : public Operator { // TODO(T23893772): doing this in a loop is likely going to be slow on GPU for (size_t p = 0; p < num_outer_batches; ++p) { - math::GemmBatched( + math::GemmStridedBatched( trans_a_ ? CblasTrans : CblasNoTrans, trans_b_ ? CblasTrans : CblasNoTrans, num_sub_batches, @@ -264,11 +264,13 @@ class BatchMatMulOp final : public Operator { K, 1.0f, data_A + p * A_stride, + M * K, data_B + p * B_stride, + K * N, 0.0f, Y_data + p * Y_stride, - &context_, - use_scratch_ ? scratch_.get() : nullptr); + M * N, + &context_); } } return true; diff --git a/caffe2/operators/conv_op.h b/caffe2/operators/conv_op.h index 7153b14229960..efdc30f161a31 100644 --- a/caffe2/operators/conv_op.h +++ b/caffe2/operators/conv_op.h @@ -34,6 +34,26 @@ class ConvOp final : public ConvPoolOpBase { bool RunOnDeviceWithOrderNHWC() override; private: + bool Run1x1ConvOnDeviceWithOrderNCHW( + const int N, + const int C, + const int HxW, + const int M, + const T* X, + const T* filter, + const T* bias, + T* Y); + + bool Run1x1ConvOnDeviceWithOrderNHWC( + const int N, + const int C, + const int HxW, + const int M, + const T* X, + const T* filter, + const T* bias, + T* Y); + Tensor col_buffer_; Tensor bias_multiplier_; Tensor img_shape_device_; diff --git a/caffe2/operators/conv_op_impl.h b/caffe2/operators/conv_op_impl.h index 161887cfc5aa1..f8ad628c0ca72 100644 --- a/caffe2/operators/conv_op_impl.h +++ b/caffe2/operators/conv_op_impl.h @@ -2,12 +2,17 @@ #ifndef CAFFE2_OPERATORS_CONV_OP_IMPL_H_ #define CAFFE2_OPERATORS_CONV_OP_IMPL_H_ +#include "caffe2/operators/conv_op.h" + +#include +#include + #include "caffe2/core/context.h" #include "caffe2/core/flags.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" -#include "caffe2/operators/conv_op.h" #include "caffe2/operators/conv_pool_op_base.h" +#include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" namespace caffe2 { @@ -71,15 +76,25 @@ bool ConvOp::RunOnDeviceWithOrderNCHW() { // The col buffer is stored in CHW order as well - kernel_dim, and the height // and width. - const T* Xdata = X.template data(); + const T* X_data = X.template data(); + const T* filter_data = filter.template data(); + const T* bias_data = nullptr; if (InputSize() == 3) { const auto& bias = Input(BIAS); - CAFFE_ENFORCE(bias.ndim() == 1); - CAFFE_ENFORCE(bias.dim32(0) == M); + CAFFE_ENFORCE_EQ(bias.ndim(), 1); + CAFFE_ENFORCE_EQ(bias.dim32(0), M); + bias_data = bias.template data(); ConvPoolOpBase::template SetBiasMultiplier( output_image_size, &bias_multiplier_); } - T* Ydata = Y->template mutable_data(); + T* Y_data = Y->template mutable_data(); + + // Shortcut for 1x1 conv. + if (kernel_dims_size == 1 && !HasPad() && !HasStride()) { + const int HxW = X.size() / (N * C); + return Run1x1ConvOnDeviceWithOrderNCHW( + N, C, HxW, M, X_data, filter_data, bias_data, Y_data); + } auto f = [&](Tensor* col_buffer) { col_buffer->Resize(buffer_shape); @@ -102,7 +117,7 @@ bool ConvOp::RunOnDeviceWithOrderNCHW() { pad_r(), stride_h(), stride_w(), - Xdata + group_id * input_offset, + X_data + group_id * input_offset, col_buffer_data, &context_); } else { @@ -116,7 +131,7 @@ bool ConvOp::RunOnDeviceWithOrderNCHW() { stride_.data(), dilation_.data(), pads_.data(), - Xdata + group_id * input_offset, + X_data + group_id * input_offset, col_buffer_data, &context_); } @@ -128,16 +143,15 @@ bool ConvOp::RunOnDeviceWithOrderNCHW() { output_image_size, kernel_dim, 1, - filter.template data() + group_id * filter_offset, + filter_data + group_id * filter_offset, col_buffer_data, 0, - Ydata + group_id * output_offset, + Y_data + group_id * output_offset, &context_); } - if (InputSize() == 3) { + if (bias_data != nullptr) { // Bias term can be carried out outside the group definition // to be efficient. - auto* bias_data = Input(BIAS).template data(); math::Gemm( CblasNoTrans, CblasNoTrans, @@ -148,14 +162,13 @@ bool ConvOp::RunOnDeviceWithOrderNCHW() { bias_data, bias_multiplier_.template data(), 1, - Ydata, + Y_data, &context_); } - Xdata += input_offset * group_; - Ydata += output_offset * group_; + X_data += input_offset * group_; + Y_data += output_offset * group_; } }; - if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) { runWithSharedBuffer(ws_, f); } else { @@ -194,119 +207,215 @@ bool ConvOp::RunOnDeviceWithOrderNHWC() { const int output_image_size = Y->dim32(1) * Y->dim32(2); // The col buffer is stored in HWC order as well - kernel_dim, and the height // and width. - const T* Xdata = X.template data(); - T* Ydata = Y->template mutable_data(); + const T* X_data = X.template data(); + const T* filter_data = filter.template data(); + const T* bias_data = nullptr; + T* Y_data = Y->template mutable_data(); + if (InputSize() == 3) { + const auto& bias = Input(BIAS); + CAFFE_ENFORCE_EQ(bias.ndim(), 1); + CAFFE_ENFORCE_EQ(bias.dim32(0), M); + bias_data = bias.template data(); + } // Specialized path for 1 by 1 convolution with stride 1, pad 0 - we // can skip im2col. - if (kernel_dim == C && Y->dim32(1) == X.dim32(1) && - Y->dim32(2) == X.dim32(2) && stride_h() == 1 && stride_w() == 1 && - pad_t() == 0 && pad_b() == 0 && pad_l() == 0 && pad_r() == 0) { - math::Gemm( - CblasNoTrans, - CblasTrans, - N * H * W, - M, - C, - 1, - Xdata, - filter.template data(), - 0, - Ydata, - &context_); - if (InputSize() == 3) { - auto& bias = Input(BIAS); - CAFFE_ENFORCE(1 == bias.ndim()); - CAFFE_ENFORCE(bias.dim32(0) == M); - if (bias_multiplier_.size() != N * H * W) { - // If the helper bias multiplier is not M, reshape and fill it with one. - bias_multiplier_.Resize(vector(1, N * H * W)); - math::Set( - N * H * W, - static_cast(1), - bias_multiplier_.template mutable_data(), - &context_); - } + if (kernel_dim == C && !HasPad() && !HasStride()) { + const int HxW = X.size() / (N * C); + if (bias_data != nullptr) { + ConvPoolOpBase::template SetBiasMultiplier( + N * HxW, &bias_multiplier_); + } + return Run1x1ConvOnDeviceWithOrderNHWC( + N, C, HxW, M, X_data, filter_data, bias_data, Y_data); + } + + if (bias_data != nullptr) { + ConvPoolOpBase::template SetBiasMultiplier( + output_image_size, &bias_multiplier_); + } + auto f = [&](Tensor* col_buffer) { + col_buffer->Resize( + vector{Y->dim32(1), Y->dim32(2), kernel_h(), kernel_w(), C}); + T* col_buffer_data = col_buffer->template mutable_data(); + // Im2Col, followed by gemm. + for (int image_id = 0; image_id < N; ++image_id) { + math::Im2Col( + C, + H, + W, + kernel_h(), + kernel_w(), + dilation_h(), + dilation_w(), + pad_t(), + pad_l(), + pad_b(), + pad_r(), + stride_h(), + stride_w(), + X_data, + col_buffer_data, + &context_); + // Weight term math::Gemm( CblasNoTrans, - CblasNoTrans, - N * H * W, + CblasTrans, + output_image_size, M, + kernel_dim, 1, - 1, - bias_multiplier_.template data(), - bias.template data(), - 1, - Ydata, + col_buffer_data, + filter_data, + 0, + Y_data, &context_); - } - } else { - if (InputSize() == 3) { - const auto& bias = Input(BIAS); - CAFFE_ENFORCE(1 == bias.ndim()); - CAFFE_ENFORCE(bias.dim32(0) == M); - ConvPoolOpBase::template SetBiasMultiplier( - output_image_size, &bias_multiplier_); - } - auto f = [&](Tensor* col_buffer) { - col_buffer->Resize( - vector{Y->dim32(1), Y->dim32(2), kernel_h(), kernel_w(), C}); - T* col_buffer_data = col_buffer->template mutable_data(); - // Im2Col, followed by gemm. - for (int image_id = 0; image_id < N; ++image_id) { - math::Im2Col( - C, - H, - W, - kernel_h(), - kernel_w(), - dilation_h(), - dilation_w(), - pad_t(), - pad_l(), - pad_b(), - pad_r(), - stride_h(), - stride_w(), - Xdata, - col_buffer_data, - &context_); - // Weight term + if (bias_data != nullptr) { + // Bias term math::Gemm( CblasNoTrans, - CblasTrans, + CblasNoTrans, output_image_size, M, - kernel_dim, 1, - col_buffer_data, - filter.template data(), - 0, - Ydata, + 1, + bias_multiplier_.template data(), + bias_data, + 1, + Y_data, &context_); - if (InputSize() == 3) { - // Bias term - math::Gemm( - CblasNoTrans, - CblasNoTrans, - output_image_size, - M, - 1, - 1, - bias_multiplier_.template data(), - Input(BIAS).template data(), - 1, - Ydata, - &context_); - } - Xdata += input_offset; - Ydata += output_offset; } - }; - if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) { - runWithSharedBuffer(ws_, f); - } else { - f(&col_buffer_); + X_data += input_offset; + Y_data += output_offset; + } + }; + if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) { + runWithSharedBuffer(ws_, f); + } else { + f(&col_buffer_); + } + return true; +} + +template +bool ConvOp::Run1x1ConvOnDeviceWithOrderNCHW( + const int N, + const int C, + const int HxW, + const int M, + const T* X, + const T* filter, + const T* bias, + T* Y) { + const int G = group_; + if (G == 1) { + math::GemmStridedBatched( + CblasNoTrans, + CblasNoTrans, + N, + M, + HxW, + C, + 1.0f, + filter, + 0, + X, + C * HxW, + 0.0f, + Y, + M * HxW, + &context_); + } else { + const int batch_size = N * G; + const int D_X = C / G; + const int D_Y = M / G; + const int X_stride = D_X * HxW; + const int W_stride = D_Y * D_X; + const int Y_stride = D_Y * HxW; + std::vector X_ptr(N * G); + std::vector W_ptr(N * G); + std::vector Y_ptr(N * G); + for (int i = 0; i < N; ++i) { + for (int j = 0; j < G; ++j) { + const int index = i * G + j; + X_ptr[index] = X + index * X_stride; + W_ptr[index] = filter + j * W_stride; + Y_ptr[index] = Y + index * Y_stride; + } } + math::GemmBatched( + CblasNoTrans, + CblasNoTrans, + batch_size, + D_Y, + HxW, + D_X, + 1.0f, + W_ptr.data(), + X_ptr.data(), + 0.0f, + Y_ptr.data(), + &context_); + } + if (bias != nullptr) { + const T* bias_multiplier_data = bias_multiplier_.template data(); + math::GemmStridedBatched( + CblasNoTrans, + CblasNoTrans, + N, + M, + HxW, + 1, + 1.0f, + bias, + 0, + bias_multiplier_data, + 0, + 1.0f, + Y, + M * HxW, + &context_); + } + return true; +} + +template +bool ConvOp::Run1x1ConvOnDeviceWithOrderNHWC( + const int N, + const int C, + const int HxW, + const int M, + const T* X, + const T* filter, + const T* bias, + T* Y) { + const int G = group_; + CAFFE_ENFORCE_EQ(G, 1); + math::Gemm( + CblasNoTrans, + CblasTrans, + N * HxW, + M, + C, + 1.0f, + X, + filter, + 0.0f, + Y, + &context_); + if (bias != nullptr) { + const T* bias_multiplier_data = bias_multiplier_.template data(); + math::Gemm( + CblasNoTrans, + CblasNoTrans, + N * HxW, + M, + 1, + 1.0f, + bias_multiplier_data, + bias, + 1.0f, + Y, + &context_); } return true; } diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h index 5d7b003ae97fa..723304994c4e9 100644 --- a/caffe2/operators/conv_pool_op_base.h +++ b/caffe2/operators/conv_pool_op_base.h @@ -319,6 +319,22 @@ class ConvPoolOpBase : public Operator { } } + bool HasPad() const { + if (kernel_.size() == 2) { + return pad_t() > 0 || pad_b() > 0 || pad_l() > 0 || pad_r() > 0; + } + return std::any_of( + pads_.cbegin(), pads_.cend(), [](const int x) { return x > 0; }); + } + + bool HasStride() const { + if (kernel_.size() == 2) { + return stride_h() > 1 || stride_w() > 1; + } + return std::any_of( + stride_.cbegin(), stride_.cend(), [](const int x) { return x > 1; }); + } + void SetDeviceTensor(const std::vector& data, Tensor* tensor) { bool reset_tensor_device_ = false; @@ -719,36 +735,38 @@ class ConvPoolOpBase : public Operator { } private: - inline void AllocateAndCopy(const vector& vec, Tensor& tensor) { - tensor.Resize(vec.size()); - context_.template Copy( - vec.size(), vec.data(), tensor.template mutable_data()); - } - -#define USE_CONV_POOL_BASE_FUNCTIONS(Context) \ - USE_OPERATOR_FUNCTIONS(Context); \ - using ConvPoolOpBase::pads_; \ - using ConvPoolOpBase::pad_t; \ - using ConvPoolOpBase::pad_l; \ - using ConvPoolOpBase::pad_b; \ - using ConvPoolOpBase::pad_r; \ - using ConvPoolOpBase::legacy_pad_; \ - using ConvPoolOpBase::global_pooling_; \ - using ConvPoolOpBase::kernel_; \ - using ConvPoolOpBase::kernel_h; \ - using ConvPoolOpBase::kernel_w; \ - using ConvPoolOpBase::dilation_; \ - using ConvPoolOpBase::dilation_h; \ - using ConvPoolOpBase::dilation_w; \ - using ConvPoolOpBase::stride_; \ - using ConvPoolOpBase::stride_h; \ - using ConvPoolOpBase::stride_w; \ - using ConvPoolOpBase::group_; \ - using ConvPoolOpBase::order_; \ - using ConvPoolOpBase::shared_buffer_; \ - using ConvPoolOpBase::GetDims; \ - using ConvPoolOpBase::GetDimsSize; \ - using ConvPoolOpBase::SetDeviceTensor; \ + inline void AllocateAndCopy(const vector& vec, Tensor& tensor) { + tensor.Resize(vec.size()); + context_.template Copy( + vec.size(), vec.data(), tensor.template mutable_data()); + } + +#define USE_CONV_POOL_BASE_FUNCTIONS(Context) \ + USE_OPERATOR_FUNCTIONS(Context); \ + using ConvPoolOpBase::pads_; \ + using ConvPoolOpBase::pad_t; \ + using ConvPoolOpBase::pad_l; \ + using ConvPoolOpBase::pad_b; \ + using ConvPoolOpBase::pad_r; \ + using ConvPoolOpBase::legacy_pad_; \ + using ConvPoolOpBase::global_pooling_; \ + using ConvPoolOpBase::kernel_; \ + using ConvPoolOpBase::kernel_h; \ + using ConvPoolOpBase::kernel_w; \ + using ConvPoolOpBase::dilation_; \ + using ConvPoolOpBase::dilation_h; \ + using ConvPoolOpBase::dilation_w; \ + using ConvPoolOpBase::stride_; \ + using ConvPoolOpBase::stride_h; \ + using ConvPoolOpBase::stride_w; \ + using ConvPoolOpBase::group_; \ + using ConvPoolOpBase::order_; \ + using ConvPoolOpBase::shared_buffer_; \ + using ConvPoolOpBase::GetDims; \ + using ConvPoolOpBase::GetDimsSize; \ + using ConvPoolOpBase::SetDeviceTensor; \ + using ConvPoolOpBase::HasPad; \ + using ConvPoolOpBase::HasStride; \ using ConvPoolOpBase::ws_ }; diff --git a/caffe2/operators/ctc_beam_search_decoder_op.cc b/caffe2/operators/ctc_beam_search_decoder_op.cc new file mode 100644 index 0000000000000..9dd426978b257 --- /dev/null +++ b/caffe2/operators/ctc_beam_search_decoder_op.cc @@ -0,0 +1,171 @@ +#include "caffe2/operators/ctc_beam_search_decoder_op.h" + +namespace caffe2 { + +namespace { + +template +const float* getTensorDataPtr(const Tensor& tensor, int t, int n) { + const auto& dims = tensor.dims(); + CAFFE_ENFORCE_EQ(dims.size(), 3); + int offset = (t * dims[1] + n) * dims[2]; + CAFFE_ENFORCE_LT(offset, tensor.size()); + return tensor.template data() + offset; +} + +} // namespace + +template <> +bool CTCBeamSearchDecoderOp::RunOnDevice() { + // shape: max_activation_length x batch_size x alphabet_size + auto& inputs = Input(INPUTS); + // shape: batch_size + auto* output_len = Output(OUTPUT_LEN); + // shape: sum over all decoded_length + auto* values = Output(VALUES); + + const auto& inputs_dims = inputs.dims(); + int32_t max_activation_length = inputs_dims[0]; + int32_t batch_size = inputs_dims[1]; + int32_t alphabet_size = inputs_dims[2]; + // [batch_size] + const int* seq_len_data = + (InputSize() == 2) ? Input(SEQ_LEN).data() : nullptr; + + vector values_cache; + output_len->Resize(vector{batch_size}); + int* output_len_data = output_len->mutable_data(); + + for (int32_t i = 0; i < batch_size; ++i) { + const int32_t activation_length = + (seq_len_data) ? seq_len_data[i] : max_activation_length; + std::multimap, std::greater> A_next_inv; + // For a given time step, Pb maps prefixes to the probability of all + // candidate sequences that end in a blank and Pnb maps prefixes to the + // probability of all candidate sequences that don't end in a blank. + vector, float>> Pb( + activation_length + 1, std::map, float>()); + vector, float>> Pnb( + activation_length + 1, std::map, float>()); + set> A_prev; + Pb[0][vector()] = 1; + Pnb[0][vector()] = 0; + A_prev.insert(vector()); + + for (int t = 0; t < activation_length; t++) { + const float* ctc = getTensorDataPtr(inputs, t, i); + + vector pruned_alpha; + for (int32_t c = 0; c < alphabet_size; c++) { + if (ctc[c] > prune_threshold_) { + pruned_alpha.push_back(c); + } + } + + // If the pruned alphabet is empty, don't use pruning. + if (pruned_alpha.size() == 0) { + pruned_alpha = vector(alphabet_size); + std::iota(pruned_alpha.begin(), pruned_alpha.end(), 0); + } + + for (auto const& l : A_prev) { + // We skip the code handling the end character from the article since + // our system does not support an end character. + + for (auto const c : pruned_alpha) { + // Assumption: blank character always mapped to index 0 + if (c == 0) { + Pb[t + 1][l] += ctc[c] * (Pb[t][l] + Pnb[t][l]); + } else { + vector l_plus = vector(l); + l_plus.push_back(c); + if (l.size() > 0 && c == l.back()) { + Pnb[t + 1][l_plus] += ctc[c] * Pb[t][l]; + Pnb[t + 1][l] += ctc[c] * Pnb[t][l]; + } else { + Pnb[t + 1][l_plus] += ctc[c] * (Pb[t][l] + Pnb[t][l]); + } + + if (A_prev.find(l_plus) == A_prev.end()) { + Pb[t + 1][l_plus] += ctc[0] * (Pb[t][l_plus] + Pnb[t][l_plus]); + Pnb[t + 1][l_plus] += ctc[c] * Pnb[t][l_plus]; + } + } + } + } + + std::map, float> A_next(Pb[t + 1]); + for (auto& it : Pnb[t + 1]) { + A_next[it.first] += it.second; + } + A_next_inv.clear(); + for (auto& it : A_next) { + A_next_inv.insert({it.second, it.first}); + } + + A_prev.clear(); + auto it = A_next_inv.begin(); + for (int j = 0; j < beam_width_; j++) { + if (it == A_next_inv.end()) { + break; + } + A_prev.insert(it->second); + it++; + } + } + + vector decoded = + (A_next_inv.empty()) ? vector() : A_next_inv.begin()->second; + + output_len_data[i] = decoded.size(); + values_cache.insert(values_cache.end(), decoded.begin(), decoded.end()); + } + + int32_t cache_size = values_cache.size(); + values->Resize(vector{cache_size}); + int* values_data = values->mutable_data(); + for (int i = 0; i < values_cache.size(); ++i) { + values_data[i] = values_cache.at(i); + } + values_cache.clear(); + + return true; +} + +REGISTER_CPU_OPERATOR(CTCBeamSearchDecoder, CTCBeamSearchDecoderOp); +OPERATOR_SCHEMA(CTCBeamSearchDecoder) + .NumInputs(1, 2) + .NumOutputs(2) + .SetDoc( + "Prefix beam search decoder for connectionist temporal classification.") + .Arg( + "beam_width", + "Maximum number of candidates to carry over to next activation step.") + .Arg( + "prune_threshold", + "Probability threshold below which outputs are ignored.") + .Input( + 0, + "INPUTS", + "3D float Tensor sized [max_activation_length, batch_size, alphabet_size] " + "of network logits (before softmax application).") + .Input( + 1, + "SEQ_LEN", + "(optional) 1D int vector containing sequence lengths, " + "having size [batch_size] " + "seq_len will be set to max_time if not provided.") + .Output( + 0, + "OUTPUT_LEN", + "Output_len matrix size (batch_size). " + "Each index stores final output length of its corresponding batch item.") + .Output( + 1, + "VALUES", + "Values vector, size (total_decoded_outputs). " + "The flattened vector of final output sequences, in batch order.") + .InheritOnnxSchema("CTCBeamSearchDecoder"); +SHOULD_NOT_DO_GRADIENT(CTCBeamSearchDecoder); + +} // namespace caffe2 diff --git a/caffe2/operators/ctc_beam_search_decoder_op.h b/caffe2/operators/ctc_beam_search_decoder_op.h new file mode 100644 index 0000000000000..3e087505878d9 --- /dev/null +++ b/caffe2/operators/ctc_beam_search_decoder_op.h @@ -0,0 +1,32 @@ +#ifndef CAFFE2_OPERATORS_CTC_BEAM_SEARCH_OP_H_ +#define CAFFE2_OPERATORS_CTC_BEAM_SEARCH_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class CTCBeamSearchDecoderOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + CTCBeamSearchDecoderOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws) { + beam_width_ = OperatorBase::GetSingleArgument("beam_width", 10); + prune_threshold_ = + OperatorBase::GetSingleArgument("prune_threshold", 0.001); + } + + bool RunOnDevice() override; + + protected: + int32_t beam_width_; + float prune_threshold_; + INPUT_TAGS(INPUTS, SEQ_LEN); + OUTPUT_TAGS(OUTPUT_LEN, VALUES); + // Input: X, 3D tensor; L, 1D tensor. Output: Y sparse tensor +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_CTC_BEAM_SEARCH_OP_H_ diff --git a/caffe2/operators/locally_connected_op_impl.h b/caffe2/operators/locally_connected_op_impl.h index b43b0c351c755..4d7762fccbbb3 100644 --- a/caffe2/operators/locally_connected_op_impl.h +++ b/caffe2/operators/locally_connected_op_impl.h @@ -246,7 +246,7 @@ void LocallyConnectedOp::RunOnDeviceWithOrderNCHWImpl( column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); - math::GemmBatched( + math::GemmStridedBatched( CblasNoTrans, CblasNoTrans, shape.output_image_size * group_, @@ -255,9 +255,12 @@ void LocallyConnectedOp::RunOnDeviceWithOrderNCHWImpl( shape.kernel_size, 1.0f, filter_data, + shape.M / group_ * shape.kernel_size, column_transposed_buffer->template data(), + shape.kernel_size * shape.N, 0.0f, Y_transposed_buffer_data, + shape.M / group_ * shape.N, &context_); if (bias_data != nullptr) { math::Gemm( @@ -325,7 +328,7 @@ void LocallyConnectedOp::RunOnDeviceWithOrderNHWCImpl( column_buffer->template data(), column_transposed_buffer->template mutable_data(), &context_); - math::GemmBatched( + math::GemmStridedBatched( CblasNoTrans, CblasTrans, shape.output_image_size, @@ -334,9 +337,12 @@ void LocallyConnectedOp::RunOnDeviceWithOrderNHWCImpl( shape.kernel_size, 1.0f, column_transposed_buffer->template data(), + shape.N * shape.kernel_size, filter_data, + shape.kernel_size * shape.M, 0.0f, Y_transposed_buffer_data, + shape.N * shape.M, &context_); math::Transpose( shape.Y_transposed_dims.size(), @@ -612,7 +618,7 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHWImpl( &context_); // Gradient respect to filter. - math::GemmBatched( + math::GemmStridedBatched( CblasNoTrans, CblasTrans, shape.output_image_size * group_, @@ -621,9 +627,12 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHWImpl( shape.N, 1.0f, dY_transposed_buffer_data, + shape.M / group_ * shape.N, column_transposed_buffer->template data(), + shape.N * shape.kernel_size, 0.0f, dfilter_data, + shape.M / group_ * shape.kernel_size, &context_); if (dbias_data != nullptr) { @@ -642,7 +651,7 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHWImpl( if (dX_data != nullptr) { // Gradient respect to X. - math::GemmBatched( + math::GemmStridedBatched( CblasTrans, CblasNoTrans, shape.output_image_size * group_, @@ -651,9 +660,12 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNCHWImpl( shape.M / group_, 1.0f, filter_data, + shape.kernel_size * shape.M / group_, dY_transposed_buffer_data, + shape.M / group_ * shape.N, 0.0f, column_transposed_buffer->template mutable_data(), + shape.kernel_size * shape.N, &context_); math::Transpose( shape.column_transposed_dims.size(), @@ -760,7 +772,7 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWCImpl( &context_); // Gradient respect to filter. - math::GemmBatched( + math::GemmStridedBatched( CblasTrans, CblasNoTrans, shape.output_image_size, @@ -769,9 +781,12 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWCImpl( shape.N, 1.0f, dY_transposed_buffer_data, + shape.M * shape.N, column_transposed_buffer->template data(), + shape.N * shape.kernel_size, 0.0f, dfilter_data, + shape.M * shape.kernel_size, &context_); if (dbias_data != nullptr) { @@ -790,7 +805,7 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWCImpl( if (dX_data != nullptr) { // Gradient respect to X. - math::GemmBatched( + math::GemmStridedBatched( CblasNoTrans, CblasNoTrans, shape.output_image_size, @@ -799,9 +814,12 @@ void LocallyConnectedGradientOp::RunOnDeviceWithOrderNHWCImpl( shape.M, 1.0f, dY_transposed_buffer_data, + shape.N * shape.M, filter_data, + shape.M * shape.kernel_size, 0.0f, column_transposed_buffer->template mutable_data(), + shape.N * shape.kernel_size, &context_); math::Transpose( shape.column_transposed_dims.size(), diff --git a/caffe2/opt/converter.cc b/caffe2/opt/converter.cc index b4866618b4e60..69f4503b6da88 100644 --- a/caffe2/opt/converter.cc +++ b/caffe2/opt/converter.cc @@ -146,9 +146,6 @@ REGISTER_CONVERTER(SpatialBN, BatchNormalizationConverter); TRIVIAL_CONVERTER(Flatten); REGISTER_CONVERTER(Flatten, FlattenConverter); -TRIVIAL_CONVERTER(BatchGather); -REGISTER_CONVERTER(BatchGather, BatchGatherConverter); - class AveragePoolConverter : public Converter { std::unique_ptr convertToNeuralNetOperator( const OperatorDef& op) override { @@ -205,37 +202,6 @@ class ConcatConverter : public Converter { }; REGISTER_CONVERTER(Concat, ConcatConverter); -class BatchMatMulConverter : public Converter { - std::unique_ptr convertToNeuralNetOperator( - const OperatorDef& op) override { - std::unique_ptr nnOp = - util::make_unique(); - auto argMap = getArgumentsFromOperator(op); - - auto c = dyn_cast(nnOp.get()); - if (argMap.count("trans_a")) { - CAFFE_ENFORCE(argMap["trans_a"].has_i(), "Invalid axis argument"); - int trans_a = static_cast(argMap["trans_a"].i()); - c->setTransA(!!trans_a); - } - if (argMap.count("trans_b")) { - CAFFE_ENFORCE(argMap["trans_b"].has_i(), "Invalid add_axis argument"); - int trans_b = static_cast(argMap["trans_b"].i()); - c->setTransB(!!trans_b); - } - if (argMap.count("broadcast")) { - CAFFE_ENFORCE(argMap["broadcast"].has_i(), "Invalid add_axis argument"); - int broadcast = static_cast(argMap["broadcast"].i()); - c->setBroadcast(!!broadcast); - } - return nnOp; - } - // Does not override default converter to OperatorDef - - virtual ~BatchMatMulConverter() {} -}; -REGISTER_CONVERTER(BatchMatMul, BatchMatMulConverter); - } // namespace std::unique_ptr convertToNeuralNetOperator( diff --git a/caffe2/python/ideep/copy_op_test.py b/caffe2/python/ideep/copy_op_test.py index 55d243bc4999e..9599e994b679c 100644 --- a/caffe2/python/ideep/copy_op_test.py +++ b/caffe2/python/ideep/copy_op_test.py @@ -16,9 +16,9 @@ def _get_deep_device(self): def test_copy_to_ideep(self): op = core.CreateOperator( - "CopyCPUToIDEEP", - ["X"], - ["X_ideep"], + "CopyCPUToIDEEP", + ["X"], + ["X_ideep"], ) op.device_option.CopyFrom(self._get_deep_device()) n = randint(1, 128) @@ -33,9 +33,9 @@ def test_copy_to_ideep(self): def test_copy_from_ideep(self): op = core.CreateOperator( - "CopyIDEEPToCPU", - ["X_ideep"], - ["X"], + "CopyIDEEPToCPU", + ["X_ideep"], + ["X"], ) op.device_option.CopyFrom(self._get_deep_device()) n = randint(1, 128) @@ -48,3 +48,18 @@ def test_copy_from_ideep(self): X_ideep = workspace.FetchBlob("X") np.testing.assert_allclose(X, X_ideep) + def test_copy_from_ideep_fallthrough(self): + op = core.CreateOperator( + "CopyIDEEPToCPU", + ["X_ideep"], + ["X"],) + op.device_option.CopyFrom(self._get_deep_device()) + n = randint(1, 128) + c = randint(1, 64) + h = randint(1, 128) + w = randint(1, 128) + X = np.random.rand(n, c, h, w).astype(np.float32) + workspace.FeedBlob("X_ideep", X) + workspace.RunOperatorOnce(op) + X_ideep = workspace.FetchBlob("X") + np.testing.assert_allclose(X, X_ideep) diff --git a/caffe2/python/layer_model_helper.py b/caffe2/python/layer_model_helper.py index ce709b4e42c8f..750f6aaf98fa3 100644 --- a/caffe2/python/layer_model_helper.py +++ b/caffe2/python/layer_model_helper.py @@ -107,6 +107,9 @@ def add_metric_field(self, name, value): ) def add_ad_hoc_plot_blob(self, blob, dtype=None): + assert isinstance( + blob, (six.string_types, core.BlobReference) + ), "expect type str or BlobReference, but got {}".format(type(blob)) dtype = dtype or (np.float, (1, )) self.add_metric_field(str(blob), schema.Scalar(dtype, blob)) self.ad_hoc_plot_blobs.append(blob) diff --git a/caffe2/python/operator_test/conv_test.py b/caffe2/python/operator_test/conv_test.py index 1798c211373a1..3210f180c81b1 100644 --- a/caffe2/python/operator_test/conv_test.py +++ b/caffe2/python/operator_test/conv_test.py @@ -224,8 +224,8 @@ def test_convolution_gradients(self, op_type, stride, pad, kernel, dilation, self.assertGradientChecks(gc, op, inputs, i, [0]) def _nd_convolution_nchw(self, n, input_channels, output_channels, - batch_size, stride, size, kernel, dilation, pad, - use_bias, gc, dc): + batch_size, stride, size, kernel, dilation, pad, + use_bias, gc, dc): dkernel = dilation * (kernel - 1) + 1 for op_type in ["Conv", "Conv" + str(n) + "D"]: op = core.CreateOperator( @@ -546,6 +546,55 @@ def test_use_cudnn_engine_interactions(self): self.assertEqual(model.Proto().op[-1].engine, expected_engine) + @given(op_type=st.sampled_from(["Conv", "Conv2D"]), N=st.integers(1, 4), + G=st.integers(1, 4), DX=st.integers(1, 4), DY=st.integers(1, 4), + H=st.integers(1, 4), W=st.integers(1, 4), use_bias=st.booleans(), + **hu.gcs) + def test_1x1_conv(self, op_type, N, G, DX, DY, H, W, use_bias, gc, dc): + op = core.CreateOperator( + op_type, + ["X", "filter", "bias"] if use_bias else ["X", "filter"], + ["Y"], + stride_h=1, + stride_w=1, + pad_t=0, + pad_l=0, + pad_b=0, + pad_r=0, + kernel=1, + order="NCHW", + group=G, + ) + C = G * DX + M = G * DY + X = np.random.randn(N, C, H, W).astype(np.float32) + filter = np.random.randn(M, DX, 1, 1).astype(np.float32) + bias = np.random.randn(M).astype(np.float32) + inputs = [X, filter, bias] if use_bias else [X, filter] + + def conv_1x1_ref(X, filter, bias=None): + X = X.reshape(N, G, DX, -1) + filter = filter.reshape(G, DY, DX) + Y = np.zeros(shape=(N, G, DY, H * W), dtype=np.float32) + for i in range(N): + for j in range(G): + Y[i, j, :, :] = np.dot(filter[j, :, :], X[i, j, :, :]) + Y = Y.reshape(N, M, H, W) + if bias is not None: + bias = bias.reshape(1, M, 1, 1) + Y = np.add(Y, bias) + return [Y] + + self.assertReferenceChecks( + device_option=gc, + op=op, + inputs=inputs, + reference=conv_1x1_ref, + ) + self.assertDeviceChecks(dc, op, inputs, [0]) + for i in range(len(inputs)): + self.assertGradientChecks(gc, op, inputs, i, [0]) + if __name__ == "__main__": import unittest diff --git a/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py b/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py new file mode 100644 index 0000000000000..51d2bbc6f484a --- /dev/null +++ b/caffe2/python/operator_test/ctc_beam_search_decoder_op_test.py @@ -0,0 +1,126 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from caffe2.python import core +from collections import defaultdict, Counter +from hypothesis import given +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np + +import unittest + +DEFAULT_BEAM_WIDTH = 10 +DEFAULT_PRUNE_THRESHOLD = 0.001 + + +class TestCTCBeamSearchDecoderOp(hu.HypothesisTestCase): + + @given( + batch=st.sampled_from([1, 2, 4]), + max_time=st.sampled_from([1, 8, 64]), + alphabet_size=st.sampled_from([1, 2, 32, 128, 512]), + beam_width=st.sampled_from([1, 2, 16, None]), + **hu.gcs_cpu_only + ) + def test_ctc_beam_search_decoder( + self, batch, max_time, + alphabet_size, beam_width, gc, dc + ): + if not beam_width: + beam_width = DEFAULT_BEAM_WIDTH + op_seq_len = core.CreateOperator('CTCBeamSearchDecoder', + ['INPUTS', 'SEQ_LEN'], + ['OUTPUT_LEN', 'VALUES']) + + op_no_seq_len = core.CreateOperator('CTCBeamSearchDecoder', + ['INPUTS'], + ['OUTPUT_LEN', 'VALUES']) + else: + op_seq_len = core.CreateOperator('CTCBeamSearchDecoder', + ['INPUTS', 'SEQ_LEN'], + ['OUTPUT_LEN', 'VALUES'], + beam_width=beam_width) + + op_no_seq_len = core.CreateOperator('CTCBeamSearchDecoder', + ['INPUTS'], + ['OUTPUT_LEN', 'VALUES'], + beam_width=beam_width) + + def input_generater(): + inputs = np.random.rand(max_time, batch, alphabet_size)\ + .astype(np.float32) + seq_len = np.random.randint(1, max_time + 1, size=batch)\ + .astype(np.int32) + return inputs, seq_len + + def ref_ctc_decoder(inputs, seq_len): + output_len = np.array([]).astype(np.int32) + val = np.array([]).astype(np.int32) + + for i in range(batch): + Pb, Pnb = defaultdict(Counter), defaultdict(Counter) + Pb[0][()] = 1 + Pnb[0][()] = 0 + A_prev = [()] + ctc = inputs[:, i, :] + ctc = np.vstack((np.zeros(alphabet_size), ctc)) + len_i = seq_len[i] if seq_len is not None else max_time + + for t in range(1, len_i + 1): + pruned_alphabet = np.where(ctc[t] > DEFAULT_PRUNE_THRESHOLD)[0] + for l in A_prev: + for c in pruned_alphabet: + if c == 0: + Pb[t][l] += ctc[t][c] * (Pb[t - 1][l] + Pnb[t - 1][l]) + else: + l_plus = l + (c,) + if len(l) > 0 and c == l[-1]: + Pnb[t][l_plus] += ctc[t][c] * Pb[t - 1][l] + Pnb[t][l] += ctc[t][c] * Pnb[t - 1][l] + else: + Pnb[t][l_plus] += \ + ctc[t][c] * (Pb[t - 1][l] + Pnb[t - 1][l]) + + if l_plus not in A_prev: + Pb[t][l_plus] += \ + ctc[t][0] * \ + (Pb[t - 1][l_plus] + Pnb[t - 1][l_plus]) + Pnb[t][l_plus] += ctc[t][c] * Pnb[t - 1][l_plus] + + A_next = Pb[t] + Pnb[t] + A_prev = sorted(A_next, key=A_next.get, reverse=True) + A_prev = A_prev[:beam_width] + + decoded = A_prev[0] if len(A_prev) > 0 else () + + val = np.hstack((val, decoded)) + output_len = np.append(output_len, len(decoded)) + + return [output_len, val] + + def ref_ctc_decoder_max_time(inputs): + return ref_ctc_decoder(inputs, None) + + inputs, seq_len = input_generater() + + self.assertReferenceChecks( + device_option=gc, + op=op_seq_len, + inputs=[inputs, seq_len], + reference=ref_ctc_decoder, + ) + + self.assertReferenceChecks( + device_option=gc, + op=op_no_seq_len, + inputs=[inputs], + reference=ref_ctc_decoder_max_time, + ) + + +if __name__ == "__main__": + import random + random.seed(2603) + unittest.main() diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py index bbbcbf9b8a29a..1231c5e4ba8b7 100644 --- a/caffe2/python/operator_test/elementwise_ops_test.py +++ b/caffe2/python/operator_test/elementwise_ops_test.py @@ -450,6 +450,10 @@ def ref(A, B): B = np.random.rand(n, m, k, t).astype(np.float32) + bias self._run_single_test(op, ref, A, B, True, test_grad, gc, dc) + A = np.random.rand(1, m, k, 1).astype(np.float32) + bias + B = np.random.rand(n, m, k, t).astype(np.float32) + bias + self._run_single_test(op, ref, A, B, True, test_grad, gc, dc) + A = np.random.rand(m, 1, t).astype(np.float32) + bias B = np.random.rand(n, m, k, t).astype(np.float32) + bias self._run_single_test(op, ref, A, B, True, test_grad, gc, dc) @@ -574,6 +578,10 @@ def ref(A, B): B = np.random.randint(128, size=(n, m, k, t)) self._run_single_test(op, ref, A, B, True, False, gc, dc) + A = np.random.randint(128, size=(1, m, k, 1)) + B = np.random.randint(128, size=(n, m, k, t)) + self._run_single_test(op, ref, A, B, True, False, gc, dc) + A = np.random.randint(128, size=(m, 1, t)) B = np.random.randint(128, size=(n, m, k, t)) self._run_single_test(op, ref, A, B, True, False, gc, dc) diff --git a/caffe2/utils/hip/math_hip.cc b/caffe2/utils/hip/math_hip.cc index 9b2eef5421618..d37da68923429 100644 --- a/caffe2/utils/hip/math_hip.cc +++ b/caffe2/utils/hip/math_hip.cc @@ -29,19 +29,21 @@ namespace math { namespace { -#define DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Func, expr) \ - template struct Func##Functor { \ - inline __host__ __device__ T operator()(const T &lhs, \ - const T &rhs) const { \ - return lhs expr rhs; \ - } \ - }; \ - template <> struct Func##Functor { \ - inline __host__ __device__ float16 operator()(const float16 &lhs, \ - const float16 &rhs) const { \ - return convert::To(convert::To( \ - lhs) expr convert::To(rhs)); \ - } \ +#define DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Func, expr) \ + template \ + struct Func##Functor { \ + inline __host__ __device__ T \ + operator()(const T& lhs, const T& rhs) const { \ + return lhs expr rhs; \ + } \ + }; \ + template <> \ + struct Func##Functor { \ + inline __host__ __device__ float16 \ + operator()(const float16& lhs, const float16& rhs) const { \ + return convert::To(convert::To( \ + lhs) expr convert::To(rhs)); \ + } \ }; DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Add, +) DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Sub, -) @@ -50,16 +52,25 @@ DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR(Div, /) #undef DELEGATE_SIMPLE_HOST_DEVICE_BINARY_FUNCTOR template -__global__ void SimpleBinaryOpHIPKernel(const int N, const BinaryOperator op, - const TIn *A, const TIn *B, TOut *C) { - HIP_1D_KERNEL_LOOP(i, N) { C[i] = op(A[i], B[i]); } +__global__ void SimpleBinaryOpHIPKernel( + const int N, + const BinaryOperator op, + const TIn* A, + const TIn* B, + TOut* C) { + HIP_1D_KERNEL_LOOP(i, N) { + C[i] = op(A[i], B[i]); + } } template -__global__ void RowwiseBinaryOpHIPKernel(const int size, - const FixedDivisor cols, - const BinaryOperator op, const TIn *A, - const TIn *B, TOut *C) { +__global__ void RowwiseBinaryOpHIPKernel( + const int size, + const FixedDivisor cols, + const BinaryOperator op, + const TIn* A, + const TIn* B, + TOut* C) { HIP_1D_KERNEL_LOOP(C_index, size) { const int j = cols.Mod(C_index); const int A_index = broadcast_1st ? j : C_index; @@ -69,10 +80,13 @@ __global__ void RowwiseBinaryOpHIPKernel(const int size, } template -__global__ void ColwiseBinaryOpHIPKernel(const int size, - const FixedDivisor cols, - const BinaryOperator op, const TIn *A, - const TIn *B, TOut *C) { +__global__ void ColwiseBinaryOpHIPKernel( + const int size, + const FixedDivisor cols, + const BinaryOperator op, + const TIn* A, + const TIn* B, + TOut* C) { HIP_1D_KERNEL_LOOP(C_index, size) { const int i = cols.Div(C_index); const int A_index = broadcast_1st ? i : C_index; @@ -82,12 +96,15 @@ __global__ void ColwiseBinaryOpHIPKernel(const int size, } template -__global__ void -BroadcastBinaryOpHIPKernel(const int size, const SimpleArray A_strides, - const SimpleArray B_strides, - const SimpleArray, D> C_dims, - const BinaryOperator op, const TIn *A, const TIn *B, - TOut *C) { +__global__ void BroadcastBinaryOpHIPKernel( + const int size, + const SimpleArray A_strides, + const SimpleArray B_strides, + const SimpleArray, D> C_dims, + const BinaryOperator op, + const TIn* A, + const TIn* B, + TOut* C) { HIP_1D_KERNEL_LOOP(C_index, size) { int A_index = 0; int B_index = 0; @@ -104,15 +121,16 @@ BroadcastBinaryOpHIPKernel(const int size, const SimpleArray A_strides, } template -void BinaryOpWith2DBroadcasting(const int ndim, const int *dims, - const int pivot, const bool rowwise_broadcast, - const bool broadcast_1st, - const BinaryOperator &op, const TIn *A, - const TIn *B, TOut *C, HIPContext *context) { - const int rows = - std::accumulate(dims, dims + pivot, 1, std::multiplies()); - const int cols = - std::accumulate(dims + pivot, dims + ndim, 1, std::multiplies()); +void BinaryOpWith2DBroadcasting( + const int rows, + const int cols, + const bool rowwise_broadcast, + const bool broadcast_1st, + const BinaryOperator& op, + const TIn* A, + const TIn* B, + TOut* C, + HIPContext* context) { if (rows == 0 || cols == 0) { return; } @@ -122,34 +140,71 @@ void BinaryOpWith2DBroadcasting(const int ndim, const int *dims, if (broadcast_1st) { hipLaunchKernelGGL( (RowwiseBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), dim3(CAFFE_HIP_NUM_THREADS), 0, - context->hip_stream(), size, cols_div, op, A, B, C); + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + cols_div, + op, + A, + B, + C); } else { hipLaunchKernelGGL( (RowwiseBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), dim3(CAFFE_HIP_NUM_THREADS), 0, - context->hip_stream(), size, cols_div, op, A, B, C); + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + cols_div, + op, + A, + B, + C); } } else { if (broadcast_1st) { hipLaunchKernelGGL( (ColwiseBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), dim3(CAFFE_HIP_NUM_THREADS), 0, - context->hip_stream(), size, cols_div, op, A, B, C); + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + cols_div, + op, + A, + B, + C); } else { hipLaunchKernelGGL( (ColwiseBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), dim3(CAFFE_HIP_NUM_THREADS), 0, - context->hip_stream(), size, cols_div, op, A, B, C); + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + cols_div, + op, + A, + B, + C); } } } template -void BroadcastBinaryOpImpl(const int *A_dims, const int *B_dims, - const int *C_dims, const BinaryOperator &op, - const TIn *A, const TIn *B, TOut *C, - HIPContext *context) { +void BroadcastBinaryOpImpl( + const int* A_dims, + const int* B_dims, + const int* C_dims, + const BinaryOperator& op, + const TIn* A, + const TIn* B, + TOut* C, + HIPContext* context) { SimpleArray A_strides_array; SimpleArray B_strides_array; SimpleArray, D> C_dims_array; @@ -167,69 +222,122 @@ void BroadcastBinaryOpImpl(const int *A_dims, const int *B_dims, } const int size = std::accumulate(C_dims, C_dims + D, 1, std::multiplies()); - hipLaunchKernelGGL((BroadcastBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), dim3(CAFFE_HIP_NUM_THREADS), - 0, context->hip_stream(), size, A_strides_array, - B_strides_array, C_dims_array, op, A, B, C); + hipLaunchKernelGGL( + (BroadcastBinaryOpHIPKernel), + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + A_strides_array, + B_strides_array, + C_dims_array, + op, + A, + B, + C); } template -void BroadcastBinaryOp(const int A_ndim, const int *A_dims, const int B_ndim, - const int *B_dims, const BinaryOperator &op, - const TIn *A, const TIn *B, TOut *C, - HIPContext *context) { +void BroadcastBinaryOp( + const int A_ndim, + const int* A_dims, + const int B_ndim, + const int* B_dims, + const BinaryOperator& op, + const TIn* A, + const TIn* B, + TOut* C, + HIPContext* context) { const int ndim = std::max(A_ndim, B_ndim); std::vector A_dims_array(ndim); std::vector B_dims_array(ndim); std::vector C_dims_array(ndim); - utils::ComputeBroadcastBinaryOpDims(A_ndim, A_dims, B_ndim, B_dims, - A_dims_array.data(), B_dims_array.data(), - C_dims_array.data()); + utils::ComputeBroadcastBinaryOpDims( + A_ndim, + A_dims, + B_ndim, + B_dims, + A_dims_array.data(), + B_dims_array.data(), + C_dims_array.data()); if (A_dims_array == B_dims_array) { - const int size = std::accumulate(C_dims_array.cbegin(), C_dims_array.cend(), - 1, std::multiplies()); - hipLaunchKernelGGL((SimpleBinaryOpHIPKernel), - dim3(CAFFE_GET_BLOCKS(size)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - size, op, A, B, C); + const int size = std::accumulate( + C_dims_array.cbegin(), C_dims_array.cend(), 1, std::multiplies()); + hipLaunchKernelGGL( + (SimpleBinaryOpHIPKernel), + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + op, + A, + B, + C); return; } - int pivot; + int rows; + int cols; bool broadcast_1st; - if (utils::IsRowwiseBroadcastBinaryOp(ndim, A_dims_array.data(), - B_dims_array.data(), &pivot, - &broadcast_1st)) { + if (utils::IsRowwiseBroadcastBinaryOp( + ndim, + A_dims_array.data(), + B_dims_array.data(), + &rows, + &cols, + &broadcast_1st)) { BinaryOpWith2DBroadcasting( - ndim, C_dims_array.data(), pivot, true, broadcast_1st, op, A, B, C, - context); + rows, cols, true, broadcast_1st, op, A, B, C, context); return; } - if (utils::IsColwiseBroadcastBinaryOp(ndim, A_dims_array.data(), - B_dims_array.data(), &pivot, - &broadcast_1st)) { + if (utils::IsColwiseBroadcastBinaryOp( + ndim, + A_dims_array.data(), + B_dims_array.data(), + &rows, + &cols, + &broadcast_1st)) { BinaryOpWith2DBroadcasting( - ndim, C_dims_array.data(), pivot, false, broadcast_1st, op, A, B, C, - context); + rows, cols, false, broadcast_1st, op, A, B, C, context); return; } DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3( - ndim, BroadcastBinaryOpImpl, TIn, TOut, BinaryOperator, - A_dims_array.data(), B_dims_array.data(), C_dims_array.data(), op, A, B, - C, context); + ndim, + BroadcastBinaryOpImpl, + TIn, + TOut, + BinaryOperator, + A_dims_array.data(), + B_dims_array.data(), + C_dims_array.data(), + op, + A, + B, + C, + context); } } // namespace -#define DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(T, Func, op) \ - __global__ void Func##HIPKernel(const int N, const T *X, T *Y) { \ - HIP_1D_KERNEL_LOOP(i, N) { Y[i] = op(X[i]); } \ - } \ - template <> \ - void Func(const int N, const T *x, T *y, \ - HIPContext *context) { \ - hipLaunchKernelGGL((Func##HIPKernel), CAFFE_GET_BLOCKS(N), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), N, x, \ - y); \ +#define DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(T, Func, op) \ + __global__ void Func##HIPKernel(const int N, const T* X, T* Y) { \ + HIP_1D_KERNEL_LOOP(i, N) { \ + Y[i] = op(X[i]); \ + } \ + } \ + template <> \ + void Func( \ + const int N, const T* x, T* y, HIPContext* context) { \ + hipLaunchKernelGGL( \ + (Func##HIPKernel), \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + x, \ + y); \ } DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(float, Exp, expf) @@ -251,40 +359,61 @@ DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(float, Cbrt, cbrtf) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(float, Cube, utils::Cube) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(double, Cube, utils::Cube) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int32_t, Cube, - utils::Cube) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int64_t, Cube, - utils::Cube) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int32_t, + Cube, + utils::Cube) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int64_t, + Cube, + utils::Cube) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(bool, Not, utils::Not) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(float, Neg, utils::Negate) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(double, Neg, utils::Negate) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int32_t, Neg, - utils::Negate) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int64_t, Neg, - utils::Negate) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int32_t, + Neg, + utils::Negate) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int64_t, + Neg, + utils::Negate) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(float, Sign, utils::Sign) DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(double, Sign, utils::Sign) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int32_t, Sign, - utils::Sign) -DELEGATE_SIMPLE_HIP_UNARY_FUNCTION(std::int64_t, Sign, - utils::Sign) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int32_t, + Sign, + utils::Sign) +DELEGATE_SIMPLE_HIP_UNARY_FUNCTION( + std::int64_t, + Sign, + utils::Sign) #undef DELEGATE_SIMPLE_HIP_UNARY_FUNCTION -#define DELEGATE_SINCOS_HIP_FUNCTION(T, fn) \ - __global__ void _Kernel_##T##_##SinCos(const int N, const T *x, T *ys, \ - T *yc) { \ - HIP_1D_KERNEL_LOOP(i, N) { fn(__ldg(x + i), ys + i, yc + i); } \ - } \ - template <> \ - void SinCos(const int N, const T *x, T *ys, T *yc, \ - HIPContext *context) { \ - hipLaunchKernelGGL((_Kernel_##T##_##SinCos), CAFFE_GET_BLOCKS(N), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), N, x, \ - ys, yc); \ +#define DELEGATE_SINCOS_HIP_FUNCTION(T, fn) \ + __global__ void _Kernel_##T##_##SinCos( \ + const int N, const T* x, T* ys, T* yc) { \ + HIP_1D_KERNEL_LOOP(i, N) { \ + fn(__ldg(x + i), ys + i, yc + i); \ + } \ + } \ + template <> \ + void SinCos( \ + const int N, const T* x, T* ys, T* yc, HIPContext* context) { \ + hipLaunchKernelGGL( \ + (_Kernel_##T##_##SinCos), \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + x, \ + ys, \ + yc); \ } DELEGATE_SINCOS_HIP_FUNCTION(float, sincosf) @@ -292,18 +421,26 @@ DELEGATE_SINCOS_HIP_FUNCTION(double, sincos) #define DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(TIn, TOut, Func, Op) \ template <> \ - void Func(const int N, const TIn *A, const TIn *B, TOut *C, \ - HIPContext *context) { \ - hipLaunchKernelGGL((SimpleBinaryOpHIPKernel>), \ - CAFFE_GET_BLOCKS(N), CAFFE_HIP_NUM_THREADS, 0, \ - context->hip_stream(), N, Op(), A, B, C); \ - } - -#define DEFINE_SIMPLE_HIP_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ + void Func( \ + const int N, const TIn* A, const TIn* B, TOut* C, HIPContext* context) { \ + hipLaunchKernelGGL( \ + (SimpleBinaryOpHIPKernel>), \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + Op(), \ + A, \ + B, \ + C); \ + } + +#define DEFINE_SIMPLE_HIP_COMPARE_FUNCTION(Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, Func, Op) DEFINE_SIMPLE_HIP_COMPARE_FUNCTION(EQ, thrust::equal_to) @@ -315,11 +452,11 @@ DEFINE_SIMPLE_HIP_COMPARE_FUNCTION(GE, thrust::greater_equal) #undef DEFINE_SIMPLE_HIP_COMPARE_FUNCTION -#define DEFINE_SIMPLE_HIP_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(double, double, Func, Op) \ +#define DEFINE_SIMPLE_HIP_BINARY_FUNCTION(Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float, float, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(double, double, Func, Op) \ DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float16, float16, Func, Op) DEFINE_SIMPLE_HIP_BINARY_FUNCTION(Add, AddFunctor) @@ -333,9 +470,9 @@ DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) -#define DEFINE_SIMPLE_HIP_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ +#define DEFINE_SIMPLE_HIP_BITWISE_BINARY_FUNCTION(Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(bool, bool, Func, Op) \ + DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, Op) \ DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, Op) DEFINE_SIMPLE_HIP_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) @@ -348,69 +485,117 @@ DELEGATE_SIMPLE_HIP_BINARY_FUNCTION(float, float, ElemwiseMax, thrust::maximum); #undef DELEGATE_SIMPLE_HIP_BINARY_FUNCTION -#define DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - void Rowwise##Func(const int rows, const int cols, \ - const TIn *A, const TIn *B, \ - TOut *C, HIPContext *context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FixedDivisor cols_div(cols); \ - hipLaunchKernelGGL(RowwiseBinaryOpHIPKernel, true>, \ - CAFFE_GET_BLOCKS(size), CAFFE_HIP_NUM_THREADS, 0, \ - context->hip_stream(), size, cols_div, Op(), A, B, \ - C); \ - } \ - template <> \ - void Rowwise##Func(const int rows, const int cols, \ - const TIn *A, const TIn *B, \ - TOut *C, HIPContext *context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FixedDivisor cols_div(cols); \ - hipLaunchKernelGGL(RowwiseBinaryOpHIPKernel, false>, \ - CAFFE_GET_BLOCKS(size), CAFFE_HIP_NUM_THREADS, 0, \ - context->hip_stream(), size, cols_div, Op(), A, B, \ - C); \ - } \ - template <> \ - void Colwise##Func(const int rows, const int cols, \ - const TIn *A, const TIn *B, \ - TOut *C, HIPContext *context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FixedDivisor cols_div(cols); \ - hipLaunchKernelGGL(ColwiseBinaryOpHIPKernel, true>, \ - CAFFE_GET_BLOCKS(size), CAFFE_HIP_NUM_THREADS, 0, \ - context->hip_stream(), size, cols_div, Op(), A, B, \ - C); \ - } \ - template <> \ - void Colwise##Func(const int rows, const int cols, \ - const TIn *A, const TIn *B, \ - TOut *C, HIPContext *context) { \ - if (rows == 0 || cols == 0) { \ - return; \ - } \ - const int size = rows * cols; \ - const FixedDivisor cols_div(cols); \ - hipLaunchKernelGGL(ColwiseBinaryOpHIPKernel, false>, \ - CAFFE_GET_BLOCKS(size), CAFFE_HIP_NUM_THREADS, 0, \ - context->hip_stream(), size, cols_div, Op(), A, B, \ - C); \ - } - -#define DEFINE_2D_BROADCAST_HIP_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ +#define DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(TIn, TOut, Func, Op) \ + template <> \ + void Rowwise##Func( \ + const int rows, \ + const int cols, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + HIPContext* context) { \ + if (rows == 0 || cols == 0) { \ + return; \ + } \ + const int size = rows * cols; \ + const FixedDivisor cols_div(cols); \ + hipLaunchKernelGGL( \ + RowwiseBinaryOpHIPKernel, true>, \ + CAFFE_GET_BLOCKS(size), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + size, \ + cols_div, \ + Op(), \ + A, \ + B, \ + C); \ + } \ + template <> \ + void Rowwise##Func( \ + const int rows, \ + const int cols, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + HIPContext* context) { \ + if (rows == 0 || cols == 0) { \ + return; \ + } \ + const int size = rows * cols; \ + const FixedDivisor cols_div(cols); \ + hipLaunchKernelGGL( \ + RowwiseBinaryOpHIPKernel, false>, \ + CAFFE_GET_BLOCKS(size), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + size, \ + cols_div, \ + Op(), \ + A, \ + B, \ + C); \ + } \ + template <> \ + void Colwise##Func( \ + const int rows, \ + const int cols, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + HIPContext* context) { \ + if (rows == 0 || cols == 0) { \ + return; \ + } \ + const int size = rows * cols; \ + const FixedDivisor cols_div(cols); \ + hipLaunchKernelGGL( \ + ColwiseBinaryOpHIPKernel, true>, \ + CAFFE_GET_BLOCKS(size), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + size, \ + cols_div, \ + Op(), \ + A, \ + B, \ + C); \ + } \ + template <> \ + void Colwise##Func( \ + const int rows, \ + const int cols, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + HIPContext* context) { \ + if (rows == 0 || cols == 0) { \ + return; \ + } \ + const int size = rows * cols; \ + const FixedDivisor cols_div(cols); \ + hipLaunchKernelGGL( \ + ColwiseBinaryOpHIPKernel, false>, \ + CAFFE_GET_BLOCKS(size), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + size, \ + cols_div, \ + Op(), \ + A, \ + B, \ + C); \ + } + +#define DEFINE_2D_BROADCAST_HIP_COMPARE_FUNCTION(Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Func, Op) DEFINE_2D_BROADCAST_HIP_COMPARE_FUNCTION(EQ, thrust::equal_to) @@ -422,13 +607,13 @@ DEFINE_2D_BROADCAST_HIP_COMPARE_FUNCTION(GE, thrust::greater_equal) #undef DEFINE_2D_BROADCAST_HIP_COMPARE_FUNCTION -#define DEFINE_2D_BROADCAST_HIP_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, \ - Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, \ - Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(float, float, Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(double, double, Func, Op) \ +#define DEFINE_2D_BROADCAST_HIP_BINARY_FUNCTION(Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION( \ + std::int32_t, std::int32_t, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION( \ + std::int64_t, std::int64_t, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(float, float, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(double, double, Func, Op) \ DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(float16, float16, Func, Op) DEFINE_2D_BROADCAST_HIP_BINARY_FUNCTION(Add, AddFunctor) @@ -442,12 +627,12 @@ DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, And, thrust::logical_and) DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Or, thrust::logical_or) DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Xor, thrust::bit_xor) -#define DEFINE_2D_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Func, Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, std::int32_t, Func, \ - Op) \ - DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, std::int64_t, Func, \ - Op) +#define DEFINE_2D_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION( \ + std::int32_t, std::int32_t, Func, Op) \ + DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION( \ + std::int64_t, std::int64_t, Func, Op) DEFINE_2D_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(BitwiseAnd, thrust::bit_and) DEFINE_2D_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(BitwiseOr, thrust::bit_or) @@ -457,21 +642,26 @@ DEFINE_2D_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) #undef DELEGATE_2D_BROADCAST_HIP_BINARY_FUNCTION -#define DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - void Func(const int A_ndim, const int *A_dims, \ - const int B_ndim, const int *B_dims, \ - const TIn *A, const TIn *B, TOut *C, \ - HIPContext *context) { \ - BroadcastBinaryOp>(A_ndim, A_dims, B_ndim, B_dims, \ - Op(), A, B, C, context); \ - } - -#define DEFINE_BROADCAST_HIP_COMPARE_FUNCTION(Func, Op) \ - DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ - DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ - DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ - DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ +#define DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(TIn, TOut, Func, Op) \ + template <> \ + void Func( \ + const int A_ndim, \ + const int* A_dims, \ + const int B_ndim, \ + const int* B_dims, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + HIPContext* context) { \ + BroadcastBinaryOp>( \ + A_ndim, A_dims, B_ndim, B_dims, Op(), A, B, C, context); \ + } + +#define DEFINE_BROADCAST_HIP_COMPARE_FUNCTION(Func, Op) \ + DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(std::int32_t, bool, Func, Op) \ + DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(std::int64_t, bool, Func, Op) \ + DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(float, bool, Func, Op) \ + DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(double, bool, Func, Op) \ DELEGATE_BROADCAST_HIP_BINARY_FUNCTION(bool, bool, Func, Op) DEFINE_BROADCAST_HIP_COMPARE_FUNCTION(EQ, thrust::equal_to) @@ -514,20 +704,27 @@ DEFINE_BROADCAST_HIP_BITWISE_BINARY_FUNCTION(BitwiseXor, thrust::bit_xor) #undef DELEGATE_BROADCAST_HIP_BINARY_FUNCTION -#define DELEGATE_REDUCTION_FUNCTION(T, Funcname, func) \ - template <> \ - void Funcname(const int N, const T *src, T *dst, \ - Tensor *scratch_ptr, \ - HIPContext *context) { \ - size_t memRequired = 0; \ - cub::DeviceReduce::func(nullptr, memRequired, src, dst, N, \ - context->hip_stream()); \ - auto buffer_size = \ - static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); \ - scratch_ptr->Resize(std::vector{buffer_size}); \ - cub::DeviceReduce::func( \ - static_cast(scratch_ptr->mutable_data()), memRequired, src, \ - dst, N, context->hip_stream()); \ +#define DELEGATE_REDUCTION_FUNCTION(T, Funcname, func) \ + template <> \ + void Funcname( \ + const int N, \ + const T* src, \ + T* dst, \ + Tensor* scratch_ptr, \ + HIPContext* context) { \ + size_t memRequired = 0; \ + cub::DeviceReduce::func( \ + nullptr, memRequired, src, dst, N, context->hip_stream()); \ + auto buffer_size = \ + static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); \ + scratch_ptr->Resize(std::vector{buffer_size}); \ + cub::DeviceReduce::func( \ + static_cast(scratch_ptr->mutable_data()), \ + memRequired, \ + src, \ + dst, \ + N, \ + context->hip_stream()); \ } DELEGATE_REDUCTION_FUNCTION(float, ReduceMin, Min) @@ -540,34 +737,60 @@ DELEGATE_REDUCTION_FUNCTION(int64_t, ReduceMax, Max) // Caffe2 gemm provides a simpler interface to the gemm functions, with the // limitation that the data has to be contiguous in memory. template <> -void Gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, - const int N, const int K, const float alpha, - const float *A, const float *B, const float beta, - float *C, HIPContext *context, - TensorProto::DataType math_type) { +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const float* B, + const float beta, + float* C, + HIPContext* context, + TensorProto::DataType math_type) { // Note that rocblas follows fortran order, so the order is different from // the cblas convention. int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; + ? rocblas_operation_none + : rocblas_operation_transpose; rocblas_operation cuTransB = (TransB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - ROCBLAS_ENFORCE(rocblas_sgemm(context->rocblas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); + ? rocblas_operation_none + : rocblas_operation_transpose; + ROCBLAS_ENFORCE(rocblas_sgemm( + context->rocblas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + N)); } template <> -void Gemm(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, - const int N, const int K, const float alpha, - const float16 *A, const float16 *B, - const float beta, float16 *C, - HIPContext *context, - TensorProto::DataType math_type) { +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const float alpha, + const float16* A, + const float16* B, + const float beta, + float16* C, + HIPContext* context, + TensorProto::DataType math_type) { CAFFE_THROW("Unsupported math type"); #if ROCBLAS_FP16 // rocblas does not support fp16 yet // Note that cublas follows fortran order, so the order is different from @@ -575,15 +798,30 @@ void Gemm(const CBLAS_TRANSPOSE TransA, int lda = (TransA == CblasNoTrans) ? K : M; int ldb = (TransB == CblasNoTrans) ? N : K; rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; + ? rocblas_operation_none + : rocblas_operation_transpose; rocblas_operation cuTransB = (TransB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; + ? rocblas_operation_none + : rocblas_operation_transpose; if (math_type == TensorProto_DataType_FLOAT) { - ROCBLAS_CHECK(rocblas_sgemmEx(context->rocblas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, CUDA_R_16F, ldb, A, - CUDA_R_16F, lda, &beta, C, CUDA_R_16F, N)); + ROCBLAS_CHECK(rocblas_sgemmEx( + context->rocblas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + CUDA_R_16F, + ldb, + A, + CUDA_R_16F, + lda, + &beta, + C, + CUDA_R_16F, + N)); } else if (math_type == TensorProto_DataType_FLOAT16) { // convert alpha, beta from float -> __half @@ -614,171 +852,265 @@ void Gemm(const CBLAS_TRANSPOSE TransA, } template <> -void BiasCHW(const float *bias, const float *bias_multiplier, - const int bias_channels, const int image_size, - float *image, HIPContext *context) { - Gemm(CblasNoTrans, CblasNoTrans, bias_channels, image_size, - 1, 1, bias, bias_multiplier, 1, image, context); +void BiasCHW( + const float* bias, + const float* bias_multiplier, + const int bias_channels, + const int image_size, + float* image, + HIPContext* context) { + Gemm( + CblasNoTrans, + CblasNoTrans, + bias_channels, + image_size, + 1, + 1, + bias, + bias_multiplier, + 1, + image, + context); } template <> void GemmBatched( - const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int batch_size, const int M, const int N, const int K, - const float alpha, const float *A, const float *B, const float beta, - float *C, HIPContext *context, Tensor *scratch, + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float** A, + const float** B, + const float beta, + float** C, + HIPContext* context, + TensorProto::DataType math_type) { + // rocblas doesn't support SgemmBatched yet. + for (int i = 0; i < batch_size; ++i) { + Gemm( + TransA, + TransB, + M, + N, + K, + alpha, + A[i], + B[i], + beta, + C[i], + context, + math_type); + } +} + +template <> +void GemmStridedBatched( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int A_stride, + const float* B, + const int B_stride, + const float beta, + float* C, + const int C_stride, + HIPContext* context, TensorProto::DataType math_type) { - const int a_stride = M * K; - const int b_stride = K * N; - const int c_stride = M * N; // Note that cublas follows fortran order, so the order is different from // the cblas convention. const int lda = (TransA == CblasNoTrans) ? K : M; const int ldb = (TransB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (TransB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; + const rocblas_operation cuTransA = (TransA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const rocblas_operation cuTransB = (TransB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; ROCBLAS_ENFORCE(rocblas_sgemm_strided_batched( - context->rocblas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, - b_stride, A, lda, a_stride, &beta, C, N, c_stride, batch_size)); -} - -namespace { - -__global__ void FloatToHalfKernel(const int N, const float *X, half *Y) { - HIP_1D_KERNEL_LOOP(i, N) { Y[i] = __float2half(X[i]); } -} - -__global__ void HalfToFloatKernel(const int N, const half *X, float *Y) { - HIP_1D_KERNEL_LOOP(i, N) { Y[i] = __half2float(X[i]); } + context->rocblas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + B_stride, + A, + lda, + A_stride, + &beta, + C, + N, + C_stride, + batch_size)); } -}; // namespace - template <> -void GemmBatched( - const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, - const int batch_size, const int M, const int N, const int K, - const float alpha, const float16 *A, const float16 *B, const float beta, - float16 *C, HIPContext *context, Tensor *scratch, +void GemmStridedBatched( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float16* A, + const int A_stride, + const float16* B, + const int B_stride, + const float beta, + float16* C, + const int C_stride, + HIPContext* context, TensorProto::DataType math_type) { - const int a_stride = M * K; - const int b_stride = K * N; - const int c_stride = M * N; - - // 3 options: - // 1) scratch != null = cast to fp32, SgemmStridedBatched, cast result to fp16 - // 2) math_type == FLOAT, scratch == nullptr = looped SgemmEx - // 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm - - if (scratch != nullptr) { - const int A_size = a_stride * batch_size; - const int B_size = b_stride * batch_size; - // cast, cublasSgemmStridedBatched, cast - size_t in_elems = A_size + B_size; - size_t out_elems = c_stride * batch_size; - - scratch->Resize(in_elems + out_elems); - float *scratch_ptr = scratch->mutable_data(); - - float *A_fp32 = scratch_ptr; - float *B_fp32 = scratch_ptr + A_size; - float *C_fp32 = scratch_ptr + A_size + B_size; - - // cast A, B into fp32 - hipLaunchKernelGGL((HalfToFloatKernel), dim3(CAFFE_GET_BLOCKS(A_size)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - A_size, (half *)A, A_fp32); - hipLaunchKernelGGL((HalfToFloatKernel), dim3(CAFFE_GET_BLOCKS(B_size)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - B_size, (half *)B, B_fp32); - - // run fp32 batched Gemm - GemmBatched(TransA, TransB, batch_size, M, N, K, alpha, - A_fp32, B_fp32, beta, C_fp32, context); - - // cast result back to fp16 - hipLaunchKernelGGL((FloatToHalfKernel), - dim3(CAFFE_GET_BLOCKS(batch_size * M * N)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - batch_size * M * N, C_fp32, (half *)C); - } else { #if ROCBLAS_FP16 // rocblas does not support fp16 yet - if (math_type == TensorProto_DataType_FLOAT) { - // loop over matrices in the batch - for (int i = 0; i < batch_size; ++i) { - math::Gemm(TransA, TransB, M, N, K, alpha, - A + a_stride * i, B + b_stride * i, - beta, C + c_stride * i, context); - } - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (TransA == CblasNoTrans) ? K : M; - const int ldb = (TransB == CblasNoTrans) ? N : K; - rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - rocblas_operation cuTransB = (TransB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - - // convert alpha, beta from float -> __half - auto alpha_fp16 = convert::floatToHalf(alpha); - auto beta_fp16 = convert::floatToHalf(beta); - ROCBLAS_ENFORCE(cublasHgemmStridedBatched( - context->rocblas_handle(), cuTransB, cuTransA, N, M, K, &alpha_fp16, - (const __half *)B, ldb, b_stride, (const __half *)A, lda, a_stride, - &beta_fp16, (__half *)C, N, c_stride, batch_size)); + if (math_type == TensorProto_DataType_FLOAT) { + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + TransA, + TransB, + M, + N, + K, + alpha, + A + a_stride * i, + B + b_stride * i, + beta, + C + c_stride * i, + context); } -#endif + } else if (math_type == TensorProto_DataType_FLOAT16) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + const int lda = (TransA == CblasNoTrans) ? K : M; + const int ldb = (TransB == CblasNoTrans) ? N : K; + const rocblas_operation cuTransA = (TransA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + const rocblas_operation cuTransB = (TransB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + + // convert alpha, beta from float -> __half + auto alpha_fp16 = convert::floatToHalf(alpha); + auto beta_fp16 = convert::floatToHalf(beta); + ROCBLAS_ENFORCE(cublasHgemmStridedBatched( + context->rocblas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha_fp16, + (const __half*)B, + ldb, + B_stride, + (const __half*)A, + lda, + A_stride, + &beta_fp16, + (__half*)C, + N, + C_stride, + batch_size)); } +#endif } template <> -void GemmEx(const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, const int M, - const int N, const int K, const float alpha, - const float *A, const int lda, const float *B, - const int ldb, const float beta, float *C, - const int ldc, HIPContext *context) { +void GemmEx( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + HIPContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; + ? rocblas_operation_none + : rocblas_operation_transpose; rocblas_operation cuTransB = (TransB == CblasNoTrans) - ? rocblas_operation_none - : rocblas_operation_transpose; - ROCBLAS_ENFORCE(rocblas_sgemm(context->rocblas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, - ldc)); + ? rocblas_operation_none + : rocblas_operation_transpose; + ROCBLAS_ENFORCE(rocblas_sgemm( + context->rocblas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc)); } template <> -void Gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const float alpha, const float *A, - const float *x, const float beta, float *y, - HIPContext *context, - TensorProto::DataType math_type) { +void Gemv( + const CBLAS_TRANSPOSE TransA, + const int M, + const int N, + const float alpha, + const float* A, + const float* x, + const float beta, + float* y, + HIPContext* context, + TensorProto::DataType math_type) { rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_transpose - : rocblas_operation_none; - ROCBLAS_ENFORCE(rocblas_sgemv(context->rocblas_handle(), cuTransA, N, M, - &alpha, A, N, x, 1, &beta, y, 1)); + ? rocblas_operation_transpose + : rocblas_operation_none; + ROCBLAS_ENFORCE(rocblas_sgemv( + context->rocblas_handle(), + cuTransA, + N, + M, + &alpha, + A, + N, + x, + 1, + &beta, + y, + 1)); } // Batched Add variants namespace { template -__global__ void AddStripedBatchKernel(const int N, const T *first, T *Y, - const int stripe, const int batch) { +__global__ void AddStripedBatchKernel( + const int N, + const T* first, + T* Y, + const int stripe, + const int batch) { for (int j = 0; j < batch; j++) { - const T *x = first + j * stripe; + const T* x = first + j * stripe; HIP_1D_KERNEL_LOOP(i, N) { float tmpY = convert::To(Y[i]); tmpY += convert::To(x[i]); @@ -788,14 +1120,26 @@ __global__ void AddStripedBatchKernel(const int N, const T *first, T *Y, } } // namespace -#define CAFFE2_SPECIALIZED_HIP_ADD_STRIPED_BATCH(T) \ - template <> \ - void AddStripedBatch(const int N, const T *first, T *Y, \ - const int stripe, const int batch, \ - HIPContext *context) { \ - hipLaunchKernelGGL(AddStripedBatchKernel, CAFFE_GET_BLOCKS(N), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), N, \ - first, Y, stripe, batch); \ +#define CAFFE2_SPECIALIZED_HIP_ADD_STRIPED_BATCH(T) \ + template <> \ + void AddStripedBatch( \ + const int N, \ + const T* first, \ + T* Y, \ + const int stripe, \ + const int batch, \ + HIPContext* context) { \ + hipLaunchKernelGGL( \ + AddStripedBatchKernel, \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + first, \ + Y, \ + stripe, \ + batch); \ } CAFFE2_SPECIALIZED_HIP_ADD_STRIPED_BATCH(float); @@ -803,16 +1147,22 @@ CAFFE2_SPECIALIZED_HIP_ADD_STRIPED_BATCH(float16); #undef CAFFE2_SPECIALIZED_HIP_ADD_STRIPED_BATCH template <> -void Gemv(const CBLAS_TRANSPOSE TransA, const int M, - const int N, const float alpha, const float16 *A, - const float16 *x, const float beta, float16 *y, - HIPContext *context, - TensorProto::DataType math_type) { +void Gemv( + const CBLAS_TRANSPOSE TransA, + const int M, + const int N, + const float alpha, + const float16* A, + const float16* x, + const float beta, + float16* y, + HIPContext* context, + TensorProto::DataType math_type) { CAFFE_THROW("Unsupported math type"); #if ROCBLAS_FP16 // rocblas does not support fp16 yet rocblas_operation cuTransA = (TransA == CblasNoTrans) - ? rocblas_operation_transpose - : rocblas_operation_none; + ? rocblas_operation_transpose + : rocblas_operation_none; // sort out what we need to call cublasSgemmEx / cublasHgemm int m = (cuTransA == rocblas_operation_none) ? N : M; @@ -821,18 +1171,43 @@ void Gemv(const CBLAS_TRANSPOSE TransA, const int M, int LDC = m; if (math_type == TensorProto_DataType_FLOAT) { - ROCBLAS_CHECK(cublasSgemmEx(context->rocblas_handle(), cuTransA, - rocblas_operation_none, m, 1, k, &alpha, A, - CUDA_R_16F, LDA, x, CUDA_R_16F, k, &beta, y, - CUDA_R_16F, LDC)); + ROCBLAS_CHECK(cublasSgemmEx( + context->rocblas_handle(), + cuTransA, + rocblas_operation_none, + m, + 1, + k, + &alpha, + A, + CUDA_R_16F, + LDA, + x, + CUDA_R_16F, + k, + &beta, + y, + CUDA_R_16F, + LDC)); } else if (math_type == TensorProto_DataType_FLOAT16) { auto alpha_fp16 = convert::floatToHalf(alpha); auto beta_fp16 = convert::floatToHalf(beta); - ROCBLAS_CHECK(cublasHgemm(context->rocblas_handle(), cuTransA, - rocblas_operation_none, m, 1, k, &alpha_fp16, - (const __half *)A, LDA, (const __half *)x, k, - &beta_fp16, (__half *)y, LDC)); + ROCBLAS_CHECK(cublasHgemm( + context->rocblas_handle(), + cuTransA, + rocblas_operation_none, + m, + 1, + k, + &alpha_fp16, + (const __half*)A, + LDA, + (const __half*)x, + k, + &beta_fp16, + (__half*)y, + LDC)); } else { // fail CAFFE_THROW("Unsupported math type"); @@ -842,18 +1217,26 @@ void Gemv(const CBLAS_TRANSPOSE TransA, const int M, namespace { template -__global__ void SetKernel(const int N, const T alpha, T *Y) { - HIP_1D_KERNEL_LOOP(i, N) { Y[i] = alpha; } +__global__ void SetKernel(const int N, const T alpha, T* Y) { + HIP_1D_KERNEL_LOOP(i, N) { + Y[i] = alpha; + } } } // namespace -#define CAFFE2_SPECIALIZED_HIP_SET(T) \ - template <> \ - void Set(const size_t N, const T alpha, T *Y, \ - HIPContext *context) { \ - hipLaunchKernelGGL((SetKernel), CAFFE_GET_BLOCKS(N), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), \ - static_cast(N), alpha, Y); \ +#define CAFFE2_SPECIALIZED_HIP_SET(T) \ + template <> \ + void Set( \ + const size_t N, const T alpha, T* Y, HIPContext* context) { \ + hipLaunchKernelGGL( \ + (SetKernel), \ + CAFFE_GET_BLOCKS(N), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + static_cast(N), \ + alpha, \ + Y); \ } CAFFE2_SPECIALIZED_HIP_SET(float); @@ -871,56 +1254,93 @@ CAFFE2_SPECIALIZED_HIP_SET(uint16_t); namespace { template -__global__ void UniformShift(const size_t N, const float min, const float max, - T *x) { +__global__ void +UniformShift(const size_t N, const float min, const float max, T* x) { float scale = max - min; HIP_1D_KERNEL_LOOP(i, N) { x[i] = convert::To(convert::To(x[i]) * scale + min); } } -__global__ void UniformIntFit(const size_t N, const int min, const int max, - unsigned int *x) { - int *x_int = reinterpret_cast(x); +__global__ void +UniformIntFit(const size_t N, const int min, const int max, unsigned int* x) { + int* x_int = reinterpret_cast(x); int range = (max - min + 1); - HIP_1D_KERNEL_LOOP(i, N) { x_int[i] = min + static_cast(x[i] % range); } + HIP_1D_KERNEL_LOOP(i, N) { + x_int[i] = min + static_cast(x[i] % range); + } } } // namespace template <> -void RandUniform(const size_t n, const float min, - const float max, float *r, - HIPContext *context) { +void RandUniform( + const size_t n, + const float min, + const float max, + float* r, + HIPContext* context) { HIPRAND_ENFORCE(hiprandGenerateUniform(context->hiprand_generator(), r, n)); - hipLaunchKernelGGL((UniformShift), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - min, max, r); + hipLaunchKernelGGL( + (UniformShift), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + min, + max, + r); } template <> -void RandUniform(const size_t n, const double min, - const double max, double *r, - HIPContext *context) { +void RandUniform( + const size_t n, + const double min, + const double max, + double* r, + HIPContext* context) { HIPRAND_ENFORCE( hiprandGenerateUniformDouble(context->hiprand_generator(), r, n)); - hipLaunchKernelGGL((UniformShift), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - min, max, r); + hipLaunchKernelGGL( + (UniformShift), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + min, + max, + r); } template <> -void RandUniform(const size_t n, const int min, const int max, - int *r, HIPContext *context) { - HIPRAND_ENFORCE(hiprandGenerate(context->hiprand_generator(), - reinterpret_cast(r), n)); - hipLaunchKernelGGL((UniformIntFit), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - min, max, reinterpret_cast(r)); +void RandUniform( + const size_t n, + const int min, + const int max, + int* r, + HIPContext* context) { + HIPRAND_ENFORCE(hiprandGenerate( + context->hiprand_generator(), reinterpret_cast(r), n)); + hipLaunchKernelGGL( + (UniformIntFit), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + min, + max, + reinterpret_cast(r)); } template -size_t HandleOddLengthRandGaussian(const size_t n, const T mean, const T std, - T *r, HIPContext *context) { +size_t HandleOddLengthRandGaussian( + const size_t n, + const T mean, + const T std, + T* r, + HIPContext* context) { if (n % 2 == 1) { std::default_random_engine generator; std::normal_distribution distribution(mean, std); @@ -932,31 +1352,41 @@ size_t HandleOddLengthRandGaussian(const size_t n, const T mean, const T std, } template <> -void RandGaussian(const size_t n, const float mean, - const float std, float *r, - HIPContext *context) { +void RandGaussian( + const size_t n, + const float mean, + const float std, + float* r, + HIPContext* context) { // If n is odd, we add a random Gaussian value at the end manually // and generate n-1 random values using curandGenerateNormal. // curandGenerateNormal requires n to be even. const size_t even_n = HandleOddLengthRandGaussian(n, mean, std, r, context); - HIPRAND_ENFORCE(hiprandGenerateNormal(context->hiprand_generator(), r, even_n, - mean, std)); + HIPRAND_ENFORCE(hiprandGenerateNormal( + context->hiprand_generator(), r, even_n, mean, std)); } template <> -void RandGaussian(const size_t n, const double mean, - const double std, double *r, - HIPContext *context) { +void RandGaussian( + const size_t n, + const double mean, + const double std, + double* r, + HIPContext* context) { const size_t even_n = HandleOddLengthRandGaussian(n, mean, std, r, context); - HIPRAND_ENFORCE(hiprandGenerateNormalDouble(context->hiprand_generator(), r, - even_n, mean, std)); + HIPRAND_ENFORCE(hiprandGenerateNormalDouble( + context->hiprand_generator(), r, even_n, mean, std)); } template <> -void Dot(const int n, const float *a, const float *b, - float *y, HIPContext *context) { +void Dot( + const int n, + const float* a, + const float* b, + float* y, + HIPContext* context) { float result; ROCBLAS_ENFORCE( rocblas_sdot(context->rocblas_handle(), n, a, 1, b, 1, &result)); @@ -964,14 +1394,28 @@ void Dot(const int n, const float *a, const float *b, } template <> -void Dot(const int n, const float16 *a, const float16 *b, - float16 *y, HIPContext *context) { +void Dot( + const int n, + const float16* a, + const float16* b, + float16* y, + HIPContext* context) { CAFFE_THROW("Unsupported math type"); #if ROCBLAS_FP16 // rocblas does not support fp16 yet float16 result; // execute with 32-bit math - ROCBLAS_CHECK(cublasDotEx(context->rocblas_handle(), n, a, CUDA_R_16F, 1, b, - CUDA_R_16F, 1, &result, CUDA_R_16F, CUDA_R_32F)); + ROCBLAS_CHECK(cublasDotEx( + context->rocblas_handle(), + n, + a, + CUDA_R_16F, + 1, + b, + CUDA_R_16F, + 1, + &result, + CUDA_R_16F, + CUDA_R_32F)); context->Copy(1, &result, y); #endif } @@ -982,7 +1426,7 @@ void Dot(const int n, const float16 *a, const float16 *b, // reduction here. #define SUM_KERNEL_NTHREADS 128 template -__global__ void SumKernel(const int N, const T *X, T *Y, bool square) { +__global__ void SumKernel(const int N, const T* X, T* Y, bool square) { const int idx = threadIdx.x; __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS]; @@ -1004,8 +1448,7 @@ __global__ void SumKernel(const int N, const T *X, T *Y, bool square) { // 128 -> 32 if (idx < 32) { reduction_buffer[idx] += reduction_buffer[idx + 32] + - reduction_buffer[idx + 64] + - reduction_buffer[idx + 96]; + reduction_buffer[idx + 64] + reduction_buffer[idx + 96]; } __syncthreads(); // 32 -> 1 @@ -1025,16 +1468,21 @@ __global__ void SumKernel(const int N, const T *X, T *Y, bool square) { namespace { -template __global__ void SumConvertKernel(float *sum, T *dest) { +template +__global__ void SumConvertKernel(float* sum, T* dest) { *dest = convert::To(*sum); } template -void SumGenericIter(const int N, IterT it, T *&dest, HIPContext *context, - Tensor *scratch_ptr) { +void SumGenericIter( + const int N, + IterT it, + T*& dest, + HIPContext* context, + Tensor* scratch_ptr) { size_t memRequired = 0; - cub::DeviceReduce::Sum(nullptr, memRequired, it, dest, N, - context->hip_stream()); + cub::DeviceReduce::Sum( + nullptr, memRequired, it, dest, N, context->hip_stream()); auto buffer_size = static_cast((memRequired + sizeof(T) - 1) / sizeof(T)); if (!dest) { @@ -1045,106 +1493,184 @@ void SumGenericIter(const int N, IterT it, T *&dest, HIPContext *context, scratch_ptr->Resize(std::vector{buffer_size}); } cub::DeviceReduce::Sum( - static_cast(scratch_ptr->template mutable_data()), memRequired, - it, dest, N, context->hip_stream()); + static_cast(scratch_ptr->template mutable_data()), + memRequired, + it, + dest, + N, + context->hip_stream()); } } // namespace template <> -void Sum(const int N, const float *x, float *y, - HIPContext *context, - Tensor *scratch_ptr) { +void Sum( + const int N, + const float* x, + float* y, + HIPContext* context, + Tensor* scratch_ptr) { if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { SumGenericIter(N, x, y, context, scratch_ptr); } else { - hipLaunchKernelGGL((SumKernel), dim3(1), dim3(SUM_KERNEL_NTHREADS), 0, - context->hip_stream(), N, x, y, false); + hipLaunchKernelGGL( + (SumKernel), + dim3(1), + dim3(SUM_KERNEL_NTHREADS), + 0, + context->hip_stream(), + N, + x, + y, + false); } } template <> -void Sum(const int N, const int32_t *x, int32_t *y, - HIPContext *context, - Tensor *scratch_ptr) { +void Sum( + const int N, + const int32_t* x, + int32_t* y, + HIPContext* context, + Tensor* scratch_ptr) { if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { SumGenericIter(N, x, y, context, scratch_ptr); } else { - hipLaunchKernelGGL((SumKernel), dim3(1), dim3(SUM_KERNEL_NTHREADS), 0, - context->hip_stream(), N, x, y, false); + hipLaunchKernelGGL( + (SumKernel), + dim3(1), + dim3(SUM_KERNEL_NTHREADS), + 0, + context->hip_stream(), + N, + x, + y, + false); } } namespace { -template struct FloatTransform { +template +struct FloatTransform { inline __host__ __device__ float operator()(const T v) const { return convert::To(v); } }; } // namespace -#define CAFFE2_MATH_SUM_FUNC(T) \ - template <> \ - void Sum(const int N, const T *x, T *y, HIPContext *context, \ - Tensor *scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform transform; \ - cub::TransformInputIterator, const T *> it( \ - x, transform); \ - float *sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - hipLaunchKernelGGL((SumConvertKernel), dim3(1), dim3(1), 0, \ - context->hip_stream(), sum, y); \ - } else { \ - hipLaunchKernelGGL((SumKernel), dim3(1), dim3(SUM_KERNEL_NTHREADS), 0, \ - context->hip_stream(), N, x, y, false); \ - } \ +#define CAFFE2_MATH_SUM_FUNC(T) \ + template <> \ + void Sum( \ + const int N, \ + const T* x, \ + T* y, \ + HIPContext* context, \ + Tensor* scratch_ptr) { \ + if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ + FloatTransform transform; \ + cub::TransformInputIterator, const T*> it( \ + x, transform); \ + float* sum = nullptr; \ + SumGenericIter(N, it, sum, context, scratch_ptr); \ + hipLaunchKernelGGL( \ + (SumConvertKernel), \ + dim3(1), \ + dim3(1), \ + 0, \ + context->hip_stream(), \ + sum, \ + y); \ + } else { \ + hipLaunchKernelGGL( \ + (SumKernel), \ + dim3(1), \ + dim3(SUM_KERNEL_NTHREADS), \ + 0, \ + context->hip_stream(), \ + N, \ + x, \ + y, \ + false); \ + } \ } CAFFE2_MATH_SUM_FUNC(float16) #undef CAFFE2_MATH_SUM_FUNC namespace { -template struct SqrTransform { - inline __host__ __device__ T operator()(const T v) const { return v * v; } +template +struct SqrTransform { + inline __host__ __device__ T operator()(const T v) const { + return v * v; + } }; } // namespace template <> -void SumSqr(const int N, const float *x, float *y, - HIPContext *context, - Tensor *scratch_ptr) { +void SumSqr( + const int N, + const float* x, + float* y, + HIPContext* context, + Tensor* scratch_ptr) { if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { SqrTransform transform; - cub::TransformInputIterator, const float *> it( + cub::TransformInputIterator, const float*> it( x, transform); SumGenericIter(N, it, y, context, scratch_ptr); } else { - hipLaunchKernelGGL((SumKernel), dim3(1), dim3(SUM_KERNEL_NTHREADS), 0, - context->hip_stream(), N, x, y, true); - } -} - -#define CAFFE2_MATH_SUMSQR_FUNC(T) \ - template <> \ - void SumSqr(const int N, const T *x, T *y, \ - HIPContext *context, \ - Tensor *scratch_ptr) { \ - if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ - FloatTransform float_transform; \ - cub::TransformInputIterator, const T *> \ - float_it(x, float_transform); \ - SqrTransform sqr_transform; \ - cub::TransformInputIterator, \ - decltype(float_it)> \ - it(float_it, sqr_transform); \ - float *sum = nullptr; \ - SumGenericIter(N, it, sum, context, scratch_ptr); \ - hipLaunchKernelGGL((SumConvertKernel), dim3(1), dim3(1), 0, \ - context->hip_stream(), sum, y); \ - } else { \ - hipLaunchKernelGGL((SumKernel), dim3(1), dim3(SUM_KERNEL_NTHREADS), 0, \ - context->hip_stream(), N, x, y, true); \ - } \ + hipLaunchKernelGGL( + (SumKernel), + dim3(1), + dim3(SUM_KERNEL_NTHREADS), + 0, + context->hip_stream(), + N, + x, + y, + true); + } +} + +#define CAFFE2_MATH_SUMSQR_FUNC(T) \ + template <> \ + void SumSqr( \ + const int N, \ + const T* x, \ + T* y, \ + HIPContext* context, \ + Tensor* scratch_ptr) { \ + if (scratch_ptr && N > DEVICE_REDUCE_SIZE_THRESHOLD) { \ + FloatTransform float_transform; \ + cub::TransformInputIterator, const T*> \ + float_it(x, float_transform); \ + SqrTransform sqr_transform; \ + cub::TransformInputIterator< \ + float, \ + SqrTransform, \ + decltype(float_it)> \ + it(float_it, sqr_transform); \ + float* sum = nullptr; \ + SumGenericIter(N, it, sum, context, scratch_ptr); \ + hipLaunchKernelGGL( \ + (SumConvertKernel), \ + dim3(1), \ + dim3(1), \ + 0, \ + context->hip_stream(), \ + sum, \ + y); \ + } else { \ + hipLaunchKernelGGL( \ + (SumKernel), \ + dim3(1), \ + dim3(SUM_KERNEL_NTHREADS), \ + 0, \ + context->hip_stream(), \ + N, \ + x, \ + y, \ + true); \ + } \ } CAFFE2_MATH_SUMSQR_FUNC(float16) @@ -1153,32 +1679,59 @@ CAFFE2_MATH_SUMSQR_FUNC(float16) namespace { template -__global__ void SelectKernel(const int N, const int D, const T *x, - const int *idx, T *y) { - HIP_1D_KERNEL_LOOP(i, N) { y[i] = x[i * D + idx[i]]; } +__global__ void +SelectKernel(const int N, const int D, const T* x, const int* idx, T* y) { + HIP_1D_KERNEL_LOOP(i, N) { + y[i] = x[i * D + idx[i]]; + } } } // namespace template <> -void Select(const int N, const int D, const float *x, - const int *idx, float *y, HIPContext *context) { - hipLaunchKernelGGL((SelectKernel), dim3(CAFFE_GET_BLOCKS(N)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), N, - D, x, idx, y); +void Select( + const int N, + const int D, + const float* x, + const int* idx, + float* y, + HIPContext* context) { + hipLaunchKernelGGL( + (SelectKernel), + dim3(CAFFE_GET_BLOCKS(N)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + N, + D, + x, + idx, + y); } template <> -void Select(const int N, const int D, const float16 *x, - const int *idx, float16 *y, - HIPContext *context) { - hipLaunchKernelGGL((SelectKernel), dim3(CAFFE_GET_BLOCKS(N)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), N, - D, x, idx, y); +void Select( + const int N, + const int D, + const float16* x, + const int* idx, + float16* y, + HIPContext* context) { + hipLaunchKernelGGL( + (SelectKernel), + dim3(CAFFE_GET_BLOCKS(N)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + N, + D, + x, + idx, + y); } namespace { template -__global__ void ScaleKernel(const int n, const float alpha, const T *x, T *y) { +__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) { HIP_1D_KERNEL_LOOP(i, n) { // y[i] = convert::To(convert::To(x[i]) * alpha); y[i] = convert::Get(convert::Get(x[i]) * alpha); @@ -1186,146 +1739,275 @@ __global__ void ScaleKernel(const int n, const float alpha, const T *x, T *y) { } template -__global__ void ScaleKernelDeviceAlpha(const int n, const float *alpha, - const T *x, T *y) { - HIP_1D_KERNEL_LOOP(i, n) { y[i] = x[i] * (*alpha); } +__global__ void +ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) { + HIP_1D_KERNEL_LOOP(i, n) { + y[i] = x[i] * (*alpha); + } } template -__global__ void PowKernel(const int n, const T *x, const T exponent, T *y) { - HIP_1D_KERNEL_LOOP(i, n) { y[i] = powf(x[i], exponent); } +__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) { + HIP_1D_KERNEL_LOOP(i, n) { + y[i] = powf(x[i], exponent); + } } // fp16 specialization template <> -__global__ void ScaleKernelDeviceAlpha(const int n, const float *alpha, - const float16 *x, float16 *y) { +__global__ void ScaleKernelDeviceAlpha( + const int n, + const float* alpha, + const float16* x, + float16* y) { HIP_1D_KERNEL_LOOP(i, n) { - y[i] = convert::To(convert::To(x[i]) * - (*alpha)); + y[i] = convert::To( + convert::To(x[i]) * (*alpha)); } } } // namespace template <> -void Powx(const int N, const float *a, const float b, - float *y, HIPContext *context) { - hipLaunchKernelGGL((PowKernel), dim3(CAFFE_GET_BLOCKS(N)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), N, - a, b, y); +void Powx( + const int N, + const float* a, + const float b, + float* y, + HIPContext* context) { + hipLaunchKernelGGL( + (PowKernel), + dim3(CAFFE_GET_BLOCKS(N)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + N, + a, + b, + y); } template <> -void Scale(const int n, const float alpha, const float *x, - float *y, HIPContext *context) { - hipLaunchKernelGGL((ScaleKernel), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - alpha, x, y); +void Scale( + const int n, + const float alpha, + const float* x, + float* y, + HIPContext* context) { + hipLaunchKernelGGL( + (ScaleKernel), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + x, + y); } template <> -void Scale(const int n, const float alpha, - const float16 *x, float16 *y, - HIPContext *context) { - hipLaunchKernelGGL((ScaleKernel), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - alpha, x, y); +void Scale( + const int n, + const float alpha, + const float16* x, + float16* y, + HIPContext* context) { + hipLaunchKernelGGL( + (ScaleKernel), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + x, + y); } template <> -void Scale(const int n, const float *alpha, const float *x, - float *y, HIPContext *context) { - hipLaunchKernelGGL((ScaleKernelDeviceAlpha), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - alpha, x, y); +void Scale( + const int n, + const float* alpha, + const float* x, + float* y, + HIPContext* context) { + hipLaunchKernelGGL( + (ScaleKernelDeviceAlpha), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + x, + y); } template <> -void Scale(const int n, const float *alpha, - const float16 *x, float16 *y, - HIPContext *context) { - hipLaunchKernelGGL((ScaleKernelDeviceAlpha), - dim3(CAFFE_GET_BLOCKS(n)), dim3(CAFFE_HIP_NUM_THREADS), 0, - context->hip_stream(), n, alpha, x, y); +void Scale( + const int n, + const float* alpha, + const float16* x, + float16* y, + HIPContext* context) { + hipLaunchKernelGGL( + (ScaleKernelDeviceAlpha), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + x, + y); } template <> -void Axpy(const int N, const float alpha, const float *X, - float *Y, HIPContext *context) { +void Axpy( + const int N, + const float alpha, + const float* X, + float* Y, + HIPContext* context) { ROCBLAS_ENFORCE( rocblas_saxpy(context->rocblas_handle(), N, &alpha, X, 1, Y, 1)); } template <> -void Axpy(const int N, const float alpha, const double *X, - double *Y, HIPContext *context) { +void Axpy( + const int N, + const float alpha, + const double* X, + double* Y, + HIPContext* context) { double alpha_d{alpha}; ROCBLAS_ENFORCE( rocblas_daxpy(context->rocblas_handle(), N, &alpha_d, X, 1, Y, 1)); } template <> -void Axpy(const int N, const float alpha, const float16 *X, - float16 *Y, HIPContext *context) { +void Axpy( + const int N, + const float alpha, + const float16* X, + float16* Y, + HIPContext* context) { CAFFE_THROW("Unsupported math type"); #if ROCBLAS_FP16 - ROCBLAS_CHECK(cublasAxpyEx(context->rocblas_handle(), N, &alpha, CUDA_R_16F, - X, CUDA_R_16F, 1, Y, CUDA_R_16F, 1, CUDA_R_32F)); + ROCBLAS_CHECK(cublasAxpyEx( + context->rocblas_handle(), + N, + &alpha, + CUDA_R_16F, + X, + CUDA_R_16F, + 1, + Y, + CUDA_R_16F, + 1, + CUDA_R_32F)); #endif } namespace { template -__global__ void AxpyKernel(const int n, const float *a, const T *x, T *y) { +__global__ void AxpyKernel(const int n, const float* a, const T* x, T* y) { HIP_1D_KERNEL_LOOP(index, n) { - y[index] = convert::Get(convert::Get(x[index]) * (*a) + - convert::Get(y[index])); + y[index] = convert::Get( + convert::Get(x[index]) * (*a) + convert::Get(y[index])); } } } // namespace template <> -void Axpy(const int n, const float *alpha, const float *X, - float *Y, HIPContext *context) { - hipLaunchKernelGGL((AxpyKernel), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - alpha, X, Y); +void Axpy( + const int n, + const float* alpha, + const float* X, + float* Y, + HIPContext* context) { + hipLaunchKernelGGL( + (AxpyKernel), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + X, + Y); } template <> -void Axpy(const int n, const float *alpha, - const float16 *X, float16 *Y, - HIPContext *context) { - hipLaunchKernelGGL((AxpyKernel), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - alpha, X, Y); +void Axpy( + const int n, + const float* alpha, + const float16* X, + float16* Y, + HIPContext* context) { + hipLaunchKernelGGL( + (AxpyKernel), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + alpha, + X, + Y); } namespace { template -__global__ void AxpbyKernel(const int n, const T a, const T *x, const T b, - T *y) { - HIP_1D_KERNEL_LOOP(index, n) { y[index] = x[index] * a + y[index] * b; } +__global__ void +AxpbyKernel(const int n, const T a, const T* x, const T b, T* y) { + HIP_1D_KERNEL_LOOP(index, n) { + y[index] = x[index] * a + y[index] * b; + } } } // namespace template <> -void Axpby(const int n, const float a, const float *x, - const float b, float *y, HIPContext *context) { - hipLaunchKernelGGL((AxpbyKernel), dim3(CAFFE_GET_BLOCKS(n)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), n, - a, x, b, y); +void Axpby( + const int n, + const float a, + const float* x, + const float b, + float* y, + HIPContext* context) { + hipLaunchKernelGGL( + (AxpbyKernel), + dim3(CAFFE_GET_BLOCKS(n)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + n, + a, + x, + b, + y); } namespace { template __global__ void Im2ColNCHWHIPKernel( - const int n, const int input_h, const int input_w, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int stride_h, const int stride_w, - const int output_h, const int output_w, const T *img_data, T *col_data) { + const int n, + const int input_h, + const int input_w, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int stride_h, + const int stride_w, + const int output_h, + const int output_w, + const T* img_data, + T* col_data) { HIP_1D_KERNEL_LOOP(index, n) { const int w_out = index % output_w; const int h_index = index / output_w; @@ -1335,9 +2017,9 @@ __global__ void Im2ColNCHWHIPKernel( const int h_in = h_out * stride_h - pad_t; const int w_in = w_out * stride_w - pad_l; const int output_size = output_h * output_w; - T *col_data_ptr = + T* col_data_ptr = col_data + (channel_out * output_h + h_out) * output_w + w_out; - const T *img_data_ptr = + const T* img_data_ptr = img_data + (channel_in * input_h + h_in) * input_w + w_in; int dh = 0; for (int i = 0; i < kernel_h; ++i) { @@ -1346,8 +2028,8 @@ __global__ void Im2ColNCHWHIPKernel( const int h = h_in + dh; const int w = w_in + dw; *col_data_ptr = (h >= 0 && w >= 0 && h < input_h && w < input_w) - ? __ldg(img_data_ptr + dh * input_w + dw) - : 0; + ? __ldg(img_data_ptr + dh * input_w + dw) + : 0; col_data_ptr += output_size; dw += dilation_w; } @@ -1358,18 +2040,29 @@ __global__ void Im2ColNCHWHIPKernel( template __global__ void Im2ColNHWCHIPKernel( - const int n, const int input_h, const int input_w, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int stride_h, const int stride_w, - const int output_w, const int channels, const T *img_data, T *col_data) { + const int n, + const int input_h, + const int input_w, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int stride_h, + const int stride_w, + const int output_w, + const int channels, + const T* img_data, + T* col_data) { HIP_1D_KERNEL_LOOP(index, n) { const int channel_in = index % channels; const int w_out = index / channels % output_w; const int h_out = index / channels / output_w; const int h_in = h_out * stride_h - pad_t; const int w_in = w_out * stride_w - pad_l; - T *col_data_ptr = - col_data + (h_out * output_w + w_out) * channels * kernel_h * kernel_w + + T* col_data_ptr = col_data + + (h_out * output_w + w_out) * channels * kernel_h * kernel_w + channel_in; int dh = 0; for (int i = 0; i < kernel_h; ++i) { @@ -1377,10 +2070,9 @@ __global__ void Im2ColNHWCHIPKernel( for (int j = 0; j < kernel_w; ++j) { const int h = h_in + dh; const int w = w_in + dw; - *col_data_ptr = - (h >= 0 && w >= 0 && h < input_h && w < input_w) - ? __ldg(img_data + (h * input_w + w) * channels + channel_in) - : 0; + *col_data_ptr = (h >= 0 && w >= 0 && h < input_h && w < input_w) + ? __ldg(img_data + (h * input_w + w) * channels + channel_in) + : 0; col_data_ptr += channels; dw += dilation_w; } @@ -1390,12 +2082,22 @@ __global__ void Im2ColNHWCHIPKernel( } template -__global__ void -Col2ImNCHWHIPKernel(const int n, const int input_h, const int input_w, - const int patch_h, const int patch_w, const int dilation_h, - const int dilation_w, const int pad_t, const int pad_l, - const int stride_h, const int stride_w, const int output_h, - const int output_w, const T *col_data, T *img_data) { +__global__ void Col2ImNCHWHIPKernel( + const int n, + const int input_h, + const int input_w, + const int patch_h, + const int patch_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int stride_h, + const int stride_w, + const int output_h, + const int output_w, + const T* col_data, + T* img_data) { const int dpatch_h = dilation_h * (patch_h - 1) + 1; const int dpatch_w = dilation_w * (patch_w - 1) + 1; @@ -1431,12 +2133,22 @@ Col2ImNCHWHIPKernel(const int n, const int input_h, const int input_w, } template -__global__ void -Col2ImNHWCHIPKernel(const int n, const int input_w, const int channels, - const int patch_h, const int patch_w, const int dilation_h, - const int dilation_w, const int pad_t, const int pad_l, - const int stride_h, const int stride_w, const int output_h, - const int output_w, const T *col_data, T *img_data) { +__global__ void Col2ImNHWCHIPKernel( + const int n, + const int input_w, + const int channels, + const int patch_h, + const int patch_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int stride_h, + const int stride_w, + const int output_h, + const int output_w, + const T* col_data, + T* img_data) { const int dpatch_h = dilation_h * (patch_h - 1) + 1; const int dpatch_w = dilation_w * (patch_w - 1) + 1; @@ -1460,8 +2172,8 @@ Col2ImNHWCHIPKernel(const int n, const int input_w, const int channels, h_k /= dilation_h; w_k /= dilation_w; const int c_col = (h_k * patch_w + w_k) * channels + c; - val += __ldg(col_data + (h_col * output_w + w_col) * channels_col + - c_col); + val += __ldg( + col_data + (h_col * output_w + w_col) * channels_col + c_col); } } } @@ -1470,13 +2182,18 @@ Col2ImNHWCHIPKernel(const int n, const int input_w, const int channels, } template -__global__ void -Im2ColNdNCHWHIPKernel(const int outer_size, const int inner_size, - const int kernel_size, SimpleArray img_shape, - SimpleArray col_shape, - SimpleArray kernel_shape, - SimpleArray stride, SimpleArray dilation, - SimpleArray pad, const T *X_data, T *Y_data) { +__global__ void Im2ColNdNCHWHIPKernel( + const int outer_size, + const int inner_size, + const int kernel_size, + SimpleArray img_shape, + SimpleArray col_shape, + SimpleArray kernel_shape, + SimpleArray stride, + SimpleArray dilation, + SimpleArray pad, + const T* X_data, + T* Y_data) { int d_offset[N]; int d_iter[N]; for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { @@ -1499,7 +2216,7 @@ Im2ColNdNCHWHIPKernel(const int outer_size, const int inner_size, #pragma unroll for (int d_i = 0; d_i < N; ++d_i) { const int d_img = d_iter[d_i] * stride.data[d_i] - pad.data[d_i] + - d_offset[d_i] * dilation.data[d_i]; + d_offset[d_i] * dilation.data[d_i]; is_padding |= d_img < 0 || d_img >= img_shape.data[d_i + 1]; img_index = img_index * img_shape.data[d_i + 1] + d_img; } @@ -1513,16 +2230,22 @@ Im2ColNdNCHWHIPKernel(const int outer_size, const int inner_size, } template -void Im2ColNdNCHWHIPImpl(const int img_size, const int col_size, - const int *img_shape, const int *col_shape, - const int *kernel_shape, const int *stride, - const int *dilation, const int *pad, - const float *img_data, float *col_data, - HIPContext *context) { +void Im2ColNdNCHWHIPImpl( + const int img_size, + const int col_size, + const int* img_shape, + const int* col_shape, + const int* kernel_shape, + const int* stride, + const int* dilation, + const int* pad, + const float* img_data, + float* col_data, + HIPContext* context) { const int outer_size = col_shape[0]; const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate(kernel_shape, kernel_shape + N, 1, - std::multiplies()); + const int kernel_size = std::accumulate( + kernel_shape, kernel_shape + N, 1, std::multiplies()); SimpleArray img_shape_array; SimpleArray col_shape_array; SimpleArray kernel_shape_array; @@ -1535,25 +2258,42 @@ void Im2ColNdNCHWHIPImpl(const int img_size, const int col_size, std::memcpy(stride_array.data, stride, N * sizeof(int)); std::memcpy(dilation_array.data, dilation, N * sizeof(int)); std::memcpy(pad_array.data, pad, N * sizeof(int)); - hipLaunchKernelGGL((Im2ColNdNCHWHIPKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, kernel_size, img_shape_array, - col_shape_array, kernel_shape_array, stride_array, - dilation_array, pad_array, img_data, col_data); + hipLaunchKernelGGL( + (Im2ColNdNCHWHIPKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + kernel_size, + img_shape_array, + col_shape_array, + kernel_shape_array, + stride_array, + dilation_array, + pad_array, + img_data, + col_data); } template -void Col2ImNdNCHWHIPImpl(const int img_size, const int col_size, - const int *img_shape, const int *col_shape, - const int *kernel_shape, const int *stride, - const int *dilation, const int *pad, - const float *col_data, float *img_data, - HIPContext *context) { +void Col2ImNdNCHWHIPImpl( + const int img_size, + const int col_size, + const int* img_shape, + const int* col_shape, + const int* kernel_shape, + const int* stride, + const int* dilation, + const int* pad, + const float* col_data, + float* img_data, + HIPContext* context) { const int outer_size = col_shape[0]; const int inner_size = col_size / outer_size; - const int kernel_size = std::accumulate(kernel_shape, kernel_shape + N, 1, - std::multiplies()); + const int kernel_size = std::accumulate( + kernel_shape, kernel_shape + N, 1, std::multiplies()); SimpleArray img_shape_array; SimpleArray col_shape_array; SimpleArray kernel_shape_array; @@ -1567,130 +2307,309 @@ void Col2ImNdNCHWHIPImpl(const int img_size, const int col_size, std::memcpy(dilation_array.data, dilation, N * sizeof(int)); std::memcpy(pad_array.data, pad, N * sizeof(int)); Set(img_size, 0, img_data, context); - hipLaunchKernelGGL((Im2ColNdNCHWHIPKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, kernel_size, img_shape_array, - col_shape_array, kernel_shape_array, stride_array, - dilation_array, pad_array, col_data, img_data); + hipLaunchKernelGGL( + (Im2ColNdNCHWHIPKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + kernel_size, + img_shape_array, + col_shape_array, + kernel_shape_array, + stride_array, + dilation_array, + pad_array, + col_data, + img_data); } } // namespace template <> void Im2Col( - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int pad_b, const int pad_r, - const int stride_h, const int stride_w, const float *img_data, - float *col_data, HIPContext *context) { + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int pad_b, + const int pad_r, + const int stride_h, + const int stride_w, + const float* img_data, + float* col_data, + HIPContext* context) { const int dkernel_h = dilation_h * (kernel_h - 1) + 1; const int dkernel_w = dilation_w * (kernel_w - 1) + 1; const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; const int num_kernels = channels * output_h * output_w; hipLaunchKernelGGL( - (Im2ColNCHWHIPKernel), dim3(CAFFE_GET_BLOCKS(num_kernels)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), num_kernels, - height, width, kernel_h, kernel_w, dilation_h, dilation_w, pad_t, pad_l, - stride_h, stride_w, output_h, output_w, img_data, col_data); + (Im2ColNCHWHIPKernel), + dim3(CAFFE_GET_BLOCKS(num_kernels)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + num_kernels, + height, + width, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_t, + pad_l, + stride_h, + stride_w, + output_h, + output_w, + img_data, + col_data); } template <> void Im2Col( - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int pad_b, const int pad_r, - const int stride_h, const int stride_w, const float *img_data, - float *col_data, HIPContext *context) { + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int pad_b, + const int pad_r, + const int stride_h, + const int stride_w, + const float* img_data, + float* col_data, + HIPContext* context) { const int dkernel_h = dilation_h * (kernel_h - 1) + 1; const int dkernel_w = dilation_w * (kernel_w - 1) + 1; const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; const int num_kernels = output_h * output_w * channels; hipLaunchKernelGGL( - (Im2ColNHWCHIPKernel), dim3(CAFFE_GET_BLOCKS(num_kernels)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), num_kernels, - height, width, kernel_h, kernel_w, dilation_h, dilation_w, pad_t, pad_l, - stride_h, stride_w, output_w, channels, img_data, col_data); + (Im2ColNHWCHIPKernel), + dim3(CAFFE_GET_BLOCKS(num_kernels)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + num_kernels, + height, + width, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_t, + pad_l, + stride_h, + stride_w, + output_w, + channels, + img_data, + col_data); } template <> void Col2Im( - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int pad_b, const int pad_r, - const int stride_h, const int stride_w, const float *col_data, - float *img_data, HIPContext *context) { + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int pad_b, + const int pad_r, + const int stride_h, + const int stride_w, + const float* col_data, + float* img_data, + HIPContext* context) { const int dkernel_h = dilation_h * (kernel_h - 1) + 1; const int dkernel_w = dilation_w * (kernel_w - 1) + 1; const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; const int num_kernels = channels * height * width; hipLaunchKernelGGL( - (Col2ImNCHWHIPKernel), dim3(CAFFE_GET_BLOCKS(num_kernels)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), num_kernels, - height, width, kernel_h, kernel_w, dilation_h, dilation_w, pad_t, pad_l, - stride_h, stride_w, output_h, output_w, col_data, img_data); + (Col2ImNCHWHIPKernel), + dim3(CAFFE_GET_BLOCKS(num_kernels)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + num_kernels, + height, + width, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_t, + pad_l, + stride_h, + stride_w, + output_h, + output_w, + col_data, + img_data); } template <> void Col2Im( - const int channels, const int height, const int width, const int kernel_h, - const int kernel_w, const int dilation_h, const int dilation_w, - const int pad_t, const int pad_l, const int pad_b, const int pad_r, - const int stride_h, const int stride_w, const float *col_data, - float *img_data, HIPContext *context) { + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int dilation_h, + const int dilation_w, + const int pad_t, + const int pad_l, + const int pad_b, + const int pad_r, + const int stride_h, + const int stride_w, + const float* col_data, + float* img_data, + HIPContext* context) { const int dkernel_h = dilation_h * (kernel_h - 1) + 1; const int dkernel_w = dilation_w * (kernel_w - 1) + 1; const int output_h = (height + pad_t + pad_b - dkernel_h) / stride_h + 1; const int output_w = (width + pad_l + pad_r - dkernel_w) / stride_w + 1; const int num_kernels = height * width * channels; hipLaunchKernelGGL( - (Col2ImNHWCHIPKernel), dim3(CAFFE_GET_BLOCKS(num_kernels)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), num_kernels, width, - channels, kernel_h, kernel_w, dilation_h, dilation_w, pad_t, pad_l, - stride_h, stride_w, output_h, output_w, col_data, img_data); + (Col2ImNHWCHIPKernel), + dim3(CAFFE_GET_BLOCKS(num_kernels)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + num_kernels, + width, + channels, + kernel_h, + kernel_w, + dilation_h, + dilation_w, + pad_t, + pad_l, + stride_h, + stride_w, + output_h, + output_w, + col_data, + img_data); } template <> void Im2ColNd( - const int N, const int img_size, const int col_size, const int *img_shape, - const int *col_shape, const int *kernel_shape, const int *stride, - const int *dilation, const int *pad, const float *img_data, float *col_data, - HIPContext *context) { + const int N, + const int img_size, + const int col_size, + const int* img_shape, + const int* col_shape, + const int* kernel_shape, + const int* stride, + const int* dilation, + const int* pad, + const float* img_data, + float* col_data, + HIPContext* context) { DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, Im2ColNdNCHWHIPImpl, float, img_size, col_size, img_shape, col_shape, - kernel_shape, stride, dilation, pad, img_data, col_data, context); + N, + Im2ColNdNCHWHIPImpl, + float, + img_size, + col_size, + img_shape, + col_shape, + kernel_shape, + stride, + dilation, + pad, + img_data, + col_data, + context); } template <> void Col2ImNd( - const int N, const int img_size, const int col_size, const int *img_shape, - const int *col_shape, const int *kernel_shape, const int *stride, - const int *dilation, const int *pad, const float *col_data, float *img_data, - HIPContext *context) { + const int N, + const int img_size, + const int col_size, + const int* img_shape, + const int* col_shape, + const int* kernel_shape, + const int* stride, + const int* dilation, + const int* pad, + const float* col_data, + float* img_data, + HIPContext* context) { DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - N, Col2ImNdNCHWHIPImpl, float, img_size, col_size, img_shape, col_shape, - kernel_shape, stride, dilation, pad, col_data, img_data, context); + N, + Col2ImNdNCHWHIPImpl, + float, + img_size, + col_size, + img_shape, + col_shape, + kernel_shape, + stride, + dilation, + pad, + col_data, + img_data, + context); } template <> -void CopyMatrix(const size_t itemsize, const int M, const int N, - const void *A, const int lda, void *B, - const int ldb, HIPContext *context, - TypeMeta::TypedCopy copy) { +void CopyMatrix( + const size_t itemsize, + const int M, + const int N, + const void* A, + const int lda, + void* B, + const int ldb, + HIPContext* context, + TypeMeta::TypedCopy copy) { CAFFE_ENFORCE(!copy, "Copy constructor is not supported in HIP context"); - hipMemcpy2DAsync(B, ldb * itemsize, A, lda * itemsize, N * itemsize, M, - hipMemcpyDeviceToDevice, context->hip_stream()); + hipMemcpy2DAsync( + B, + ldb * itemsize, + A, + lda * itemsize, + N * itemsize, + M, + hipMemcpyDeviceToDevice, + context->hip_stream()); } template <> -void CopyVector(const int N, const float *src, float *dst, - HIPContext *context) { +void CopyVector( + const int N, + const float* src, + float* dst, + HIPContext* context) { if (src != dst && N > 0) { - hipMemcpyAsync(dst, src, sizeof(float) * N, hipMemcpyDeviceToDevice, - context->hip_stream()); + hipMemcpyAsync( + dst, + src, + sizeof(float) * N, + hipMemcpyDeviceToDevice, + context->hip_stream()); } } @@ -1700,9 +2619,13 @@ template using BlockReduce = cub::BlockReduce; template -__global__ void RowwiseReduceKernel(const int rows, const int cols, - const Reducer reducer, const T init, - const T *X, T *Y) { +__global__ void RowwiseReduceKernel( + const int rows, + const int cols, + const Reducer reducer, + const T init, + const T* X, + T* Y) { __shared__ typename BlockReduce::TempStorage temp_storage; for (int i = blockIdx.x; i < rows; i += gridDim.x) { T val = init; @@ -1718,9 +2641,13 @@ __global__ void RowwiseReduceKernel(const int rows, const int cols, } template -__global__ void ColwiseReduceKernel(const int rows, const int cols, - const Reducer reducer, const T init, - const T *X, T *Y) { +__global__ void ColwiseReduceKernel( + const int rows, + const int cols, + const Reducer reducer, + const T init, + const T* X, + T* Y) { __shared__ typename BlockReduce::TempStorage temp_storage; for (int i = blockIdx.x; i < cols; i += gridDim.x) { T val = init; @@ -1737,53 +2664,86 @@ __global__ void ColwiseReduceKernel(const int rows, const int cols, } // namespace -#define CAFFE2_SPECIALIZED_HIP_ROWWISE_MAX(T) \ - template <> \ - void RowwiseMax(const int N, const int D, const T *x, T *y, \ - HIPContext *context) { \ - hipLaunchKernelGGL(RowwiseReduceKernel, \ - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), N, D, \ - cub::Max(), std::numeric_limits::lowest(), x, y); \ +#define CAFFE2_SPECIALIZED_HIP_ROWWISE_MAX(T) \ + template <> \ + void RowwiseMax( \ + const int N, const int D, const T* x, T* y, HIPContext* context) { \ + hipLaunchKernelGGL( \ + RowwiseReduceKernel, \ + std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + D, \ + cub::Max(), \ + std::numeric_limits::lowest(), \ + x, \ + y); \ } CAFFE2_SPECIALIZED_HIP_ROWWISE_MAX(float) #undef CAFFE2_SPECIALIZED_HIP_ROWWISE_MAX -#define CAFFE2_SPECIALIZED_HIP_COLWISE_MAX(T) \ - template <> \ - void ColwiseMax(const int N, const int D, const T *x, T *y, \ - HIPContext *context) { \ - hipLaunchKernelGGL(ColwiseReduceKernel, \ - std::min(D, CAFFE_MAXIMUM_NUM_BLOCKS), \ - CAFFE_HIP_NUM_THREADS, 0, context->hip_stream(), N, D, \ - cub::Max(), std::numeric_limits::lowest(), x, y); \ +#define CAFFE2_SPECIALIZED_HIP_COLWISE_MAX(T) \ + template <> \ + void ColwiseMax( \ + const int N, const int D, const T* x, T* y, HIPContext* context) { \ + hipLaunchKernelGGL( \ + ColwiseReduceKernel, \ + std::min(D, CAFFE_MAXIMUM_NUM_BLOCKS), \ + CAFFE_HIP_NUM_THREADS, \ + 0, \ + context->hip_stream(), \ + N, \ + D, \ + cub::Max(), \ + std::numeric_limits::lowest(), \ + x, \ + y); \ } CAFFE2_SPECIALIZED_HIP_COLWISE_MAX(float) #undef CAFFE2_SPECIALIZED_HIP_COLWISE_MAX namespace { -__global__ void maximum_kernel(const int N, const float alpha, const float *x, - float *y) { - HIP_1D_KERNEL_LOOP(i, N) { y[i] = fmaxf(x[i], alpha); } +__global__ void +maximum_kernel(const int N, const float alpha, const float* x, float* y) { + HIP_1D_KERNEL_LOOP(i, N) { + y[i] = fmaxf(x[i], alpha); + } } } // namespace template <> -void Maximum(const int N, const float alpha, const float *x, float *y, - HIPContext *context) { +void Maximum( + const int N, + const float alpha, + const float* x, + float* y, + HIPContext* context) { hipLaunchKernelGGL( - (maximum_kernel), dim3(std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), N, alpha, x, y); + (maximum_kernel), + dim3(std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + N, + alpha, + x, + y); } namespace { template -__global__ void -ReduceTensorHIPKernel(const int outer_size, const int inner_size, - SimpleArray X_strides, - SimpleArray, D> Y_dims, - const Reducer reducer, const T init, const T *X, T *Y) { +__global__ void ReduceTensorHIPKernel( + const int outer_size, + const int inner_size, + SimpleArray X_strides, + SimpleArray, D> Y_dims, + const Reducer reducer, + const T init, + const T* X, + T* Y) { __shared__ typename BlockReduce::TempStorage temp_storage; for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { T val = init; @@ -1811,34 +2771,56 @@ ReduceTensorHIPKernel(const int outer_size, const int inner_size, } template -void ReduceTensorHIPImpl(const int outer_size, const int inner_size, - const int *dims, const int *axes, - const Reducer &reducer, const T &init, const T *X, - T *Y, HIPContext *context) { +void ReduceTensorHIPImpl( + const int outer_size, + const int inner_size, + const int* dims, + const int* axes, + const Reducer& reducer, + const T& init, + const T* X, + T* Y, + HIPContext* context) { SimpleArray X_strides; SimpleArray, D> Y_dims; utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); for (int i = 0; i < D; ++i) { Y_dims.data[i] = FixedDivisor(dims[axes[i]]); } - hipLaunchKernelGGL((ReduceTensorHIPKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, X_strides, Y_dims, reducer, init, - X, Y); + hipLaunchKernelGGL( + (ReduceTensorHIPKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + X_strides, + Y_dims, + reducer, + init, + X, + Y); } template -void ReduceTensorHIP(const int num_dims, const int *dims, const int num_axes, - const int *axes, const Reducer &reducer, const T &init, - const T *X, T *Y, HIPContext *context) { +void ReduceTensorHIP( + const int num_dims, + const int* dims, + const int num_axes, + const int* axes, + const Reducer& reducer, + const T& init, + const T* X, + T* Y, + HIPContext* context) { CAFFE_ENFORCE_LE(num_axes, num_dims); if (X == Y) { return; } std::vector transpose_axes(num_dims); - utils::ComputeTransposeAxesForReduceOp(num_dims, num_axes, axes, - transpose_axes.data()); + utils::ComputeTransposeAxesForReduceOp( + num_dims, num_axes, axes, transpose_axes.data()); const int pivot = num_dims - num_axes; int outer_size = 1; for (int i = 0; i < pivot; ++i) { @@ -1850,25 +2832,50 @@ void ReduceTensorHIP(const int num_dims, const int *dims, const int num_axes, } if (outer_size > 0 && inner_size > 0) { if (transpose_axes[pivot] == pivot) { - hipLaunchKernelGGL((RowwiseReduceKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, reducer, init, X, Y); + hipLaunchKernelGGL( + (RowwiseReduceKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + reducer, + init, + X, + Y); return; } DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_2( - num_dims, ReduceTensorHIPImpl, T, Reducer, outer_size, inner_size, dims, - transpose_axes.data(), reducer, init, X, Y, context); + num_dims, + ReduceTensorHIPImpl, + T, + Reducer, + outer_size, + inner_size, + dims, + transpose_axes.data(), + reducer, + init, + X, + Y, + context); } else if (outer_size > 0) { math::Set(outer_size, init, Y, context); } } template -void ReduceMeanHIPImpl(const int num_dims, const int *dims, const int num_axes, - const int *axes, const T *X, T *Y, HIPContext *context) { - ReduceTensorHIP(num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, - context); +void ReduceMeanHIPImpl( + const int num_dims, + const int* dims, + const int num_axes, + const int* axes, + const T* X, + T* Y, + HIPContext* context) { + ReduceTensorHIP( + num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context); const int X_size = std::accumulate(dims, dims + num_dims, 1, std::multiplies()); int scale = 1; @@ -1881,13 +2888,26 @@ void ReduceMeanHIPImpl(const int num_dims, const int *dims, const int num_axes, } // namespace -#define CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(T) \ - template <> \ - void ReduceMin(const int num_dims, const int *dims, \ - const int num_axes, const int *axes, \ - const T *X, T *Y, HIPContext *context) { \ - ReduceTensorHIP(num_dims, dims, num_axes, axes, cub::Min(), \ - std::numeric_limits::max(), X, Y, context); \ +#define CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(T) \ + template <> \ + void ReduceMin( \ + const int num_dims, \ + const int* dims, \ + const int num_axes, \ + const int* axes, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + ReduceTensorHIP( \ + num_dims, \ + dims, \ + num_axes, \ + axes, \ + cub::Min(), \ + std::numeric_limits::max(), \ + X, \ + Y, \ + context); \ } CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(std::int32_t) CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(std::int64_t) @@ -1895,13 +2915,26 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(float) CAFFE2_SPECIALIZED_HIP_REDUCE_MIN(double) #undef CAFFE2_SPECIALIZED_HIP_REDUCE_MIN -#define CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(T) \ - template <> \ - void ReduceMax(const int num_dims, const int *dims, \ - const int num_axes, const int *axes, \ - const T *X, T *Y, HIPContext *context) { \ - ReduceTensorHIP(num_dims, dims, num_axes, axes, cub::Max(), \ - std::numeric_limits::lowest(), X, Y, context); \ +#define CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(T) \ + template <> \ + void ReduceMax( \ + const int num_dims, \ + const int* dims, \ + const int num_axes, \ + const int* axes, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + ReduceTensorHIP( \ + num_dims, \ + dims, \ + num_axes, \ + axes, \ + cub::Max(), \ + std::numeric_limits::lowest(), \ + X, \ + Y, \ + context); \ } CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(std::int32_t) CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(std::int64_t) @@ -1909,13 +2942,18 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(float) CAFFE2_SPECIALIZED_HIP_REDUCE_MAX(double) #undef CAFFE2_SPECIALIZED_HIP_REDUCE_MAX -#define CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(T) \ - template <> \ - void ReduceSum(const int num_dims, const int *dims, \ - const int num_axes, const int *axes, \ - const T *X, T *Y, HIPContext *context) { \ - ReduceTensorHIP(num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, \ - context); \ +#define CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(T) \ + template <> \ + void ReduceSum( \ + const int num_dims, \ + const int* dims, \ + const int num_axes, \ + const int* axes, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + ReduceTensorHIP( \ + num_dims, dims, num_axes, axes, cub::Sum(), T(0), X, Y, context); \ } CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(std::int32_t) CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(std::int64_t) @@ -1923,12 +2961,17 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(float) CAFFE2_SPECIALIZED_HIP_REDUCE_SUM(double) #undef CAFFE2_SPECIALIZED_HIP_REDUCE_SUM -#define CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN(T) \ - template <> \ - void ReduceMean(const int num_dims, const int *dims, \ - const int num_axes, const int *axes, \ - const T *X, T *Y, HIPContext *context) { \ - ReduceMeanHIPImpl(num_dims, dims, num_axes, axes, X, Y, context); \ +#define CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN(T) \ + template <> \ + void ReduceMean( \ + const int num_dims, \ + const int* dims, \ + const int num_axes, \ + const int* axes, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + ReduceMeanHIPImpl(num_dims, dims, num_axes, axes, X, Y, context); \ } CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN(float) #undef CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN @@ -1936,16 +2979,20 @@ CAFFE2_SPECIALIZED_HIP_REDUCE_MEAN(float) namespace { template -__global__ void -BroadcastHIPKernel(const int Y_size, const SimpleArray X_strides, - const SimpleArray Y_dims, const T *X, T *Y) { +__global__ void BroadcastHIPKernel( + const int Y_size, + const SimpleArray X_strides, + const SimpleArray Y_dims, + const T* X, + T* Y) { HIP_1D_KERNEL_LOOP(Y_index, Y_size) { int X_index = 0; int Y_index_val = Y_index; #pragma unroll for (int i = D - 1; i >= 0; --i) { - X_index += X_strides.data[i] == 0 ? 0 : (Y_index_val % Y_dims.data[i]) * - X_strides.data[i]; + X_index += X_strides.data[i] == 0 + ? 0 + : (Y_index_val % Y_dims.data[i]) * X_strides.data[i]; Y_index_val /= Y_dims.data[i]; } Y[Y_index] = __ldg(X + X_index); @@ -1953,8 +3000,13 @@ BroadcastHIPKernel(const int Y_size, const SimpleArray X_strides, } template -void BroadcastHIPImpl(const int X_ndim, const int *X_dims, const int *Y_dims, - const T *X, T *Y, HIPContext *context) { +void BroadcastHIPImpl( + const int X_ndim, + const int* X_dims, + const int* Y_dims, + const T* X, + T* Y, + HIPContext* context) { SimpleArray X_strides_array; SimpleArray Y_dims_array; const int d = D - X_ndim; @@ -1968,21 +3020,34 @@ void BroadcastHIPImpl(const int X_ndim, const int *X_dims, const int *Y_dims, std::copy_n(Y_dims, D, Y_dims_array.data); const int Y_size = std::accumulate(Y_dims, Y_dims + D, 1, std::multiplies()); - hipLaunchKernelGGL((BroadcastHIPKernel), dim3(CAFFE_GET_BLOCKS(Y_size)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - Y_size, X_strides_array, Y_dims_array, X, Y); + hipLaunchKernelGGL( + (BroadcastHIPKernel), + dim3(CAFFE_GET_BLOCKS(Y_size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + Y_size, + X_strides_array, + Y_dims_array, + X, + Y); } } // namespace -#define CAFFE2_SPECIALIZED_HIP_BROADCAST(T) \ - template <> \ - void Broadcast(const int X_ndim, const int *X_dims, \ - const int Y_ndim, const int *Y_dims, \ - const T *X, T *Y, HIPContext *context) { \ - CAFFE_ENFORCE_LE(X_ndim, Y_ndim); \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( \ - Y_ndim, BroadcastHIPImpl, T, X_ndim, X_dims, Y_dims, X, Y, context); \ +#define CAFFE2_SPECIALIZED_HIP_BROADCAST(T) \ + template <> \ + void Broadcast( \ + const int X_ndim, \ + const int* X_dims, \ + const int Y_ndim, \ + const int* Y_dims, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + CAFFE_ENFORCE_LE(X_ndim, Y_ndim); \ + DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( \ + Y_ndim, BroadcastHIPImpl, T, X_ndim, X_dims, Y_dims, X, Y, context); \ } CAFFE2_SPECIALIZED_HIP_BROADCAST(std::int32_t) CAFFE2_SPECIALIZED_HIP_BROADCAST(std::int64_t) @@ -1993,8 +3058,12 @@ CAFFE2_SPECIALIZED_HIP_BROADCAST(double) namespace { template -__global__ void RowwiseMomentsHIPKernel(const int rows, const int cols, - const T *X, T *mean, T *variance) { +__global__ void RowwiseMomentsHIPKernel( + const int rows, + const int cols, + const T* X, + T* mean, + T* variance) { __shared__ typename BlockReduce::TempStorage m_storage; __shared__ typename BlockReduce::TempStorage v_storage; for (int i = blockIdx.x; i < rows; i += gridDim.x) { @@ -2016,10 +3085,14 @@ __global__ void RowwiseMomentsHIPKernel(const int rows, const int cols, } template -__global__ void MomentsHIPKernel(const int outer_size, const int inner_size, - SimpleArray X_strides, - SimpleArray, D> Y_dims, - const T *X, T *mean, T *variance) { +__global__ void MomentsHIPKernel( + const int outer_size, + const int inner_size, + SimpleArray X_strides, + SimpleArray, D> Y_dims, + const T* X, + T* mean, + T* variance) { __shared__ typename BlockReduce::TempStorage m_storage; __shared__ typename BlockReduce::TempStorage v_storage; for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { @@ -2048,30 +3121,50 @@ __global__ void MomentsHIPKernel(const int outer_size, const int inner_size, } template -void MomentsHIPImpl(const int outer_size, const int inner_size, const int *dims, - const int *axes, const T *X, T *mean, T *variance, - HIPContext *context) { +void MomentsHIPImpl( + const int outer_size, + const int inner_size, + const int* dims, + const int* axes, + const T* X, + T* mean, + T* variance, + HIPContext* context) { SimpleArray X_strides; SimpleArray, D> Y_dims; utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); for (int i = 0; i < D; ++i) { Y_dims.data[i] = FixedDivisor(dims[axes[i]]); } - hipLaunchKernelGGL((MomentsHIPKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, X_strides, Y_dims, X, mean, - variance); + hipLaunchKernelGGL( + (MomentsHIPKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + X_strides, + Y_dims, + X, + mean, + variance); } template -void MomentsHIP(const int num_dims, const int *dims, const int num_axes, - const int *axes, const T *X, T *mean, T *variance, - HIPContext *context) { +void MomentsHIP( + const int num_dims, + const int* dims, + const int num_axes, + const int* axes, + const T* X, + T* mean, + T* variance, + HIPContext* context) { CAFFE_ENFORCE_LE(num_axes, num_dims); std::vector transpose_axes(num_dims); - utils::ComputeTransposeAxesForReduceOp(num_dims, num_axes, axes, - transpose_axes.data()); + utils::ComputeTransposeAxesForReduceOp( + num_dims, num_axes, axes, transpose_axes.data()); const int pivot = num_dims - num_axes; int outer_size = 1; for (int i = 0; i < pivot; ++i) { @@ -2083,15 +3176,31 @@ void MomentsHIP(const int num_dims, const int *dims, const int num_axes, } if (outer_size > 0 && inner_size > 0) { if (transpose_axes[pivot] == pivot) { - hipLaunchKernelGGL((RowwiseMomentsHIPKernel), - dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - outer_size, inner_size, X, mean, variance); + hipLaunchKernelGGL( + (RowwiseMomentsHIPKernel), + dim3(std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + outer_size, + inner_size, + X, + mean, + variance); return; } DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( - num_dims, MomentsHIPImpl, T, outer_size, inner_size, dims, - transpose_axes.data(), X, mean, variance, context); + num_dims, + MomentsHIPImpl, + T, + outer_size, + inner_size, + dims, + transpose_axes.data(), + X, + mean, + variance, + context); } } @@ -2099,9 +3208,15 @@ void MomentsHIP(const int num_dims, const int *dims, const int num_axes, #define CAFFE2_SPECIALIZED_HIP_MOMENTS(T) \ template <> \ - void Moments(const int num_dims, const int *dims, \ - const int num_axes, const int *axes, const T *X, \ - T *mean, T *variance, HIPContext *context) { \ + void Moments( \ + const int num_dims, \ + const int* dims, \ + const int num_axes, \ + const int* axes, \ + const T* X, \ + T* mean, \ + T* variance, \ + HIPContext* context) { \ MomentsHIP(num_dims, dims, num_axes, axes, X, mean, variance, context); \ } CAFFE2_SPECIALIZED_HIP_MOMENTS(float) @@ -2110,10 +3225,12 @@ CAFFE2_SPECIALIZED_HIP_MOMENTS(float) namespace { template -__global__ void -TransposeHIPKernel(const int size, const SimpleArray X_strides, - const SimpleArray, D> Y_dims, const T *X, - T *Y) { +__global__ void TransposeHIPKernel( + const int size, + const SimpleArray X_strides, + const SimpleArray, D> Y_dims, + const T* X, + T* Y) { HIP_1D_KERNEL_LOOP(Y_index, size) { int X_index = 0; int Y_index_val = Y_index; @@ -2128,8 +3245,12 @@ TransposeHIPKernel(const int size, const SimpleArray X_strides, } template -void TransposeHIPImpl(const int *dims, const int *axes, const T *X, T *Y, - HIPContext *context) { +void TransposeHIPImpl( + const int* dims, + const int* axes, + const T* X, + T* Y, + HIPContext* context) { SimpleArray X_strides; SimpleArray, D> Y_dims; utils::ComputeTransposedStrides(D, dims, axes, X_strides.data); @@ -2138,26 +3259,38 @@ void TransposeHIPImpl(const int *dims, const int *axes, const T *X, T *Y, Y_dims.data[i] = FixedDivisor(dims[axes[i]]); size *= dims[i]; } - hipLaunchKernelGGL((TransposeHIPKernel), dim3(CAFFE_GET_BLOCKS(size)), - dim3(CAFFE_HIP_NUM_THREADS), 0, context->hip_stream(), - size, X_strides, Y_dims, X, Y); + hipLaunchKernelGGL( + (TransposeHIPKernel), + dim3(CAFFE_GET_BLOCKS(size)), + dim3(CAFFE_HIP_NUM_THREADS), + 0, + context->hip_stream(), + size, + X_strides, + Y_dims, + X, + Y); } } // namespace -#define CAFFE2_SPECIALIZED_HIP_TRANSPOSE(T) \ - template <> \ - void Transpose(const int ndim, const int *dims, \ - const int *axes, const T *X, T *Y, \ - HIPContext *context) { \ - if (utils::IsIdentityPermutation(ndim, axes)) { \ - const int size = \ - std::accumulate(dims, dims + ndim, 1, std::multiplies()); \ - context->template Copy(size, X, Y); \ - return; \ - } \ - DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1(ndim, TransposeHIPImpl, T, dims, \ - axes, X, Y, context); \ +#define CAFFE2_SPECIALIZED_HIP_TRANSPOSE(T) \ + template <> \ + void Transpose( \ + const int ndim, \ + const int* dims, \ + const int* axes, \ + const T* X, \ + T* Y, \ + HIPContext* context) { \ + if (utils::IsIdentityPermutation(ndim, axes)) { \ + const int size = \ + std::accumulate(dims, dims + ndim, 1, std::multiplies()); \ + context->template Copy(size, X, Y); \ + return; \ + } \ + DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_1( \ + ndim, TransposeHIPImpl, T, dims, axes, X, Y, context); \ } CAFFE2_SPECIALIZED_HIP_TRANSPOSE(float) CAFFE2_SPECIALIZED_HIP_TRANSPOSE(double) diff --git a/caffe2/utils/hip/math_hip_test.cc b/caffe2/utils/hip/math_hip_test.cc index 19a4eed95632e..684d0218d322c 100644 --- a/caffe2/utils/hip/math_hip_test.cc +++ b/caffe2/utils/hip/math_hip_test.cc @@ -271,8 +271,11 @@ class GemmBatchedGPUTest trans_W_ = std::get<1>(GetParam()); } - void RunGemmBatched(const float alpha, const float beta) { - math::GemmBatched( + void RunGemmStridedBatched(const float alpha, const float beta) { + const int X_stride = 5 * 10; + const int W_stride = 10 * 6; + const int Y_stride = 5 * 6; + math::GemmStridedBatched( trans_X_ ? CblasTrans : CblasNoTrans, trans_W_ ? CblasTrans : CblasNoTrans, 3, @@ -281,9 +284,12 @@ class GemmBatchedGPUTest 10, alpha, X_->template data(), + X_stride, W_->template data(), + W_stride, beta, Y_->template mutable_data(), + Y_stride, hip_context_.get()); } @@ -308,11 +314,11 @@ TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) { if (!HasHipGPU()) { return; } - RunGemmBatched(1.0f, 0.0f); + RunGemmStridedBatched(1.0f, 0.0f); VerifyOutput(10.0f); - RunGemmBatched(1.0f, 0.5f); + RunGemmStridedBatched(1.0f, 0.5f); VerifyOutput(15.0f); - RunGemmBatched(0.5f, 1.0f); + RunGemmStridedBatched(0.5f, 1.0f); VerifyOutput(20.0f); } diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h index ca8535e4aa3dd..ac3b06681d858 100644 --- a/caffe2/utils/math.h +++ b/caffe2/utils/math.h @@ -310,8 +310,8 @@ void Transpose( // limitation that the data has to be contiguous in memory. template void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -327,8 +327,8 @@ void Gemm( // In most cases you probably want to use the function above, though. template void GemmEx( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -345,19 +345,37 @@ void GemmEx( // GemmBatched provides a simple abstraction into library routines template void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const T** A, + const T** B, + const float beta, + T** C, + Context* context, + TensorProto::DataType math_type = TensorProto_DataType_FLOAT); + +template +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const T* A, + const int A_stride, const T* B, + const int B_stride, const float beta, T* C, + const int C_stride, Context* context, - Tensor* scratch = nullptr, TensorProto::DataType math_type = TensorProto_DataType_FLOAT); // Gemv always takes in a M*N matrix A, and depending on whether we set TransA @@ -366,7 +384,7 @@ void GemmBatched( // CblasTrans: x is an M dim vector and y is an N dim vector. template void Gemv( - const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE trans_A, const int M, const int N, const float alpha, diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc index 6aae82ea6554d..17afeb3107aa0 100644 --- a/caffe2/utils/math_cpu.cc +++ b/caffe2/utils/math_cpu.cc @@ -75,8 +75,8 @@ namespace math { // CblasTrans, respectively, for each of A and B. template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -93,9 +93,9 @@ void Gemm( } else { C_mat *= beta; } - switch (TransA) { + switch (trans_A) { case CblasNoTrans: { - switch (TransB) { + switch (trans_B) { case CblasNoTrans: C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, N, K) * @@ -107,11 +107,11 @@ void Gemm( ConstEigenMatrixMap(A, K, M)); return; default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; } } case CblasTrans: { - switch (TransB) { + switch (trans_B) { case CblasNoTrans: C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, N, K) * @@ -123,18 +123,18 @@ void Gemm( ConstEigenMatrixMap(A, M, K).transpose()); return; default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; } } default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransA"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_A"; } } template <> void GemmEx( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -156,9 +156,9 @@ void GemmEx( } else { C_mat *= beta; } - switch (TransA) { + switch (trans_A) { case CblasNoTrans: { - switch (TransB) { + switch (trans_B) { case CblasNoTrans: C_mat.noalias() += alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * @@ -170,11 +170,11 @@ void GemmEx( ConstStridedMap(A, K, M, OuterStride(lda))); return; default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; } } case CblasTrans: { - switch (TransB) { + switch (trans_B) { case CblasNoTrans: C_mat.noalias() += alpha * (ConstStridedMap(B, N, K, OuterStride(ldb)) * @@ -186,17 +186,17 @@ void GemmEx( ConstStridedMap(A, M, K, OuterStride(lda)).transpose()); return; default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_B"; } } default: - LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransA"; + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for trans_A"; } } template <> void Gemv( - const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE trans_A, const int M, const int N, const float alpha, @@ -206,7 +206,7 @@ void Gemv( float* y, CPUContext* context, TensorProto::DataType math_type) { - EigenVectorMap y_vec(y, TransA == CblasNoTrans ? M : N); + EigenVectorMap y_vec(y, trans_A == CblasNoTrans ? M : N); if (beta == 0) { // In Caffe2 we often do a lazy initialization, which may contain NaNs in // the float values. As a result, if beta is 0, we explicitly do a setzero. @@ -214,7 +214,7 @@ void Gemv( } else { y_vec *= beta; } - switch (TransA) { + switch (trans_A) { case CblasNoTrans: { y_vec.noalias() += alpha * (ConstEigenMatrixMap(A, N, M).transpose() * @@ -292,8 +292,8 @@ CAFFE2_SPECIALIZED_AXPBY(float) template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -304,12 +304,12 @@ void Gemm( float* C, CPUContext* /*context*/, TensorProto::DataType /*math_type*/) { - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; cblas_sgemm( CblasRowMajor, - TransA, - TransB, + trans_A, + trans_B, M, N, K, @@ -325,8 +325,8 @@ void Gemm( template <> void GemmEx( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -341,8 +341,8 @@ void GemmEx( CPUContext* /*context*/) { cblas_sgemm( CblasRowMajor, - TransA, - TransB, + trans_A, + trans_B, M, N, K, @@ -358,7 +358,7 @@ void GemmEx( template <> void Gemv( - const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE trans_A, const int M, const int N, const float alpha, @@ -368,7 +368,7 @@ void Gemv( float* y, CPUContext* /*context*/, TensorProto::DataType /*math_type*/) { - cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1); + cblas_sgemv(CblasRowMajor, trans_A, M, N, alpha, A, N, x, 1, beta, y, 1); } #define CAFFE2_SPECIALIZED_SCALE(T, prefix) \ @@ -447,69 +447,106 @@ CAFFE2_SPECIALIZED_AXPBY(float, s) template <> void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float** A, + const float** B, + const float beta, + float** C, + CPUContext* context, + TensorProto::DataType /* math_type */) { +#ifdef CAFFE2_USE_MKL + (void)context; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + cblas_sgemm_batch( + CblasRowMajor, + &trans_A, + &trans_B, + &M, + &N, + &K, + &alpha, + A, + &lda, + B, + &ldb, + &beta, + C, + &ldc, + 1, + &batch_size); +#else // CAFFE2_USE_MKL + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + trans_A, trans_B, M, N, K, alpha, A[i], B[i], beta, C[i], context); + } +#endif // CAFFE2_USE_MKL +} + +template <> +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const float* A, + const int A_stride, const float* B, + const int B_stride, const float beta, float* C, + const int C_stride, CPUContext* context, - Tensor*, /* scratch */ TensorProto::DataType /* math_type */) { - const int a_stride = M * K; - const int b_stride = K * N; - const int c_stride = M * N; - #ifdef CAFFE2_USE_MKL (void)context; - - const int lda = (TransA == CblasNoTrans) ? K : M; - const int ldb = (TransB == CblasNoTrans) ? N : K; - std::vector a_array(batch_size, nullptr); - std::vector b_array(batch_size, nullptr); - std::vector c_array(batch_size, nullptr); + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector C_array(batch_size); for (int i = 0; i < batch_size; ++i) { - a_array[i] = A + a_stride * i; - b_array[i] = B + b_stride * i; - c_array[i] = C + c_stride * i; + A_array[i] = A + i * A_stride; + B_array[i] = B + i * B_stride; + C_array[i] = C + i * C_stride; } cblas_sgemm_batch( CblasRowMajor, - &TransA, - &TransB, + &trans_A, + &trans_B, &M, &N, &K, &alpha, - a_array.data(), + A_array.data(), &lda, - b_array.data(), + B_array.data(), &ldb, &beta, - c_array.data(), - &N, // ldc_array + C_array.data(), + &ldc, 1, &batch_size); #else // CAFFE2_USE_MKL // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { math::Gemm( - TransA, - TransB, - M, - N, - K, - alpha, - A + a_stride * i, - B + b_stride * i, - beta, - C + c_stride * i, - context); + trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context); + A += A_stride; + B += B_stride; + C += C_stride; } #endif } @@ -1395,29 +1432,6 @@ void ColwiseBinaryOp( } } -template -void BinaryOpWith2DBroadcasting( - const int ndim, - const int* dims, - const int pivot, - const bool broadcast_1st, - const Operator1& op1, - const Operator2& op2, - const TIn* A, - const TIn* B, - TOut* C, - CPUContext* context) { - const int rows = - std::accumulate(dims, dims + pivot, 1, std::multiplies()); - const int cols = - std::accumulate(dims + pivot, dims + ndim, 1, std::multiplies()); - if (broadcast_1st) { - op1(rows, cols, A, B, C, context); - } else { - op2(rows, cols, A, B, C, context); - } -} - template void BroadcastBinaryOpImpl( const int ndim, @@ -1585,87 +1599,101 @@ DEFINE_2D_BROADCAST_1ST_DIV_FUNCTION(std::int32_t) DEFINE_2D_BROADCAST_1ST_DIV_FUNCTION(std::int64_t) #undef DEFINE_2D_BROADCAST_1ST_DIV_FUNCTION -#define DELEGATE_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \ - template <> \ - void Func( \ - const int A_ndim, \ - const int* A_dims, \ - const int B_ndim, \ - const int* B_dims, \ - const TIn* A, \ - const TIn* B, \ - TOut* C, \ - CPUContext* context) { \ - const int ndim = std::max(A_ndim, B_ndim); \ - std::vector A_dims_array(ndim); \ - std::vector B_dims_array(ndim); \ - std::vector C_dims_array(ndim); \ - utils::ComputeBroadcastBinaryOpDims( \ - A_ndim, \ - A_dims, \ - B_ndim, \ - B_dims, \ - A_dims_array.data(), \ - B_dims_array.data(), \ - C_dims_array.data()); \ - if (A_dims_array == B_dims_array) { \ - const int size = std::accumulate( \ - C_dims_array.cbegin(), \ - C_dims_array.cend(), \ - 1, \ - std::multiplies()); \ - Func(size, A, B, C, context); \ - return; \ - } \ - int pivot; \ - bool broadcast_1st; \ - if (utils::IsRowwiseBroadcastBinaryOp( \ - ndim, \ - A_dims_array.data(), \ - B_dims_array.data(), \ - &pivot, \ - &broadcast_1st)) { \ - BinaryOpWith2DBroadcasting( \ - ndim, \ - C_dims_array.data(), \ - pivot, \ - broadcast_1st, \ - Rowwise##Func, \ - Rowwise##Func, \ - A, \ - B, \ - C, \ - context); \ - return; \ - } \ - if (utils::IsColwiseBroadcastBinaryOp( \ - ndim, \ - A_dims_array.data(), \ - B_dims_array.data(), \ - &pivot, \ - &broadcast_1st)) { \ - BinaryOpWith2DBroadcasting( \ - ndim, \ - C_dims_array.data(), \ - pivot, \ - broadcast_1st, \ - Colwise##Func, \ - Colwise##Func, \ - A, \ - B, \ - C, \ - context); \ - return; \ - } \ - BroadcastBinaryOpImpl( \ - ndim, \ - A_dims_array.data(), \ - B_dims_array.data(), \ - C_dims_array.data(), \ - Op(), \ - A, \ - B, \ - C); \ +#define DELEGATE_BROADCAST_BINARY_FUNCTION(TIn, TOut, Func, Op) \ + template <> \ + void Func( \ + const int A_ndim, \ + const int* A_dims, \ + const int B_ndim, \ + const int* B_dims, \ + const TIn* A, \ + const TIn* B, \ + TOut* C, \ + CPUContext* context) { \ + const int ndim = std::max(A_ndim, B_ndim); \ + std::vector A_dims_array(ndim); \ + std::vector B_dims_array(ndim); \ + std::vector C_dims_array(ndim); \ + utils::ComputeBroadcastBinaryOpDims( \ + A_ndim, \ + A_dims, \ + B_ndim, \ + B_dims, \ + A_dims_array.data(), \ + B_dims_array.data(), \ + C_dims_array.data()); \ + if (A_dims_array == B_dims_array) { \ + const int size = std::accumulate( \ + C_dims_array.cbegin(), \ + C_dims_array.cend(), \ + 1, \ + std::multiplies()); \ + Func(size, A, B, C, context); \ + return; \ + } \ + int rows; \ + int cols; \ + bool broadcast_1st; \ + if (utils::IsRowwiseBroadcastBinaryOp( \ + ndim, \ + A_dims_array.data(), \ + B_dims_array.data(), \ + &rows, \ + &cols, \ + &broadcast_1st)) { \ + if (broadcast_1st) { \ + Rowwise##Func(rows, cols, A, B, C, context); \ + } else { \ + Rowwise##Func(rows, cols, A, B, C, context); \ + } \ + return; \ + } \ + if (utils::IsColwiseBroadcastBinaryOp( \ + ndim, \ + A_dims_array.data(), \ + B_dims_array.data(), \ + &rows, \ + &cols, \ + &broadcast_1st)) { \ + if (broadcast_1st) { \ + Colwise##Func(rows, cols, A, B, C, context); \ + } else { \ + Colwise##Func(rows, cols, A, B, C, context); \ + } \ + return; \ + } \ + int pre; \ + int mid; \ + int nxt; \ + if (utils::IsMiddleBroadcastBinaryOp( \ + ndim, \ + A_dims_array.data(), \ + B_dims_array.data(), \ + &pre, \ + &mid, \ + &nxt, \ + &broadcast_1st)) { \ + const int stride = mid * nxt; \ + for (int i = 0; i < pre; ++i) { \ + if (broadcast_1st) { \ + Colwise##Func( \ + mid, nxt, A, B + i * stride, C + i * stride, context); \ + } else { \ + Colwise##Func( \ + mid, nxt, A + i * stride, B, C + i * stride, context); \ + } \ + } \ + return; \ + } \ + BroadcastBinaryOpImpl( \ + ndim, \ + A_dims_array.data(), \ + B_dims_array.data(), \ + C_dims_array.data(), \ + Op(), \ + A, \ + B, \ + C); \ } #define DEFINE_BROADCAST_COMPARE_FUNCTION(Func, Op) \ diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu index e93c1a729b429..8704f59e35083 100644 --- a/caffe2/utils/math_gpu.cu +++ b/caffe2/utils/math_gpu.cu @@ -10,6 +10,7 @@ #include #include +#include #include #include "caffe2/core/context_gpu.h" @@ -130,9 +131,8 @@ __global__ void BroadcastBinaryOpCUDAKernel( template void BinaryOpWith2DBroadcasting( - const int ndim, - const int* dims, - const int pivot, + const int rows, + const int cols, const bool rowwise_broadcast, const bool broadcast_1st, const BinaryOperator& op, @@ -140,10 +140,6 @@ void BinaryOpWith2DBroadcasting( const TIn* B, TOut* C, CUDAContext* context) { - const int rows = - std::accumulate(dims, dims + pivot, 1, std::multiplies()); - const int cols = - std::accumulate(dims + pivot, dims + ndim, 1, std::multiplies()); if (rows == 0 || cols == 0) { return; } @@ -248,44 +244,29 @@ void BroadcastBinaryOp( context->cuda_stream()>>>(size, op, A, B, C); return; } - int pivot; + int rows; + int cols; bool broadcast_1st; if (utils::IsRowwiseBroadcastBinaryOp( ndim, A_dims_array.data(), B_dims_array.data(), - &pivot, + &rows, + &cols, &broadcast_1st)) { BinaryOpWith2DBroadcasting( - ndim, - C_dims_array.data(), - pivot, - true, - broadcast_1st, - op, - A, - B, - C, - context); + rows, cols, true, broadcast_1st, op, A, B, C, context); return; } if (utils::IsColwiseBroadcastBinaryOp( ndim, A_dims_array.data(), B_dims_array.data(), - &pivot, + &rows, + &cols, &broadcast_1st)) { BinaryOpWith2DBroadcasting( - ndim, - C_dims_array.data(), - pivot, - false, - broadcast_1st, - op, - A, - B, - C, - context); + rows, cols, false, broadcast_1st, op, A, B, C, context); return; } DISPATCH_FUNCTION_BY_VALUE_WITH_TYPE_3( @@ -686,8 +667,8 @@ DELEGATE_REDUCTION_FUNCTION(int64_t, ReduceMax, Max) // limitation that the data has to be contiguous in memory. template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -700,16 +681,16 @@ void Gemm( TensorProto::DataType math_type) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; CUBLAS_ENFORCE(cublasSgemm( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, @@ -725,8 +706,8 @@ void Gemm( template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -739,17 +720,17 @@ void Gemm( TensorProto::DataType math_type) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_CHECK(cublasSgemmEx( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, @@ -764,17 +745,15 @@ void Gemm( C, CUDA_R_16F, N)); - } else if (math_type == TensorProto_DataType_FLOAT16) { // convert alpha, beta from float -> __half - auto alpha_fp16 = convert::floatToHalf(alpha); - auto beta_fp16 = convert::floatToHalf(beta); - + const __half alpha_fp16 = convert::floatToHalf(alpha); + const __half beta_fp16 = convert::floatToHalf(beta); // call cublasHgemm CUBLAS_CHECK(cublasHgemm( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, @@ -816,224 +795,351 @@ void BiasCHW( template <> void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, - const float* A, - const float* B, + const float** A, + const float** B, const float beta, - float* C, + float** C, CUDAContext* context, - Tensor* scratch, TensorProto::DataType math_type) { - const int a_stride = M * K; - const int b_stride = K * N; - const int c_stride = M * N; #if __CUDACC_VER_MAJOR__ < 8 // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { math::Gemm( - TransA, - TransB, + trans_A, + trans_B, M, N, K, alpha, - A + a_stride * i, - B + b_stride * i, + A[i], + B[i], beta, - C + c_stride * i, - context); + C[i], + context, + math_type); } #else // Note that cublas follows fortran order, so the order is different from // the cblas convention. - const int lda = (TransA == CblasNoTrans) ? K : M; - const int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + thrust::device_vector A_device(A, A + batch_size); + thrust::device_vector B_device(B, B + batch_size); + thrust::device_vector C_device(C, C + batch_size); + CUBLAS_ENFORCE(cublasSgemmBatched( + context->cublas_handle(), + cu_trans_B, + cu_trans_A, + N, + M, + K, + &alpha, + B_device.data().get(), + ldb, + A_device.data().get(), + lda, + &beta, + C_device.data().get(), + ldc, + batch_size)); +#endif +} + +template <> +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int A_stride, + const float* B, + const int B_stride, + const float beta, + float* C, + const int C_stride, + CUDAContext* context, + TensorProto::DataType math_type) { +#if __CUDACC_VER_MAJOR__ < 8 + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); + A += A_stride; + B += B_stride; + C += C_stride; + } +#else + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; CUBLAS_ENFORCE(cublasSgemmStridedBatched( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, &alpha, B, ldb, - b_stride, + B_stride, A, lda, - a_stride, + A_stride, &beta, C, - N, - c_stride, + ldc, + C_stride, batch_size)); #endif } -namespace { - -__global__ void FloatToHalfKernel(const int N, const float* X, half* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { - Y[i] = __float2half(X[i]); +template <> +void GemmBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, + const int batch_size, + const int M, + const int N, + const int K, + const float alpha, + const float16** A, + const float16** B, + const float beta, + float16** C, + CUDAContext* context, + TensorProto::DataType math_type) { +#if __CUDACC_VER_MAJOR__ < 9 + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + trans_A, + trans_B, + M, + N, + K, + alpha, + A[i], + B[i], + beta, + C[i], + context, + math_type); } -} - -__global__ void HalfToFloatKernel(const int N, const half* X, float* Y) { - CUDA_1D_KERNEL_LOOP(i, N) { - Y[i] = __half2float(X[i]); +#else + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + if (math_type == TensorProto_DataType_FLOAT) { +#if __CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ < 1 + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + trans_A, + trans_B, + M, + N, + K, + alpha, + A[i], + B[i], + beta, + C[i], + context, + math_type); + } +#else + thrust::device_vector A_device(A, A + batch_size); + thrust::device_vector B_device(B, B + batch_size); + thrust::device_vector C_device(C, C + batch_size); + CUBLAS_ENFORCE(cublasGemmBatchedEx( + context->cublas_handle(), + cu_trans_B, + cu_trans_A, + N, + M, + K, + &alpha, + B_device.data().get(), + CUDA_R_16F, + ldb, + A_device.data().get(), + CUDA_R_16F, + lda, + &beta, + C_device.data().get(), + CUDA_R_16F, + ldc, + batch_size, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#endif + } else if (math_type == TensorProto_DataType_FLOAT16) { + // Convert alpha, beta from float -> __half + const __half alpha_fp16 = convert::floatToHalf(alpha); + const __half beta_fp16 = convert::floatToHalf(beta); + std::vector A_array(batch_size); + std::vector B_array(batch_size); + std::vector<__half*> C_array(batch_size); + for (int i = 0; i < batch_size; ++i) { + A_array[i] = reinterpret_cast(A[i]); + B_array[i] = reinterpret_cast(B[i]); + C_array[i] = reinterpret_cast<__half*>(C[i]); + } + thrust::device_vector A_device( + A_array.cbegin(), A_array.cend()); + thrust::device_vector B_device( + B_array.cbegin(), B_array.cend()); + thrust::device_vector<__half*> C_device(C_array.cbegin(), C_array.cend()); + CUBLAS_ENFORCE(cublasHgemmBatched( + context->cublas_handle(), + cu_trans_B, + cu_trans_A, + N, + M, + K, + &alpha_fp16, + B_device.data().get(), + ldb, + A_device.data().get(), + lda, + &beta_fp16, + C_device.data().get(), + ldc, + batch_size)); + } else { + CAFFE_THROW("Unsupported math type"); } +#endif } -}; // namespace - template <> -void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const float16* A, + const int A_stride, const float16* B, + const int B_stride, const float beta, float16* C, + const int C_stride, CUDAContext* context, - Tensor* scratch, TensorProto::DataType math_type) { - const int a_stride = M * K; - const int b_stride = K * N; - const int c_stride = M * N; #if __CUDACC_VER_MAJOR__ < 8 // loop over matrices in the batch for (int i = 0; i < batch_size; ++i) { math::Gemm( - TransA, - TransB, - M, - N, - K, - alpha, - A + a_stride * i, - B + b_stride * i, - beta, - C + c_stride * i, - context); + trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); + A += A_stride; + B += B_stride; + C += C_stride; } #else - // 3 options: - // 1) scratch != null = cast to fp32, SgemmStridedBatched, cast result to fp16 - // 2) math_type == FLOAT, scratch == nullptr = looped SgemmEx - // 3) math_type == FLOAT16, scratch == nullptr = batched Hgemm - - if (scratch != nullptr) { - const int A_size = a_stride * batch_size; - const int B_size = b_stride * batch_size; - // cast, cublasSgemmStridedBatched, cast - size_t in_elems = A_size + B_size; - size_t out_elems = c_stride * batch_size; - - scratch->Resize(in_elems + out_elems); - float* scratch_ptr = scratch->mutable_data(); - - float* A_fp32 = scratch_ptr; - float* B_fp32 = scratch_ptr + A_size; - float* C_fp32 = scratch_ptr + A_size + B_size; - - // cast A, B into fp32 - HalfToFloatKernel<<< - CAFFE_GET_BLOCKS(A_size), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(A_size, (half*)A, A_fp32); - HalfToFloatKernel<<< - CAFFE_GET_BLOCKS(B_size), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(B_size, (half*)B, B_fp32); - - // run fp32 batched Gemm - GemmBatched( - TransA, - TransB, - batch_size, + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const int ldc = N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + if (math_type == TensorProto_DataType_FLOAT) { +#if (__CUDACC_VER_MAJOR__ < 9) || \ + (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ < 1) + // loop over matrices in the batch + for (int i = 0; i < batch_size; ++i) { + math::Gemm( + trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); + A += A_stride; + B += B_stride; + C += C_stride; + } +#else + CUBLAS_ENFORCE(cublasGemmStridedBatchedEx( + context->cublas_handle(), + cu_trans_B, + cu_trans_A, + N, M, + K, + &alpha, + B, + CUDA_R_16F, + ldb, + B_stride, + A, + CUDA_R_16F, + lda, + A_stride, + &beta, + C, + CUDA_R_16F, + ldc, + C_stride, + batch_size, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#endif + } else if (math_type == TensorProto_DataType_FLOAT16) { + // Convert alpha, beta from float -> __half + const __half alpha_fp16 = convert::floatToHalf(alpha); + const __half beta_fp16 = convert::floatToHalf(beta); + CUBLAS_ENFORCE(cublasHgemmStridedBatched( + context->cublas_handle(), + cu_trans_B, + cu_trans_A, N, + M, K, - alpha, - A_fp32, - B_fp32, - beta, - C_fp32, - context); - - // cast result back to fp16 - FloatToHalfKernel<<< - CAFFE_GET_BLOCKS(batch_size * M * N), - CAFFE_CUDA_NUM_THREADS, - 0, - context->cuda_stream()>>>(batch_size * M * N, C_fp32, (half*)C); + &alpha_fp16, + (const __half*)B, + ldb, + B_stride, + (const __half*)A, + lda, + A_stride, + &beta_fp16, + (__half*)C, + ldc, + C_stride, + batch_size)); } else { - if (math_type == TensorProto_DataType_FLOAT) { - // loop over matrices in the batch - for (int i = 0; i < batch_size; ++i) { - math::Gemm( - TransA, - TransB, - M, - N, - K, - alpha, - A + a_stride * i, - B + b_stride * i, - beta, - C + c_stride * i, - context); - } - } else if (math_type == TensorProto_DataType_FLOAT16) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - const int lda = (TransA == CblasNoTrans) ? K : M; - const int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // convert alpha, beta from float -> __half - auto alpha_fp16 = convert::floatToHalf(alpha); - auto beta_fp16 = convert::floatToHalf(beta); - CUBLAS_ENFORCE(cublasHgemmStridedBatched( - context->cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha_fp16, - (const __half*)B, - ldb, - b_stride, - (const __half*)A, - lda, - a_stride, - &beta_fp16, - (__half*)C, - N, - c_stride, - batch_size)); - } + CAFFE_THROW("Unsupported math type"); } #endif } @@ -1043,8 +1149,8 @@ void GemmBatched( // No change, but required. Defer to default CUDA engine template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -1056,13 +1162,13 @@ void Gemm( CUDAContext* context, TensorProto::DataType math_type) { return Gemm( - TransA, TransB, M, N, K, alpha, A, B, beta, C, context, math_type); + trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type); } template <> void Gemm( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -1075,12 +1181,12 @@ void Gemm( TensorProto::DataType math_type) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - int lda = (TransA == CblasNoTrans) ? K : M; - int ldb = (TransB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int lda = (trans_A == CblasNoTrans) ? K : M; + const int ldb = (trans_B == CblasNoTrans) ? N : K; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; // enable TensorCore for this call on this handle if (TensorCoreAvailable()) { @@ -1090,8 +1196,8 @@ void Gemm( CUBLAS_CHECK(cublasGemmEx( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, @@ -1117,68 +1223,76 @@ void Gemm( } template <> -void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const float* A, + const int A_stride, const float* B, + const int B_stride, const float beta, float* C, + const int C_stride, CUDAContext* context, - Tensor* scratch, TensorProto::DataType math_type) { - return GemmBatched( - TransA, - TransB, + return GemmStridedBatched( + trans_A, + trans_B, batch_size, M, N, K, alpha, A, + A_stride, B, + B_stride, beta, C, + C_stride, context, - scratch, math_type); } template <> -void GemmBatched( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, +void GemmStridedBatched( + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int batch_size, const int M, const int N, const int K, const float alpha, const float16* A, + const int A_stride, const float16* B, + const int B_stride, const float beta, float16* C, + const int C_stride, CUDAContext* context, - Tensor* scratch, TensorProto::DataType math_type) { - return GemmBatched( - TransA, - TransB, + return GemmStridedBatched( + trans_A, + trans_B, batch_size, M, N, K, alpha, A, + A_stride, B, + B_stride, beta, C, + C_stride, context, - scratch, math_type); } @@ -1186,8 +1300,8 @@ void GemmBatched( template <> void GemmEx( - const CBLAS_TRANSPOSE TransA, - const CBLAS_TRANSPOSE TransB, + const CBLAS_TRANSPOSE trans_A, + const CBLAS_TRANSPOSE trans_B, const int M, const int N, const int K, @@ -1202,14 +1316,14 @@ void GemmEx( CUDAContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const cublasOperation_t cu_trans_B = + (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; CUBLAS_ENFORCE(cublasSgemm( context->cublas_handle(), - cuTransB, - cuTransA, + cu_trans_B, + cu_trans_A, N, M, K, @@ -1225,7 +1339,7 @@ void GemmEx( template <> void Gemv( - const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE trans_A, const int M, const int N, const float alpha, @@ -1235,11 +1349,11 @@ void Gemv( float* y, CUDAContext* context, TensorProto::DataType math_type) { - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; CUBLAS_ENFORCE(cublasSgemv( context->cublas_handle(), - cuTransA, + cu_trans_A, N, M, &alpha, @@ -1295,7 +1409,7 @@ CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(float16); template <> void Gemv( - const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE trans_A, const int M, const int N, const float alpha, @@ -1305,19 +1419,19 @@ void Gemv( float16* y, CUDAContext* context, TensorProto::DataType math_type) { - cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t cu_trans_A = + (trans_A == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; // sort out what we need to call cublasSgemmEx / cublasHgemm - int m = (cuTransA == CUBLAS_OP_N) ? N : M; - int k = (cuTransA == CUBLAS_OP_N) ? M : N; - int LDA = (cuTransA == CUBLAS_OP_N) ? m : k; - int LDC = m; + const int m = (cu_trans_A == CUBLAS_OP_N) ? N : M; + const int k = (cu_trans_A == CUBLAS_OP_N) ? M : N; + const int lda = (cu_trans_A == CUBLAS_OP_N) ? m : k; + const int ldc = m; if (math_type == TensorProto_DataType_FLOAT) { CUBLAS_CHECK(cublasSgemmEx( context->cublas_handle(), - cuTransA, + cu_trans_A, CUBLAS_OP_N, m, 1, @@ -1325,33 +1439,32 @@ void Gemv( &alpha, A, CUDA_R_16F, - LDA, + lda, x, CUDA_R_16F, k, &beta, y, CUDA_R_16F, - LDC)); + ldc)); } else if (math_type == TensorProto_DataType_FLOAT16) { - auto alpha_fp16 = convert::floatToHalf(alpha); - auto beta_fp16 = convert::floatToHalf(beta); - + const __half alpha_fp16 = convert::floatToHalf(alpha); + const __half beta_fp16 = convert::floatToHalf(beta); CUBLAS_CHECK(cublasHgemm( context->cublas_handle(), - cuTransA, + cu_trans_A, CUBLAS_OP_N, m, 1, k, &alpha_fp16, (const __half*)A, - LDA, + lda, (const __half*)x, k, &beta_fp16, (__half*)y, - LDC)); + ldc)); } else { // fail CAFFE_THROW("Unsupported math type"); diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc index 8de888fc76a10..330f34181918c 100644 --- a/caffe2/utils/math_gpu_test.cc +++ b/caffe2/utils/math_gpu_test.cc @@ -1,3 +1,4 @@ +#include #include #include #include @@ -272,7 +273,19 @@ class GemmBatchedGPUTest } void RunGemmBatched(const float alpha, const float beta) { - math::GemmBatched( + const float* X_data = X_->template data(); + const float* W_data = W_->template data(); + float* Y_data = Y_->template mutable_data(); + const int X_stride = 5 * 10; + const int W_stride = 6 * 10; + const int Y_stride = 5 * 6; + std::array X_array = { + X_data, X_data + X_stride, X_data + 2 * X_stride}; + std::array W_array = { + W_data, W_data + W_stride, W_data + 2 * W_stride}; + std::array Y_array = { + Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; + math::GemmBatched( trans_X_ ? CblasTrans : CblasNoTrans, trans_W_ ? CblasTrans : CblasNoTrans, 3, @@ -280,10 +293,35 @@ class GemmBatchedGPUTest 6, 10, alpha, - X_->template data(), - W_->template data(), + X_array.data(), + W_array.data(), beta, - Y_->template mutable_data(), + Y_array.data(), + cuda_context_.get()); + } + + void RunGemmStridedBatched(const float alpha, const float beta) { + const float* X_data = X_->template data(); + const float* W_data = W_->template data(); + float* Y_data = Y_->template mutable_data(); + const int X_stride = 5 * 10; + const int W_stride = 6 * 10; + const int Y_stride = 5 * 6; + math::GemmStridedBatched( + trans_X_ ? CblasTrans : CblasNoTrans, + trans_W_ ? CblasTrans : CblasNoTrans, + 3, + 5, + 6, + 10, + alpha, + X_data, + X_stride, + W_data, + W_stride, + beta, + Y_data, + Y_stride, cuda_context_.get()); } @@ -316,6 +354,18 @@ TEST_P(GemmBatchedGPUTest, GemmBatchedGPUFloatTest) { VerifyOutput(20.0f); } +TEST_P(GemmBatchedGPUTest, GemmStridedBatchedGPUFloatTest) { + if (!HasCudaGPU()) { + return; + } + RunGemmStridedBatched(1.0f, 0.0f); + VerifyOutput(10.0f); + RunGemmStridedBatched(1.0f, 0.5f); + VerifyOutput(15.0f); + RunGemmStridedBatched(0.5f, 1.0f); + VerifyOutput(20.0f); +} + INSTANTIATE_TEST_CASE_P( GemmBatchedGPUTrans, GemmBatchedGPUTest, @@ -402,12 +452,7 @@ TEST_F(ReduceTensorGPUTest, ReduceMinGPUTest) { num_dims, dims, num_axes, axes, X, Y, context); }; // Test for 1D tensor. - RunRedcueTensorTest( - reduce_min, - {3}, - {0}, - {1.0f, 2.0f, 3.0f}, - {1.0f}); + RunRedcueTensorTest(reduce_min, {3}, {0}, {1.0f, 2.0f, 3.0f}, {1.0f}); // Test for 2D Tensor. RunRedcueTensorTest( @@ -423,11 +468,7 @@ TEST_F(ReduceTensorGPUTest, ReduceMinGPUTest) { {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f}); RunRedcueTensorTest( - reduce_min, - {2, 3}, - {0, 1}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, - {1.0f}); + reduce_min, {2, 3}, {0, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f}); // Test for 3D tensor. RunRedcueTensorTest( @@ -465,12 +506,7 @@ TEST_F(ReduceTensorGPUTest, ReduceMaxGPUTest) { num_dims, dims, num_axes, axes, X, Y, context); }; // Test for 1D tensor. - RunRedcueTensorTest( - reduce_max, - {3}, - {0}, - {1.0f, 2.0f, 3.0f}, - {3.0f}); + RunRedcueTensorTest(reduce_max, {3}, {0}, {1.0f, 2.0f, 3.0f}, {3.0f}); // Test for 2D Tensor. RunRedcueTensorTest( @@ -486,11 +522,7 @@ TEST_F(ReduceTensorGPUTest, ReduceMaxGPUTest) { {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {4.0f, 5.0f, 6.0f}); RunRedcueTensorTest( - reduce_max, - {2, 3}, - {0, 1}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, - {6.0f}); + reduce_max, {2, 3}, {0, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {6.0f}); // Test for 3D tensor. RunRedcueTensorTest( diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc index 8ade7d6f5cc45..6d3444553d51f 100644 --- a/caffe2/utils/math_test.cc +++ b/caffe2/utils/math_test.cc @@ -1,3 +1,4 @@ +#include #include #include @@ -182,6 +183,18 @@ class GemmBatchedTest } void RunGemmBatched(const float alpha, const float beta) { + const float* X_data = X_.template data(); + const float* W_data = W_.template data(); + float* Y_data = Y_.template mutable_data(); + const int X_stride = 5 * 10; + const int W_stride = 6 * 10; + const int Y_stride = 5 * 6; + std::array X_array = { + X_data, X_data + X_stride, X_data + 2 * X_stride}; + std::array W_array = { + W_data, W_data + W_stride, W_data + 2 * W_stride}; + std::array Y_array = { + Y_data, Y_data + Y_stride, Y_data + 2 * Y_stride}; math::GemmBatched( trans_X_ ? CblasTrans : CblasNoTrans, trans_W_ ? CblasTrans : CblasNoTrans, @@ -190,10 +203,35 @@ class GemmBatchedTest 6, 10, alpha, - X_.template data(), - W_.template data(), + X_array.data(), + W_array.data(), beta, - Y_.template mutable_data(), + Y_array.data(), + cpu_context_.get()); + } + + void RunGemmStridedBatched(const float alpha, const float beta) { + const float* X_data = X_.template data(); + const float* W_data = W_.template data(); + float* Y_data = Y_.template mutable_data(); + const int X_stride = 5 * 10; + const int W_stride = 6 * 10; + const int Y_stride = 5 * 6; + math::GemmStridedBatched( + trans_X_ ? CblasTrans : CblasNoTrans, + trans_W_ ? CblasTrans : CblasNoTrans, + 3, + 5, + 6, + 10, + alpha, + X_data, + X_stride, + W_data, + W_stride, + beta, + Y_data, + Y_stride, cpu_context_.get()); } @@ -221,6 +259,15 @@ TEST_P(GemmBatchedTest, GemmBatchedFloatTest) { VerifyOutput(20.0f); } +TEST_P(GemmBatchedTest, GemmStridedBatchedFloatTest) { + RunGemmStridedBatched(1.0f, 0.0f); + VerifyOutput(10.0f); + RunGemmStridedBatched(1.0f, 0.5f); + VerifyOutput(15.0f); + RunGemmStridedBatched(0.5f, 1.0f); + VerifyOutput(20.0f); +} + INSTANTIATE_TEST_CASE_P( GemmBatchedTrans, GemmBatchedTest, @@ -432,12 +479,7 @@ TEST_F(ReduceTensorTest, ReduceMinTest) { num_dims, dims, num_axes, axes, X, Y, context); }; // Test for 1D tensor. - RunRedcueTensorTest( - reduce_min, - {3}, - {0}, - {1.0f, 2.0f, 3.0f}, - {1.0f}); + RunRedcueTensorTest(reduce_min, {3}, {0}, {1.0f, 2.0f, 3.0f}, {1.0f}); // Test for 2D Tensor. RunRedcueTensorTest( @@ -453,11 +495,7 @@ TEST_F(ReduceTensorTest, ReduceMinTest) { {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f}); RunRedcueTensorTest( - reduce_min, - {2, 3}, - {0, 1}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, - {1.0f}); + reduce_min, {2, 3}, {0, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f}); // Test for 3D tensor. RunRedcueTensorTest( @@ -492,12 +530,7 @@ TEST_F(ReduceTensorTest, ReduceMaxTest) { num_dims, dims, num_axes, axes, X, Y, context); }; // Test for 1D tensor. - RunRedcueTensorTest( - reduce_max, - {3}, - {0}, - {1.0f, 2.0f, 3.0f}, - {3.0f}); + RunRedcueTensorTest(reduce_max, {3}, {0}, {1.0f, 2.0f, 3.0f}, {3.0f}); // Test for 2D Tensor. RunRedcueTensorTest( @@ -513,11 +546,7 @@ TEST_F(ReduceTensorTest, ReduceMaxTest) { {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {4.0f, 5.0f, 6.0f}); RunRedcueTensorTest( - reduce_max, - {2, 3}, - {0, 1}, - {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, - {6.0f}); + reduce_max, {2, 3}, {0, 1}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {6.0f}); // Test for 3D tensor. RunRedcueTensorTest( @@ -543,11 +572,7 @@ TEST_F(ReduceTensorTest, ReduceMaxTest) { TEST_F(ReduceTensorTest, ReduceSumTest) { // Test for 1D tensor. RunRedcueTensorTest( - math::ReduceSum, - {3}, - {0}, - {1.0f, 2.0f, 3.0f}, - {6.0f}); + math::ReduceSum, {3}, {0}, {1.0f, 2.0f, 3.0f}, {6.0f}); // Test for 2D Tensor. RunRedcueTensorTest( diff --git a/caffe2/utils/math_utils.cc b/caffe2/utils/math_utils.cc index 1334111992510..0b05099d2f077 100644 --- a/caffe2/utils/math_utils.cc +++ b/caffe2/utils/math_utils.cc @@ -1,6 +1,8 @@ #include "caffe2/utils/math_utils.h" #include +#include +#include #include #include "caffe2/core/logging.h" @@ -68,7 +70,8 @@ bool IsRowwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, - int* pivot, + int* rows, + int* cols, bool* broadcast_1st) { if (ndim == 0) { return false; @@ -82,16 +85,32 @@ bool IsRowwiseBroadcastBinaryOp( if (A_pivot == B_pivot) { return false; } - *pivot = std::max(A_pivot, B_pivot); - *broadcast_1st = A_pivot > B_pivot; - return std::equal(A_dims + *pivot, A_dims + ndim, B_dims + *pivot); + const int pivot = std::max(A_pivot, B_pivot); + if (A_pivot > B_pivot) { + *rows = std::accumulate( + B_dims + B_pivot, B_dims + pivot, 1, std::multiplies()); + *broadcast_1st = true; + } else { + *rows = std::accumulate( + A_dims + A_pivot, A_dims + pivot, 1, std::multiplies()); + *broadcast_1st = false; + } + *cols = 1; + for (int i = pivot; i < ndim; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *cols *= A_dims[i]; + } + return true; } bool IsColwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, - int* pivot, + int* rows, + int* cols, bool* broadcast_1st) { if (ndim == 0) { return false; @@ -105,9 +124,81 @@ bool IsColwiseBroadcastBinaryOp( if (A_pivot == B_pivot) { return false; } - *pivot = std::min(A_pivot, B_pivot) + 1; - *broadcast_1st = A_pivot < B_pivot; - return std::equal(A_dims, A_dims + *pivot, B_dims); + ++A_pivot; + ++B_pivot; + const int pivot = std::min(A_pivot, B_pivot); + if (A_pivot < B_pivot) { + *cols = std::accumulate( + B_dims + pivot, B_dims + B_pivot, 1, std::multiplies()); + *broadcast_1st = true; + } else { + *cols = std::accumulate( + A_dims + pivot, A_dims + A_pivot, 1, std::multiplies()); + *broadcast_1st = false; + } + *rows = 1; + for (int i = 0; i < pivot; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *rows *= A_dims[i]; + } + return true; +} + +bool IsMiddleBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* pre, + int* mid, + int* nxt, + bool* broadcast_1st) { + if (ndim == 0) { + return false; + } + int A_pre = 0; + for (; A_pre < ndim && A_dims[A_pre] == 1; ++A_pre) + ; + int B_pre = 0; + for (; B_pre < ndim && B_dims[B_pre] == 1; ++B_pre) + ; + int A_nxt = ndim - 1; + for (; A_nxt >= 0 && A_dims[A_nxt] == 1; --A_nxt) + ; + int B_nxt = ndim - 1; + for (; B_nxt >= 0 && B_dims[B_nxt] == 1; --B_nxt) + ; + ++A_nxt; + ++B_nxt; + if (A_pre == B_pre || A_nxt == B_nxt) { + return false; + } + if (A_pre > B_pre && A_nxt < B_nxt) { + *pre = std::accumulate( + B_dims + B_pre, B_dims + A_pre, 1, std::multiplies()); + *nxt = std::accumulate( + B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies()); + *broadcast_1st = true; + } else if (A_pre < B_pre && A_nxt > B_nxt) { + *pre = std::accumulate( + A_dims + A_pre, A_dims + B_pre, 1, std::multiplies()); + *nxt = std::accumulate( + A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies()); + *broadcast_1st = false; + } else { + return false; + } + const int l = std::max(A_pre, B_pre); + const int r = std::min(A_nxt, B_nxt); + *mid = 1; + for (int i = l; i < r; ++i) { + if (A_dims[i] != B_dims[i]) { + return false; + } + *mid *= A_dims[i]; + } + return true; } void ComputeTransposeAxesForReduceOp( diff --git a/caffe2/utils/math_utils.h b/caffe2/utils/math_utils.h index 4b54ab1fe9c19..5f0770460ca8b 100644 --- a/caffe2/utils/math_utils.h +++ b/caffe2/utils/math_utils.h @@ -63,14 +63,25 @@ bool IsRowwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, - int* pivot, + int* rows, + int* cols, bool* broadcast_1st); bool IsColwiseBroadcastBinaryOp( const int ndim, const int* A_dims, const int* B_dims, - int* pivot, + int* rows, + int* cols, + bool* broadcast_1st); + +bool IsMiddleBroadcastBinaryOp( + const int ndim, + const int* A_dims, + const int* B_dims, + int* pre, + int* mid, + int* nxt, bool* broadcast_1st); void ComputeTransposeAxesForReduceOp( diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index db481d0f06532..d8be5b7536590 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -176,6 +176,15 @@ Probability distributions - torch.distributions :undoc-members: :show-inheritance: +:hidden:`LowRankMultivariateNormal` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.distributions.lowrank_multivariate_normal +.. autoclass:: LowRankMultivariateNormal + :members: + :undoc-members: + :show-inheritance: + :hidden:`Multinomial` ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/modules/observers/net_observer_reporter_print.cc b/modules/observers/net_observer_reporter_print.cc index 2355fedc9a1a7..b3341bef6ae08 100644 --- a/modules/observers/net_observer_reporter_print.cc +++ b/modules/observers/net_observer_reporter_print.cc @@ -14,63 +14,60 @@ void NetObserverReporterPrint::report( NetBase* net, std::map& info) { // Not allowed to use json library - std::map< - std::string, - std::map>> - caffe2_perf; + std::vector> caffe2_perf; for (auto& p : info) { if ((p.first == "NET_DELAY") && (info.size() == 1)) { // for Net_delay perf - caffe2_perf["NET"] = { - {"latency", - {{"value", caffe2::to_string(p.second.latency * 1000)}, - {"unit", "us"}}}, - {"flops", {{"value", "-1"}, {"unit", "flops"}}}}; + caffe2_perf.push_back( + {{"type", "NET"}, + {"value", caffe2::to_string(p.second.latency * 1000)}, + {"unit", "us"}, + {"metric", "latency"}}); } else if (p.first != "NET_DELAY") { // for operator perf std::string shape_str = get_tensor_shapes(p.second); std::string args_str = get_op_args(p.second); - caffe2_perf[p.first] = { - {"latency", - {{"value", caffe2::to_string(p.second.latency * 1000)}, - {"unit", "us"}}}, - {"flops", - {{ - "value", - caffe2::to_string(p.second.flops), - }, - {"unit", "flops"}}}, - {"tensor_shapes", {{"info_string", shape_str}, {"unit", ""}}}, - {"op_args", {{"info_string", args_str}, {"unit", ""}}}}; + caffe2_perf.push_back( + {{"type", p.first}, + {"value", caffe2::to_string(p.second.latency * 1000)}, + {"unit", "us"}, + {"metric", "latency"}}); + if (p.second.flops > 0) { + caffe2_perf.push_back({{"type", p.first}, + {"value", caffe2::to_string(p.second.flops)}, + {"unit", "flop"}, + {"metric", "flops"}}); + } + if (shape_str != "") { + caffe2_perf.push_back({{"type", p.first}, + {"info_string", shape_str}, + {"unit", ""}, + {"metric", "tensor_shapes"}}); + } + if (args_str != "") { + caffe2_perf.push_back({{"type", p.first}, + {"info_string", args_str}, + {"unit", ""}, + {"metric", "op_args"}}); + } } } for (auto it = caffe2_perf.begin(); it != caffe2_perf.end(); it++) { std::stringstream buffer; + auto entry = *it; buffer << IDENTIFIER << "{"; - buffer << "\"" << it->first << "\"" - << ": {"; - for (auto jt = it->second.begin(); jt != it->second.end(); jt++) { - buffer << "\"" << jt->first << "\"" - << ": {"; - for (auto kt = jt->second.begin(); kt != jt->second.end(); kt++) { - buffer << "\"" << kt->first << "\"" - << ": " - << "\"" << kt->second << "\""; - auto lt = kt; - if ((++lt) != jt->second.end()) { - buffer << ", "; - } - } - buffer << "}"; - auto lt = jt; - if ((++lt) != it->second.end()) { - buffer << ", "; - } + buffer << "\"type\": \"" << entry["type"] << "\"," + << "\"unit\": \"" << entry["unit"] << "\"," + << "\"metric\": \"" << entry["metric"] << "\","; + if (entry.find("value") != entry.end()) { + buffer << "\"value\": \"" << entry["value"] << "\""; + } else if (entry.find("info_string") != entry.end()) { + buffer << "\"info_string\": \"" << entry["info_string"] << "\""; } - buffer << "}}"; + buffer << "}"; LOG(INFO) << buffer.str(); } } @@ -90,7 +87,7 @@ static std::string get_tensor_shapes(PerformanceInformation p) { shape_stream << "]"; shape_str = shape_stream.str(); } else { - shape_str = "[]"; + shape_str = ""; } return shape_str; } @@ -118,7 +115,7 @@ static std::string get_op_args(PerformanceInformation p) { args << "]"; args_str = args.str(); } else { - args_str = "[]"; + args_str = ""; } return args_str; } diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp index 66b11c126df14..ea8b37d44db54 100644 --- a/test/cpp/api/module.cpp +++ b/test/cpp/api/module.cpp @@ -184,7 +184,8 @@ TEST_CASE("module/clone") { SECTION( "a module that overrides clone() does not throw when clone() is called ") { struct Cloneable : Module { - std::shared_ptr clone() const override { + std::shared_ptr clone( + at::optional device = at::nullopt) const override { return nullptr; } }; @@ -299,6 +300,56 @@ TEST_CASE("module/clone") { } } +TEST_CASE("module/clone-to-device", "[cuda]") { + struct TestModule : public Cloneable { + TestModule() { + reset(); + } + void reset() override { + l1 = register_module("l1", Linear(10, 3)); + l2 = register_module("l2", Linear(3, 5)); + l3 = register_module("l3", Linear(5, 100)); + buffer = register_buffer("buf", torch::ones({2, 2})); + } + + Linear l1{nullptr}, l2{nullptr}, l3{nullptr}; + torch::Tensor buffer; + }; + + SECTION("Cloning preserves the device of parameters/buffers") { + TestModule m; + torch::Device device(torch::kCUDA, 0); + + m.to(device); + + auto clone = m.clone(); + for (const auto& parameter : clone->parameters()) { + REQUIRE(parameter->device().type() == device.type()); + REQUIRE(parameter->device().index() == device.index()); + } + for (const auto& buffer : clone->buffers()) { + REQUIRE(buffer->device().type() == device.type()); + REQUIRE(buffer->device().index() == device.index()); + } + } + + SECTION( + "Cloning to a particular device places all parameters/buffers there") { + TestModule m; + torch::Device device(torch::kCUDA, 1); + // everything is on CPU here + auto clone = m.clone(device); + for (const auto& parameter : clone->parameters()) { + REQUIRE(parameter->device().type() == device.type()); + REQUIRE(parameter->device().index() == device.index()); + } + for (const auto& buffer : clone->buffers()) { + REQUIRE(buffer->device().type() == device.type()); + REQUIRE(buffer->device().index() == device.index()); + } + } +} + TEST_CASE("module/parameters") { torch::manual_seed(0); struct TestModule : Module { diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 2cacc75579097..0b522b62686dd 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -142,8 +142,13 @@ TEST_CASE("modules") { SECTION("embedding") { SECTION("basic") { - int dict_size = 10; + const int64_t dict_size = 10; Embedding model(dict_size, 2); + REQUIRE(model->parameters().contains("weight")); + REQUIRE(model->weight.ndimension() == 2); + REQUIRE(model->weight.size(0) == dict_size); + REQUIRE(model->weight.size(1) == 2); + // Cannot get gradients to change indices (input) - only for embedding // params auto x = torch::full({10}, dict_size - 1, torch::kInt64); @@ -156,7 +161,7 @@ TEST_CASE("modules") { REQUIRE(y.size(0) == 10); REQUIRE(y.size(1) == 2); - REQUIRE(model->parameters()["table"].grad().numel() == 2 * dict_size); + REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size); } SECTION("list") { diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index 8aa608c4bb478..4d855cb10c9f8 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -295,6 +295,7 @@ TEST_CASE("sequential") { REQUIRE(params1.size() == params2.size()); for (auto& param : params1) { REQUIRE(!pointer_equal(param.value, params2[param.key])); + REQUIRE(param->device() == params2[param.key].device()); REQUIRE(param->allclose(params2[param.key])); param->data().add_(2); } @@ -303,3 +304,16 @@ TEST_CASE("sequential") { } } } + +TEST_CASE("sequential/clone-to-device", "[cuda]") { + Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3)); + torch::Device device(torch::kCUDA, 0); + Sequential clone = + std::static_pointer_cast(sequential->clone(device)); + for (const auto& p : clone->parameters()) { + REQUIRE(p->device() == device); + } + for (const auto& b : clone->buffers()) { + REQUIRE(b->device() == device); + } +} diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp index 07fa8c3fbfac4..ae38cdf2d0ccb 100644 --- a/test/cpp/api/tensor.cpp +++ b/test/cpp/api/tensor.cpp @@ -5,6 +5,8 @@ #include #include +#include +#include template bool exactly_equal(at::Tensor left, T right) { @@ -179,3 +181,14 @@ TEST_CASE("Tensor/UsesOptionsThatAreSupplied") { REQUIRE(exactly_equal(tensor[1], 2)); REQUIRE(exactly_equal(tensor[2], 3)); } + +TEST_CASE("FromBlob") { + std::vector v = {1, 2, 3}; + auto tensor = torch::from_blob( + reinterpret_cast(v.data()), v.size(), torch::kInt32); + REQUIRE(tensor.is_variable()); + REQUIRE(tensor.numel() == 3); + REQUIRE(tensor[0].toCInt() == 1); + REQUIRE(tensor[1].toCInt() == 2); + REQUIRE(tensor[2].toCInt() == 3); +} diff --git a/test/expect/TestJit.test_alexnet.expect b/test/expect/TestJit.test_alexnet.expect index 3c71802b9ff26..9a1105c8b8b17 100644 --- a/test/expect/TestJit.test_alexnet.expect +++ b/test/expect/TestJit.test_alexnet.expect @@ -28,23 +28,25 @@ graph(%0 : Double(1, 3, 224, 224) %29 : Double(1, 256, 13, 13) = aten::_convolution[stride=[1, 1], padding=[1, 1], dilation=[1, 1], transposed=0, output_padding=[0, 0], groups=1, benchmark=0, deterministic=0, cudnn_enabled=1](%28, %9, %10), scope: AlexNet/Sequential[features]/Conv2d[10] %30 : Double(1, 256, 13, 13) = aten::threshold[threshold={0}, value={0}](%29), scope: AlexNet/Sequential[features]/ReLU[11] %31 : Double(1, 256, 6, 6), %32 : Long(1, 256, 6, 6) = aten::max_pool2d_with_indices[kernel_size=[3, 3], stride=[2, 2], padding=[0, 0], dilation=[1, 1], ceil_mode=0](%30), scope: AlexNet/Sequential[features]/MaxPool2d[12] - %33 : Long() = prim::Constant[value={0}](), scope: AlexNet - %34 : Long() = aten::size(%31, %33), scope: AlexNet - %35 : Long() = prim::Constant[value={9216}](), scope: AlexNet - %36 : Dynamic = aten::stack[dim=0](%34, %35), scope: AlexNet - %37 : Double(1, 9216) = aten::view(%31, %36), scope: AlexNet - %38 : Double(1, 9216) = ^Dropout(0.5, True, False)(%37), scope: AlexNet/Sequential[classifier]/Dropout[0] - %39 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] - %40 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1] - %41 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%40, %38, %39), scope: AlexNet/Sequential[classifier]/Linear[1] - %42 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%41), scope: AlexNet/Sequential[classifier]/ReLU[2] - %43 : Double(1, 4096) = ^Dropout(0.5, True, False)(%42), scope: AlexNet/Sequential[classifier]/Dropout[3] - %44 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] - %45 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4] - %46 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%45, %43, %44), scope: AlexNet/Sequential[classifier]/Linear[4] - %47 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%46), scope: AlexNet/Sequential[classifier]/ReLU[5] - %48 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] - %49 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6] - %50 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%49, %47, %48), scope: AlexNet/Sequential[classifier]/Linear[6] - return (%50); + %33 : int = prim::Constant[value=0](), scope: AlexNet + %34 : int = aten::size(%31, %33), scope: AlexNet + %35 : Long() = prim::NumToTensor(%34), scope: AlexNet + %36 : int = prim::TensorToNum(%35), scope: AlexNet + %37 : int = prim::Constant[value=9216](), scope: AlexNet + %38 : int[] = prim::ListConstruct(%36, %37), scope: AlexNet + %39 : Double(1, 9216) = aten::view(%31, %38), scope: AlexNet + %40 : Double(1, 9216) = ^Dropout(0.5, True, False)(%39), scope: AlexNet/Sequential[classifier]/Dropout[0] + %41 : Double(9216!, 4096!) = aten::t(%11), scope: AlexNet/Sequential[classifier]/Linear[1] + %42 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%12), scope: AlexNet/Sequential[classifier]/Linear[1] + %43 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%42, %40, %41), scope: AlexNet/Sequential[classifier]/Linear[1] + %44 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%43), scope: AlexNet/Sequential[classifier]/ReLU[2] + %45 : Double(1, 4096) = ^Dropout(0.5, True, False)(%44), scope: AlexNet/Sequential[classifier]/Dropout[3] + %46 : Double(4096!, 4096!) = aten::t(%13), scope: AlexNet/Sequential[classifier]/Linear[4] + %47 : Double(1, 4096) = aten::expand[size=[1, 4096], implicit=1](%14), scope: AlexNet/Sequential[classifier]/Linear[4] + %48 : Double(1, 4096) = aten::addmm[beta={1}, alpha={1}](%47, %45, %46), scope: AlexNet/Sequential[classifier]/Linear[4] + %49 : Double(1, 4096) = aten::threshold[threshold={0}, value={0}](%48), scope: AlexNet/Sequential[classifier]/ReLU[5] + %50 : Double(4096!, 1000!) = aten::t(%15), scope: AlexNet/Sequential[classifier]/Linear[6] + %51 : Double(1, 1000) = aten::expand[size=[1, 1000], implicit=1](%16), scope: AlexNet/Sequential[classifier]/Linear[6] + %52 : Double(1, 1000) = aten::addmm[beta={1}, alpha={1}](%51, %49, %50), scope: AlexNet/Sequential[classifier]/Linear[6] + return (%52); } diff --git a/test/expect/TestJit.test_decompose_addmm.expect b/test/expect/TestJit.test_decompose_addmm.expect index a409a810124b2..925362f4f6a4a 100644 --- a/test/expect/TestJit.test_decompose_addmm.expect +++ b/test/expect/TestJit.test_decompose_addmm.expect @@ -8,9 +8,9 @@ graph(%mat : Dynamic %7 : Dynamic = aten::mm(%mat1, %mat2) %8 : Dynamic = aten::add[alpha={1}](%mat, %7) %c : Dynamic = aten::addmm[beta={2}, alpha={4.2}](%mat, %mat1, %mat2) - %10 : Number = prim::TensorToNum(%beta) - %11 : Number = prim::TensorToNum(%alpha) - %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %10, %11) + %10 : int = prim::TensorToNum(%alpha) + %11 : int = prim::TensorToNum(%beta) + %d : Dynamic = aten::addmm(%mat, %mat1, %mat2, %11, %10) %13 : Dynamic = aten::add[alpha={1}](%6, %8) %14 : Dynamic = aten::add[alpha={1}](%13, %c) %15 : Dynamic = aten::add[alpha={1}](%14, %d) diff --git a/test/expect/TestJit.test_trace_size.expect b/test/expect/TestJit.test_trace_size.expect index 1c6fdcd6eba00..567a0fc5a5ecb 100644 --- a/test/expect/TestJit.test_trace_size.expect +++ b/test/expect/TestJit.test_trace_size.expect @@ -1,11 +1,15 @@ graph(%0 : Double(5, 2, 4)) { - %1 : Long() = prim::Constant[value={1}]() - %2 : Long() = aten::size(%0, %1) - %3 : Long() = aten::mul[other={2}](%2) - %4 : Long() = prim::Constant[value={0}]() - %5 : Long() = aten::size(%0, %4) - %6 : Long() = prim::Constant[value={2}]() - %7 : Dynamic = aten::stack[dim=0](%3, %5, %6) - %8 : Double(4, 5, 2) = aten::view(%0, %7) - return (%8); + %1 : int = prim::Constant[value=1]() + %2 : int = aten::size(%0, %1) + %3 : Long() = prim::NumToTensor(%2) + %4 : Long() = aten::mul[other={2}](%3) + %5 : int = prim::TensorToNum(%4) + %6 : int = prim::Constant[value=0]() + %7 : int = aten::size(%0, %6) + %8 : Long() = prim::NumToTensor(%7) + %9 : int = prim::TensorToNum(%8) + %10 : int = prim::Constant[value=2]() + %11 : int[] = prim::ListConstruct(%5, %9, %10) + %12 : Double(4, 5, 2) = aten::view(%0, %11) + return (%12); } diff --git a/test/expect/TestJit.test_trace_size_with_grad.expect b/test/expect/TestJit.test_trace_size_with_grad.expect index 1c6fdcd6eba00..567a0fc5a5ecb 100644 --- a/test/expect/TestJit.test_trace_size_with_grad.expect +++ b/test/expect/TestJit.test_trace_size_with_grad.expect @@ -1,11 +1,15 @@ graph(%0 : Double(5, 2, 4)) { - %1 : Long() = prim::Constant[value={1}]() - %2 : Long() = aten::size(%0, %1) - %3 : Long() = aten::mul[other={2}](%2) - %4 : Long() = prim::Constant[value={0}]() - %5 : Long() = aten::size(%0, %4) - %6 : Long() = prim::Constant[value={2}]() - %7 : Dynamic = aten::stack[dim=0](%3, %5, %6) - %8 : Double(4, 5, 2) = aten::view(%0, %7) - return (%8); + %1 : int = prim::Constant[value=1]() + %2 : int = aten::size(%0, %1) + %3 : Long() = prim::NumToTensor(%2) + %4 : Long() = aten::mul[other={2}](%3) + %5 : int = prim::TensorToNum(%4) + %6 : int = prim::Constant[value=0]() + %7 : int = aten::size(%0, %6) + %8 : Long() = prim::NumToTensor(%7) + %9 : int = prim::TensorToNum(%8) + %10 : int = prim::Constant[value=2]() + %11 : int[] = prim::ListConstruct(%5, %9, %10) + %12 : Double(4, 5, 2) = aten::view(%0, %11) + return (%12); } diff --git a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect index 593a7cb2ac721..db478d2e22f9c 100644 --- a/test/expect/TestScript.test_call_python_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_fn_from_script_fn.expect @@ -1,8 +1,5 @@ graph(%x : Dynamic) { %1 : Dynamic = ^python_fn()(%x) - %2 : int = prim::Constant[value={1}]() - %3 : Dynamic = prim::NumToTensor(%2) - %4 : Dynamic = aten::type_as(%3, %1) - %6 : Dynamic = aten::add[alpha={1}](%1, %4) - return (%6); + %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + return (%5); } diff --git a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect index cf498dcb1ae26..ec5fd842f3b86 100644 --- a/test/expect/TestScript.test_call_python_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_python_mod_from_script_fn.expect @@ -1,8 +1,5 @@ graph(%x : Dynamic) { %1 : Dynamic = ^()(%x) - %2 : int = prim::Constant[value={1}]() - %3 : Dynamic = prim::NumToTensor(%2) - %4 : Dynamic = aten::type_as(%3, %1) - %6 : Dynamic = aten::add[alpha={1}](%1, %4) - return (%6); + %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + return (%5); } diff --git a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect index b50d2189407b0..e36a68926dccc 100644 --- a/test/expect/TestScript.test_call_script_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_fn_from_script_fn.expect @@ -1,8 +1,5 @@ graph(%x : Dynamic) { %1 : Dynamic = aten::neg(%x) - %2 : int = prim::Constant[value={1}]() - %3 : Dynamic = prim::NumToTensor(%2) - %4 : Dynamic = aten::type_as(%3, %1) - %6 : Dynamic = aten::add[alpha={1}](%1, %4) - return (%6); + %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + return (%5); } diff --git a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect index e3008f4e24634..e24d034b26e3d 100644 --- a/test/expect/TestScript.test_call_script_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_script_mod_from_script_fn.expect @@ -1,17 +1,12 @@ graph(%x : Dynamic) { - %1 : int = prim::Constant[value={4}]() - %2 : int = prim::Constant[value={3}]() - %3 : int = prim::Constant[value={6}]() - %4 : int = prim::Constant[value={0}]() - %5 : int[] = prim::Constant[value= 0 -1 [ CPULongTensor{2} ]]() - %6 : Dynamic = prim::NumToTensor(%1) - %7 : Dynamic = prim::NumToTensor(%2) - %8 : int[] = aten::stack[dim=0](%6, %7) - %9 : Dynamic = aten::zeros(%8, %3, %4, %5) - %10 : Dynamic = aten::mm(%x, %9) - %11 : int = prim::Constant[value={1}]() - %12 : Dynamic = prim::NumToTensor(%11) - %13 : Dynamic = aten::type_as(%12, %10) - %15 : Dynamic = aten::add[alpha={1}](%10, %13) - return (%15); + %1 : int = prim::Constant[value=4]() + %2 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=6]() + %4 : int = prim::Constant[value=0]() + %5 : int[] = prim::Constant[value=[0, -1]]() + %6 : int[] = prim::ListConstruct(%1, %2) + %7 : Dynamic = aten::zeros(%6, %3, %4, %5) + %8 : Dynamic = aten::mm(%x, %7) + %12 : Dynamic = aten::add[other={1}, alpha={1}](%8) + return (%12); } diff --git a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect index a9894e7d3c1f4..83ce62e68e086 100644 --- a/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_fn_from_script_fn.expect @@ -1,8 +1,5 @@ graph(%x : Dynamic) { %1 : Double(3, 4) = aten::neg(%x) - %2 : int = prim::Constant[value={1}]() - %3 : Dynamic = prim::NumToTensor(%2) - %4 : Dynamic = aten::type_as(%3, %1) - %6 : Dynamic = aten::add[alpha={1}](%1, %4) - return (%6); + %5 : Dynamic = aten::add[other={1}, alpha={1}](%1) + return (%5); } diff --git a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect index e1e191223aaed..9a99fbe83f1d4 100644 --- a/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect +++ b/test/expect/TestScript.test_call_traced_mod_from_script_fn.expect @@ -1,9 +1,6 @@ graph(%x : Dynamic) { %1 : Double(4, 3) = prim::Constant[value=]() %2 : Double(3, 3) = aten::mm(%x, %1) - %3 : int = prim::Constant[value={1}]() - %4 : Dynamic = prim::NumToTensor(%3) - %5 : Dynamic = aten::type_as(%4, %2) - %7 : Dynamic = aten::add[alpha={1}](%2, %5) - return (%7); + %6 : Dynamic = aten::add[other={1}, alpha={1}](%2) + return (%6); } diff --git a/test/expect/TestScript.test_erase_number_types.expect b/test/expect/TestScript.test_erase_number_types.expect index 9632635635f31..6abce3392036a 100644 --- a/test/expect/TestScript.test_erase_number_types.expect +++ b/test/expect/TestScript.test_erase_number_types.expect @@ -1,10 +1,12 @@ graph(%a : Dynamic) { %1 : Long() = prim::Constant[value={7}]() %2 : Long() = prim::Constant[value={1}]() - %3 : Dynamic = aten::add[alpha={1}](%1, %2) + %3 : Long() = aten::add(%1, %2) %4 : Long() = prim::Constant[value={3}]() - %5 : Dynamic = aten::add[alpha={1}](%3, %4) - %c.1 : Dynamic = aten::add[alpha={1}](%a, %5) - %c : Dynamic = aten::add[alpha={1}](%c.1, %5) + %b : Long() = aten::add(%3, %4) + %6 : Long() = prim::Constant[value={1}]() + %c.1 : Dynamic = aten::add(%a, %b, %6) + %8 : Long() = prim::Constant[value={1}]() + %c : Dynamic = aten::add(%c.1, %b, %8) return (%c); } diff --git a/test/expect/TestScript.test_loop_unroll_unused_counter.expect b/test/expect/TestScript.test_loop_unroll_unused_counter.expect index 1a03ab906dbff..a4b5983c1e5d9 100644 --- a/test/expect/TestScript.test_loop_unroll_unused_counter.expect +++ b/test/expect/TestScript.test_loop_unroll_unused_counter.expect @@ -1,54 +1,30 @@ graph(%x : Dynamic) { %y.1 : Dynamic = ^FIXME_zerol()() - %2 : Byte() = prim::Constant[value={1}]() - %3 : Dynamic = aten::div[other={8}](%x) - %4 : Dynamic = aten::mul[other={8}](%3) - %5 : Dynamic = aten::sub[alpha={1}](%x, %4) - %y.3 : Dynamic = prim::Loop(%3, %2, %y.1) - block0(%i.1 : Dynamic, %8 : Dynamic) { - %9 : int = prim::Constant[value={1}]() - %10 : Dynamic = prim::NumToTensor(%9) - %11 : Dynamic = aten::type_as(%10, %8) - %y.12 : Dynamic = aten::add[alpha={1}](%8, %11) - %13 : int = prim::Constant[value={1}]() - %14 : Dynamic = prim::NumToTensor(%13) - %15 : Dynamic = aten::type_as(%14, %y.12) - %y.5 : Dynamic = aten::add[alpha={1}](%y.12, %15) - %17 : int = prim::Constant[value={1}]() - %18 : Dynamic = prim::NumToTensor(%17) - %19 : Dynamic = aten::type_as(%18, %y.5) - %y.6 : Dynamic = aten::add[alpha={1}](%y.5, %19) - %21 : int = prim::Constant[value={1}]() - %22 : Dynamic = prim::NumToTensor(%21) - %23 : Dynamic = aten::type_as(%22, %y.6) - %y.7 : Dynamic = aten::add[alpha={1}](%y.6, %23) - %25 : int = prim::Constant[value={1}]() - %26 : Dynamic = prim::NumToTensor(%25) - %27 : Dynamic = aten::type_as(%26, %y.7) - %y.8 : Dynamic = aten::add[alpha={1}](%y.7, %27) - %29 : int = prim::Constant[value={1}]() - %30 : Dynamic = prim::NumToTensor(%29) - %31 : Dynamic = aten::type_as(%30, %y.8) - %y.9 : Dynamic = aten::add[alpha={1}](%y.8, %31) - %33 : int = prim::Constant[value={1}]() - %34 : Dynamic = prim::NumToTensor(%33) - %35 : Dynamic = aten::type_as(%34, %y.9) - %y.10 : Dynamic = aten::add[alpha={1}](%y.9, %35) - %37 : int = prim::Constant[value={1}]() - %38 : Dynamic = prim::NumToTensor(%37) - %39 : Dynamic = aten::type_as(%38, %y.10) - %y.11 : Dynamic = aten::add[alpha={1}](%y.10, %39) - %41 : Byte() = prim::Constant[value={1}]() - -> (%41, %y.11) + %2 : int = prim::TensorToNum(%x) + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=8]() + %5 : int = aten::div(%2, %4) + %6 : int = prim::Constant[value=8]() + %7 : int = aten::mul(%5, %6) + %8 : int = aten::sub(%2, %7) + %y.3 : Dynamic = prim::Loop(%5, %3, %y.1) + block0(%i.1 : int, %11 : Dynamic) { + %y.12 : Dynamic = aten::add[other={1}, alpha={1}](%11) + %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.12) + %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) + %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) + %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) + %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) + %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) + %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.10) + %20 : int = prim::Constant[value=1]() + -> (%20, %y.11) } - %y : Dynamic = prim::Loop(%5, %2, %y.3) - block0(%i : Dynamic, %44 : Dynamic) { - %45 : int = prim::Constant[value={1}]() - %46 : Dynamic = prim::NumToTensor(%45) - %47 : Dynamic = aten::type_as(%46, %44) - %y.4 : Dynamic = aten::add[alpha={1}](%44, %47) - %49 : Byte() = prim::Constant[value={1}]() - -> (%49, %y.4) + %y : Dynamic = prim::Loop(%8, %3, %y.3) + block0(%i : int, %23 : Dynamic) { + %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%23) + %25 : int = prim::Constant[value=1]() + -> (%25, %y.4) } return (%y); } diff --git a/test/expect/TestScript.test_loop_unrolling.expect b/test/expect/TestScript.test_loop_unrolling.expect index 54fa974c0d136..0c77a4ec47e6e 100644 --- a/test/expect/TestScript.test_loop_unrolling.expect +++ b/test/expect/TestScript.test_loop_unrolling.expect @@ -1,37 +1,58 @@ graph(%x : Dynamic) { %y.1 : Dynamic = ^FIXME_zerol()() - %2 : Byte() = prim::Constant[value={1}]() - %3 : Long() = prim::Constant[value={0}]() - %4 : Dynamic = aten::div[other={8}](%x) - %5 : Dynamic = aten::mul[other={8}](%4) - %6 : Dynamic = aten::sub[alpha={1}](%x, %5) - %7 : Dynamic, %y.3 : Dynamic = prim::Loop(%4, %2, %3, %y.1) - block0(%i.1 : Dynamic, %10 : Dynamic, %11 : Dynamic) { - %y.12 : Dynamic = aten::add[alpha={1}](%11, %10) - %13 : Dynamic = aten::add[alpha={1}, other={1}](%10) - %y.5 : Dynamic = aten::add[alpha={1}](%y.12, %13) - %15 : Dynamic = aten::add[alpha={1}, other={1}](%13) - %y.6 : Dynamic = aten::add[alpha={1}](%y.5, %15) - %17 : Dynamic = aten::add[alpha={1}, other={1}](%15) - %y.7 : Dynamic = aten::add[alpha={1}](%y.6, %17) - %19 : Dynamic = aten::add[alpha={1}, other={1}](%17) - %y.8 : Dynamic = aten::add[alpha={1}](%y.7, %19) - %21 : Dynamic = aten::add[alpha={1}, other={1}](%19) - %y.9 : Dynamic = aten::add[alpha={1}](%y.8, %21) - %23 : Dynamic = aten::add[alpha={1}, other={1}](%21) - %y.10 : Dynamic = aten::add[alpha={1}](%y.9, %23) - %25 : Dynamic = aten::add[alpha={1}, other={1}](%23) - %y.11 : Dynamic = aten::add[alpha={1}](%y.10, %25) - %27 : Byte() = prim::Constant[value={1}]() - %28 : Dynamic = aten::add[alpha={1}, other={1}](%25) - -> (%27, %28, %y.11) + %2 : int = prim::TensorToNum(%x) + %3 : int = prim::Constant[value=1]() + %4 : int = prim::Constant[value=0]() + %5 : int = prim::Constant[value=8]() + %6 : int = aten::div(%2, %5) + %7 : int = prim::Constant[value=8]() + %8 : int = aten::mul(%6, %7) + %9 : int = aten::sub(%2, %8) + %10 : Dynamic, %y.3 : Dynamic = prim::Loop(%6, %3, %4, %y.1) + block0(%i.1 : int, %13 : Dynamic, %14 : Dynamic) { + %15 : Number = prim::Constant[value=1]() + %y.12 : Dynamic = aten::add(%14, %13, %15) + %17 : int = prim::Constant[value=1]() + %18 : int = aten::add(%13, %17) + %19 : Number = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.12, %18, %19) + %21 : int = prim::Constant[value=1]() + %22 : int = aten::add(%18, %21) + %23 : Number = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %22, %23) + %25 : int = prim::Constant[value=1]() + %26 : int = aten::add(%22, %25) + %27 : Number = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %26, %27) + %29 : int = prim::Constant[value=1]() + %30 : int = aten::add(%26, %29) + %31 : Number = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %30, %31) + %33 : int = prim::Constant[value=1]() + %34 : int = aten::add(%30, %33) + %35 : Number = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %34, %35) + %37 : int = prim::Constant[value=1]() + %38 : int = aten::add(%34, %37) + %39 : Number = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %38, %39) + %41 : int = prim::Constant[value=1]() + %42 : int = aten::add(%38, %41) + %43 : Number = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.10, %42, %43) + %45 : int = prim::Constant[value=1]() + %46 : int = prim::Constant[value=1]() + %47 : int = aten::add(%42, %46) + -> (%45, %47, %y.11) } - %29 : Dynamic, %y : Dynamic = prim::Loop(%6, %2, %7, %y.3) - block0(%i : Dynamic, %32 : Dynamic, %33 : Dynamic) { - %y.4 : Dynamic = aten::add[alpha={1}](%33, %32) - %35 : Byte() = prim::Constant[value={1}]() - %36 : Dynamic = aten::add[alpha={1}, other={1}](%32) - -> (%35, %36, %y.4) + %48 : Dynamic, %y : Dynamic = prim::Loop(%9, %3, %10, %y.3) + block0(%i : int, %51 : Dynamic, %52 : Dynamic) { + %53 : Number = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%52, %51, %53) + %55 : int = prim::Constant[value=1]() + %56 : int = prim::Constant[value=1]() + %57 : int = aten::add(%51, %56) + -> (%55, %57, %y.4) } return (%y); } diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect index 5a6fc16063872..8f810b0a6339b 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_const.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_const.expect @@ -1,44 +1,14 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() - %1 : int = prim::Constant[value={1}]() - %2 : Dynamic = prim::NumToTensor(%1) - %3 : Dynamic = aten::type_as(%2, %y.1) - %y.11 : Dynamic = aten::add[alpha={1}](%y.1, %3) - %5 : int = prim::Constant[value={1}]() - %6 : Dynamic = prim::NumToTensor(%5) - %7 : Dynamic = aten::type_as(%6, %y.11) - %y.2 : Dynamic = aten::add[alpha={1}](%y.11, %7) - %9 : int = prim::Constant[value={1}]() - %10 : Dynamic = prim::NumToTensor(%9) - %11 : Dynamic = aten::type_as(%10, %y.2) - %y.3 : Dynamic = aten::add[alpha={1}](%y.2, %11) - %13 : int = prim::Constant[value={1}]() - %14 : Dynamic = prim::NumToTensor(%13) - %15 : Dynamic = aten::type_as(%14, %y.3) - %y.4 : Dynamic = aten::add[alpha={1}](%y.3, %15) - %17 : int = prim::Constant[value={1}]() - %18 : Dynamic = prim::NumToTensor(%17) - %19 : Dynamic = aten::type_as(%18, %y.4) - %y.5 : Dynamic = aten::add[alpha={1}](%y.4, %19) - %21 : int = prim::Constant[value={1}]() - %22 : Dynamic = prim::NumToTensor(%21) - %23 : Dynamic = aten::type_as(%22, %y.5) - %y.6 : Dynamic = aten::add[alpha={1}](%y.5, %23) - %25 : int = prim::Constant[value={1}]() - %26 : Dynamic = prim::NumToTensor(%25) - %27 : Dynamic = aten::type_as(%26, %y.6) - %y.7 : Dynamic = aten::add[alpha={1}](%y.6, %27) - %29 : int = prim::Constant[value={1}]() - %30 : Dynamic = prim::NumToTensor(%29) - %31 : Dynamic = aten::type_as(%30, %y.7) - %y.8 : Dynamic = aten::add[alpha={1}](%y.7, %31) - %33 : int = prim::Constant[value={1}]() - %34 : Dynamic = prim::NumToTensor(%33) - %35 : Dynamic = aten::type_as(%34, %y.8) - %y.9 : Dynamic = aten::add[alpha={1}](%y.8, %35) - %37 : int = prim::Constant[value={1}]() - %38 : Dynamic = prim::NumToTensor(%37) - %39 : Dynamic = aten::type_as(%38, %y.9) - %y.10 : Dynamic = aten::add[alpha={1}](%y.9, %39) + %y.11 : Dynamic = aten::add[other={1}, alpha={1}](%y.1) + %y.2 : Dynamic = aten::add[other={1}, alpha={1}](%y.11) + %y.3 : Dynamic = aten::add[other={1}, alpha={1}](%y.2) + %y.4 : Dynamic = aten::add[other={1}, alpha={1}](%y.3) + %y.5 : Dynamic = aten::add[other={1}, alpha={1}](%y.4) + %y.6 : Dynamic = aten::add[other={1}, alpha={1}](%y.5) + %y.7 : Dynamic = aten::add[other={1}, alpha={1}](%y.6) + %y.8 : Dynamic = aten::add[other={1}, alpha={1}](%y.7) + %y.9 : Dynamic = aten::add[other={1}, alpha={1}](%y.8) + %y.10 : Dynamic = aten::add[other={1}, alpha={1}](%y.9) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect index 020a401db01dc..2618493dc8ecb 100644 --- a/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect +++ b/test/expect/TestScript.test_loop_unrolling_const-add_iter.expect @@ -1,24 +1,43 @@ graph() { %y.1 : Dynamic = ^FIXME_zerol()() - %1 : Long() = prim::Constant[value={0}]() - %y.11 : Dynamic = aten::add[alpha={1}](%y.1, %1) - %3 : Dynamic = aten::add[alpha={1}, other={1}](%1) - %y.2 : Dynamic = aten::add[alpha={1}](%y.11, %3) - %5 : Dynamic = aten::add[alpha={1}, other={1}](%3) - %y.3 : Dynamic = aten::add[alpha={1}](%y.2, %5) - %7 : Dynamic = aten::add[alpha={1}, other={1}](%5) - %y.4 : Dynamic = aten::add[alpha={1}](%y.3, %7) - %9 : Dynamic = aten::add[alpha={1}, other={1}](%7) - %y.5 : Dynamic = aten::add[alpha={1}](%y.4, %9) - %11 : Dynamic = aten::add[alpha={1}, other={1}](%9) - %y.6 : Dynamic = aten::add[alpha={1}](%y.5, %11) - %13 : Dynamic = aten::add[alpha={1}, other={1}](%11) - %y.7 : Dynamic = aten::add[alpha={1}](%y.6, %13) - %15 : Dynamic = aten::add[alpha={1}, other={1}](%13) - %y.8 : Dynamic = aten::add[alpha={1}](%y.7, %15) - %17 : Dynamic = aten::add[alpha={1}, other={1}](%15) - %y.9 : Dynamic = aten::add[alpha={1}](%y.8, %17) - %19 : Dynamic = aten::add[alpha={1}, other={1}](%17) - %y.10 : Dynamic = aten::add[alpha={1}](%y.9, %19) + %1 : int = prim::Constant[value=0]() + %2 : Number = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.1, %1, %2) + %4 : int = prim::Constant[value=1]() + %5 : int = aten::add(%1, %4) + %6 : Number = prim::Constant[value=1]() + %y.2 : Dynamic = aten::add(%y.11, %5, %6) + %8 : int = prim::Constant[value=1]() + %9 : int = aten::add(%5, %8) + %10 : Number = prim::Constant[value=1]() + %y.3 : Dynamic = aten::add(%y.2, %9, %10) + %12 : int = prim::Constant[value=1]() + %13 : int = aten::add(%9, %12) + %14 : Number = prim::Constant[value=1]() + %y.4 : Dynamic = aten::add(%y.3, %13, %14) + %16 : int = prim::Constant[value=1]() + %17 : int = aten::add(%13, %16) + %18 : Number = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%y.4, %17, %18) + %20 : int = prim::Constant[value=1]() + %21 : int = aten::add(%17, %20) + %22 : Number = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.5, %21, %22) + %24 : int = prim::Constant[value=1]() + %25 : int = aten::add(%21, %24) + %26 : Number = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %25, %26) + %28 : int = prim::Constant[value=1]() + %29 : int = aten::add(%25, %28) + %30 : Number = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %29, %30) + %32 : int = prim::Constant[value=1]() + %33 : int = aten::add(%29, %32) + %34 : Number = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %33, %34) + %36 : int = prim::Constant[value=1]() + %37 : int = aten::add(%33, %36) + %38 : Number = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %37, %38) return (%y.10); } diff --git a/test/expect/TestScript.test_loop_unrolling_nested.expect b/test/expect/TestScript.test_loop_unrolling_nested.expect index f19c54d150b4a..3b8832d03071a 100644 --- a/test/expect/TestScript.test_loop_unrolling_nested.expect +++ b/test/expect/TestScript.test_loop_unrolling_nested.expect @@ -1,44 +1,65 @@ graph(%x : Dynamic) { %y.1 : Dynamic = ^FIXME_zerol()() - %2 : int = prim::Constant[value={10}]() - %3 : Byte() = prim::Constant[value={1}]() + %2 : int = prim::Constant[value=10]() + %3 : int = prim::Constant[value=1]() %y : Dynamic = prim::Loop(%2, %3, %y.1) - block0(%i : Dynamic, %6 : Dynamic) { - %7 : Byte() = prim::Constant[value={1}]() - %8 : Long() = prim::Constant[value={0}]() - %9 : Dynamic = aten::div[other={8}](%x) - %10 : Dynamic = aten::mul[other={8}](%9) - %11 : Dynamic = aten::sub[alpha={1}](%x, %10) - %12 : Dynamic, %y.4 : Dynamic = prim::Loop(%9, %7, %8, %6) - block0(%j.1 : Dynamic, %15 : Dynamic, %16 : Dynamic) { - %y.13 : Dynamic = aten::add[alpha={1}](%16, %15) - %18 : Dynamic = aten::add[alpha={1}, other={1}](%15) - %y.6 : Dynamic = aten::add[alpha={1}](%y.13, %18) - %20 : Dynamic = aten::add[alpha={1}, other={1}](%18) - %y.7 : Dynamic = aten::add[alpha={1}](%y.6, %20) - %22 : Dynamic = aten::add[alpha={1}, other={1}](%20) - %y.8 : Dynamic = aten::add[alpha={1}](%y.7, %22) - %24 : Dynamic = aten::add[alpha={1}, other={1}](%22) - %y.9 : Dynamic = aten::add[alpha={1}](%y.8, %24) - %26 : Dynamic = aten::add[alpha={1}, other={1}](%24) - %y.10 : Dynamic = aten::add[alpha={1}](%y.9, %26) - %28 : Dynamic = aten::add[alpha={1}, other={1}](%26) - %y.11 : Dynamic = aten::add[alpha={1}](%y.10, %28) - %30 : Dynamic = aten::add[alpha={1}, other={1}](%28) - %y.12 : Dynamic = aten::add[alpha={1}](%y.11, %30) - %32 : Byte() = prim::Constant[value={1}]() - %33 : Dynamic = aten::add[alpha={1}, other={1}](%30) - -> (%32, %33, %y.12) + block0(%i : int, %6 : Dynamic) { + %7 : int = prim::TensorToNum(%x) + %8 : int = prim::Constant[value=1]() + %9 : int = prim::Constant[value=0]() + %10 : int = prim::Constant[value=8]() + %11 : int = aten::div(%7, %10) + %12 : int = prim::Constant[value=8]() + %13 : int = aten::mul(%11, %12) + %14 : int = aten::sub(%7, %13) + %15 : Dynamic, %y.4 : Dynamic = prim::Loop(%11, %8, %9, %6) + block0(%j.1 : int, %18 : Dynamic, %19 : Dynamic) { + %20 : Number = prim::Constant[value=1]() + %y.13 : Dynamic = aten::add(%19, %18, %20) + %22 : int = prim::Constant[value=1]() + %23 : int = aten::add(%18, %22) + %24 : Number = prim::Constant[value=1]() + %y.6 : Dynamic = aten::add(%y.13, %23, %24) + %26 : int = prim::Constant[value=1]() + %27 : int = aten::add(%23, %26) + %28 : Number = prim::Constant[value=1]() + %y.7 : Dynamic = aten::add(%y.6, %27, %28) + %30 : int = prim::Constant[value=1]() + %31 : int = aten::add(%27, %30) + %32 : Number = prim::Constant[value=1]() + %y.8 : Dynamic = aten::add(%y.7, %31, %32) + %34 : int = prim::Constant[value=1]() + %35 : int = aten::add(%31, %34) + %36 : Number = prim::Constant[value=1]() + %y.9 : Dynamic = aten::add(%y.8, %35, %36) + %38 : int = prim::Constant[value=1]() + %39 : int = aten::add(%35, %38) + %40 : Number = prim::Constant[value=1]() + %y.10 : Dynamic = aten::add(%y.9, %39, %40) + %42 : int = prim::Constant[value=1]() + %43 : int = aten::add(%39, %42) + %44 : Number = prim::Constant[value=1]() + %y.11 : Dynamic = aten::add(%y.10, %43, %44) + %46 : int = prim::Constant[value=1]() + %47 : int = aten::add(%43, %46) + %48 : Number = prim::Constant[value=1]() + %y.12 : Dynamic = aten::add(%y.11, %47, %48) + %50 : int = prim::Constant[value=1]() + %51 : int = prim::Constant[value=1]() + %52 : int = aten::add(%47, %51) + -> (%50, %52, %y.12) } - %34 : Dynamic, %y.3 : Dynamic = prim::Loop(%11, %7, %12, %y.4) - block0(%j : Dynamic, %37 : Dynamic, %38 : Dynamic) { - %y.5 : Dynamic = aten::add[alpha={1}](%38, %37) - %40 : Byte() = prim::Constant[value={1}]() - %41 : Dynamic = aten::add[alpha={1}, other={1}](%37) - -> (%40, %41, %y.5) + %53 : Dynamic, %y.3 : Dynamic = prim::Loop(%14, %8, %15, %y.4) + block0(%j : int, %56 : Dynamic, %57 : Dynamic) { + %58 : Number = prim::Constant[value=1]() + %y.5 : Dynamic = aten::add(%57, %56, %58) + %60 : int = prim::Constant[value=1]() + %61 : int = prim::Constant[value=1]() + %62 : int = aten::add(%56, %61) + -> (%60, %62, %y.5) } - %42 : Byte() = prim::Constant[value={1}]() - -> (%42, %y.3) + %63 : int = prim::Constant[value=1]() + -> (%63, %y.3) } return (%y); } diff --git a/test/expect/TestScript.test_math_numbers-float.expect b/test/expect/TestScript.test_math_numbers-float.expect index 67ea8b4c5eb39..1c9231145bf7b 100644 --- a/test/expect/TestScript.test_math_numbers-float.expect +++ b/test/expect/TestScript.test_math_numbers-float.expect @@ -1,16 +1,12 @@ graph(%x : Dynamic) { - %1 : float = prim::Constant[value={1.1}]() - %2 : float = prim::Constant[value={3.1}]() - %3 : Dynamic = prim::NumToTensor(%1) - %4 : Dynamic = prim::NumToTensor(%2) - %5 : Dynamic = aten::add[alpha={1}](%3, %4) - %c : float = prim::TensorToNum(%5) - %7 : int = prim::Constant[value={1}]() - %8 : int = prim::Constant[value={6}]() - %9 : int = prim::Constant[value={0}]() - %10 : int[] = prim::Constant[value= 0 -1 [ CPULongTensor{2} ]]() - %11 : Dynamic = prim::NumToTensor(%7) - %12 : int[] = aten::stack[dim=0](%11) - %13 : Dynamic = aten::full(%12, %c, %8, %9, %10) - return (%13); + %1 : float = prim::Constant[value=1.1]() + %2 : float = prim::Constant[value=3.1]() + %c : float = aten::add(%1, %2) + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=6]() + %6 : int = prim::Constant[value=0]() + %7 : int[] = prim::Constant[value=[0, -1]]() + %8 : int[] = prim::ListConstruct(%4) + %9 : Dynamic = aten::full(%8, %c, %5, %6, %7) + return (%9); } diff --git a/test/expect/TestScript.test_math_numbers-int.expect b/test/expect/TestScript.test_math_numbers-int.expect index 9f028597ca071..385f904a8a5f8 100644 --- a/test/expect/TestScript.test_math_numbers-int.expect +++ b/test/expect/TestScript.test_math_numbers-int.expect @@ -1,16 +1,12 @@ graph(%x : Dynamic) { - %1 : int = prim::Constant[value={7}]() - %2 : int = prim::Constant[value={8}]() - %3 : Dynamic = prim::NumToTensor(%1) - %4 : Dynamic = prim::NumToTensor(%2) - %5 : Dynamic = aten::add[alpha={1}](%3, %4) - %c : int = prim::TensorToNum(%5) - %7 : int = prim::Constant[value={1}]() - %8 : int = prim::Constant[value={6}]() - %9 : int = prim::Constant[value={0}]() - %10 : int[] = prim::Constant[value= 0 -1 [ CPULongTensor{2} ]]() - %11 : Dynamic = prim::NumToTensor(%7) - %12 : int[] = aten::stack[dim=0](%11) - %13 : Dynamic = aten::full(%12, %c, %8, %9, %10) - return (%13); + %1 : int = prim::Constant[value=7]() + %2 : int = prim::Constant[value=8]() + %c : int = aten::add(%1, %2) + %4 : int = prim::Constant[value=1]() + %5 : int = prim::Constant[value=6]() + %6 : int = prim::Constant[value=0]() + %7 : int[] = prim::Constant[value=[0, -1]]() + %8 : int[] = prim::ListConstruct(%4) + %9 : Dynamic = aten::full(%8, %c, %5, %6, %7) + return (%9); } diff --git a/test/expect/TestScript.test_math_tensor_number.expect b/test/expect/TestScript.test_math_tensor_number.expect index f9d0ded7361b9..c0a88913280e5 100644 --- a/test/expect/TestScript.test_math_tensor_number.expect +++ b/test/expect/TestScript.test_math_tensor_number.expect @@ -1,7 +1,4 @@ graph(%x : Dynamic) { - %1 : int = prim::Constant[value={7}]() - %2 : Dynamic = prim::NumToTensor(%1) - %3 : Dynamic = aten::type_as(%2, %x) - %4 : Dynamic = aten::add[alpha={1}](%x, %3) - return (%4); + %1 : Dynamic = aten::add[other={7}, alpha={1}](%x) + return (%1); } diff --git a/test/expect/TestScript.test_onnx_export_script_module_if.expect b/test/expect/TestScript.test_onnx_export_script_module_if.expect index 264b86367ef2e..6063918f9c3c5 100644 --- a/test/expect/TestScript.test_onnx_export_script_module_if.expect +++ b/test/expect/TestScript.test_onnx_export_script_module_if.expect @@ -11,7 +11,7 @@ ModelProto { nodes: [ Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, - Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []}, + Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: [{ name: 'broadcast', type: int, value: 1}]}, Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export1" diff --git a/test/expect/TestScript.test_sum-1.expect b/test/expect/TestScript.test_sum-1.expect index 8e165369778fe..f8599a2ac66ec 100644 --- a/test/expect/TestScript.test_sum-1.expect +++ b/test/expect/TestScript.test_sum-1.expect @@ -1,8 +1,7 @@ graph(%x : Dynamic) { - %1 : int = prim::Constant[value={4}]() - %2 : int = prim::Constant[value={0}]() - %3 : Dynamic = prim::NumToTensor(%1) - %4 : int[] = aten::stack[dim=0](%3) - %5 : Dynamic = aten::sum(%x, %4, %2) - return (%5); + %1 : int = prim::Constant[value=4]() + %2 : int = prim::Constant[value=0]() + %3 : int[] = prim::ListConstruct(%1) + %4 : Dynamic = aten::sum(%x, %3, %2) + return (%4); } diff --git a/test/expect/TestScript.test_sum-2.expect b/test/expect/TestScript.test_sum-2.expect index dece8c4d7cc0b..1d2b93741efb0 100644 --- a/test/expect/TestScript.test_sum-2.expect +++ b/test/expect/TestScript.test_sum-2.expect @@ -1,8 +1,7 @@ graph(%x : Double(1, 1, 1, 1, 4)) { - %1 : Long() = prim::Constant[value={4}]() - %2 : Long() = prim::Constant[value={0}]() - %3 : Long() = prim::NumToTensor(%1) - %4 : Dynamic = aten::stack[dim=0](%3) - %5 : Dynamic = aten::sum(%x, %4, %2) - return (%5); + %1 : int = prim::Constant[value=4]() + %2 : int = prim::Constant[value=0]() + %3 : int[] = prim::ListConstruct(%1) + %4 : Dynamic = aten::sum(%x, %3, %2) + return (%4); } diff --git a/test/test_dataloader.py b/test/test_dataloader.py index e299412cee6b0..7670b387ae63c 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -356,7 +356,6 @@ def test_multiple_dataloaders(self): next(loader1_it) next(loader2_it) - @unittest.skip("temporarily disable until flaky failures are fixed") def test_segfault(self): p = ErrorTrackingProcess(target=_test_segfault) p.start() diff --git a/test/test_distributions.py b/test/test_distributions.py index f53271e1ea027..ec0d2f061a461 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -40,7 +40,8 @@ FisherSnedecor, Gamma, Geometric, Gumbel, HalfCauchy, HalfNormal, Independent, Laplace, LogisticNormal, - LogNormal, Multinomial, MultivariateNormal, + LogNormal, LowRankMultivariateNormal, + Multinomial, MultivariateNormal, Normal, OneHotCategorical, Pareto, Poisson, RelaxedBernoulli, RelaxedOneHotCategorical, StudentT, TransformedDistribution, Uniform, @@ -260,6 +261,18 @@ def is_all_nan(tensor): 'scale': torch.tensor([1e-5, 1e-5], requires_grad=True), }, ]), + Example(LowRankMultivariateNormal, [ + { + 'loc': torch.randn(5, 2, requires_grad=True), + 'cov_factor': torch.randn(5, 2, 1, requires_grad=True), + 'cov_diag': torch.tensor([2.0, 0.25], requires_grad=True), + }, + { + 'loc': torch.randn(4, 3, requires_grad=True), + 'cov_factor': torch.randn(3, 2, requires_grad=True), + 'cov_diag': torch.tensor([5.0, 1.5, 3.], requires_grad=True), + } + ]), Example(MultivariateNormal, [ { 'loc': torch.randn(5, 2, requires_grad=True), @@ -1448,6 +1461,125 @@ def test_normal_sample(self): scipy.stats.norm(loc=loc, scale=scale), 'Normal(mean={}, std={})'.format(loc, scale)) + def test_lowrank_multivariate_normal_shape(self): + mean = torch.randn(5, 3, requires_grad=True) + mean_no_batch = torch.randn(3, requires_grad=True) + mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True) + + # construct PSD covariance + cov_factor = torch.randn(3, 1, requires_grad=True) + cov_diag = torch.tensor(torch.randn(3).abs(), requires_grad=True) + + # construct batch of PSD covariances + cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True) + cov_diag_batched = torch.tensor(torch.randn(6, 5, 3).abs(), requires_grad=True) + + # ensure that sample, batch, event shapes all handled correctly + self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag) + .sample().size(), (5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) + .sample().size(), (3,)) + self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) + .sample().size(), (6, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag) + .sample((2,)).size(), (2, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) + .sample((2,)).size(), (2, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) + .sample((2,)).size(), (2, 6, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag) + .sample((2, 7)).size(), (2, 7, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag) + .sample((2, 7)).size(), (2, 7, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag) + .sample((2, 7)).size(), (2, 7, 6, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched) + .sample((2, 7)).size(), (2, 7, 6, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor_batched, cov_diag_batched) + .sample((2, 7)).size(), (2, 7, 6, 5, 3)) + self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor_batched, cov_diag_batched) + .sample((2, 7)).size(), (2, 7, 6, 5, 3)) + + # check gradients + self._gradcheck_log_prob(LowRankMultivariateNormal, + (mean, cov_factor, cov_diag)) + self._gradcheck_log_prob(LowRankMultivariateNormal, + (mean_multi_batch, cov_factor, cov_diag)) + self._gradcheck_log_prob(LowRankMultivariateNormal, + (mean_multi_batch, cov_factor_batched, cov_diag_batched)) + + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_lowrank_multivariate_normal_log_prob(self): + mean = torch.randn(3, requires_grad=True) + cov_factor = torch.randn(3, 1, requires_grad=True) + cov_diag = torch.tensor(torch.randn(3).abs(), requires_grad=True) + cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() + + # check that logprob values match scipy logpdf, + # and that covariance and scale_tril parameters are equivalent + dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag) + ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()) + + x = dist1.sample((10,)) + expected = ref_dist.logpdf(x.numpy()) + + self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3) + + # Double-check that batched versions behave the same as unbatched + mean = torch.randn(5, 3, requires_grad=True) + cov_factor = torch.randn(5, 3, 2, requires_grad=True) + cov_diag = torch.tensor(torch.randn(5, 3).abs(), requires_grad=True) + + dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag) + dist_unbatched = [LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i]) + for i in range(mean.size(0))] + + x = dist_batched.sample((10,)) + batched_prob = dist_batched.log_prob(x) + unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t() + + self.assertEqual(batched_prob.shape, unbatched_prob.shape) + self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3) + + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_lowrank_multivariate_normal_sample(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + mean = torch.randn(5, requires_grad=True) + cov_factor = torch.randn(5, 1, requires_grad=True) + cov_diag = torch.tensor(torch.randn(5).abs(), requires_grad=True) + cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() + + self._check_sampler_sampler(LowRankMultivariateNormal(mean, cov_factor, cov_diag), + scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()), + 'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})' + .format(mean, cov_factor, cov_diag), multivariate=True) + + def test_lowrank_multivariate_normal_properties(self): + loc = torch.randn(5) + cov_factor = torch.randn(5, 2) + cov_diag = torch.tensor(torch.randn(5).abs()) + cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag() + m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag) + m2 = MultivariateNormal(loc=loc, covariance_matrix=cov) + self.assertEqual(m1.mean, m2.mean) + self.assertEqual(m1.variance, m2.variance) + self.assertEqual(m1.covariance_matrix, m2.covariance_matrix) + self.assertEqual(m1.scale_tril, m2.scale_tril) + self.assertEqual(m1.precision_matrix, m2.precision_matrix) + self.assertEqual(m1.entropy(), m2.entropy()) + + def test_lowrank_multivariate_normal_moments(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + mean = torch.randn(5) + cov_factor = torch.randn(5, 2) + cov_diag = torch.tensor(torch.randn(5).abs()) + d = LowRankMultivariateNormal(mean, cov_factor, cov_diag) + samples = d.rsample((100000,)) + empirical_mean = samples.mean(0) + self.assertEqual(d.mean, empirical_mean, prec=0.01) + empirical_var = samples.var(0) + self.assertEqual(d.variance, empirical_var, prec=0.02) + def test_multivariate_normal_shape(self): mean = torch.randn(5, 3, requires_grad=True) mean_no_batch = torch.randn(3, requires_grad=True) @@ -1561,6 +1693,17 @@ def test_multivariate_normal_properties(self): self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0])) self.assertEqual(m.scale_tril, torch.potrf(m.covariance_matrix, upper=False)) + def test_multivariate_normal_moments(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + mean = torch.randn(5) + scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5)) + d = MultivariateNormal(mean, scale_tril=scale_tril) + samples = d.rsample((100000,)) + empirical_mean = samples.mean(0) + self.assertEqual(d.mean, empirical_mean, prec=0.01) + empirical_var = samples.var(0) + self.assertEqual(d.variance, empirical_var, prec=0.05) + def test_exponential(self): rate = torch.tensor(torch.randn(5, 5).abs(), requires_grad=True) rate_1d = torch.tensor(torch.randn(1).abs(), requires_grad=True) @@ -2949,6 +3092,71 @@ def test_kl_multivariate_normal_batched(self): MultivariateNormal(loc[1], scale_tril=scale_tril[1])) self.assertEqual(expected_kl, actual_kl) + def test_kl_multivariate_normal_batched_broadcasted(self): + b = 7 # Number of batches + loc = [torch.randn(b, 3) for _ in range(0, 2)] + scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)), + transform_to(constraints.lower_cholesky)(torch.randn(3, 3))] + expected_kl = torch.stack([ + kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]), + MultivariateNormal(loc[1][i], scale_tril=scale_tril[1])) for i in range(0, b)]) + actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]), + MultivariateNormal(loc[1], scale_tril=scale_tril[1])) + self.assertEqual(expected_kl, actual_kl) + + def test_kl_lowrank_multivariate_normal(self): + set_rng_seed(0) # see Note [Randomized statistical tests] + n = 5 # Number of tests for lowrank_multivariate_normal + for i in range(0, n): + loc = [torch.randn(4) for _ in range(0, 2)] + cov_factor = [torch.randn(4, 3) for _ in range(0, 2)] + cov_diag = [transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2)] + covariance_matrix = [cov_factor[i].matmul(cov_factor[i].t()) + + cov_diag[i].diag() for i in range(0, 2)] + p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]) + q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]) + p_full = MultivariateNormal(loc[0], covariance_matrix[0]) + q_full = MultivariateNormal(loc[1], covariance_matrix[1]) + expected = kl_divergence(p_full, q_full) + + actual_lowrank_lowrank = kl_divergence(p, q) + actual_lowrank_full = kl_divergence(p, q_full) + actual_full_lowrank = kl_divergence(p_full, q) + + error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max() + self.assertLess(error_lowrank_lowrank, self.precision, '\n'.join([ + 'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n), + 'Expected (from KL MultivariateNormal): {}'.format(expected), + 'Actual (analytic): {}'.format(actual_lowrank_lowrank), + ])) + + error_lowrank_full = torch.abs(actual_lowrank_full - expected).max() + self.assertLess(error_lowrank_full, self.precision, '\n'.join([ + 'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n), + 'Expected (from KL MultivariateNormal): {}'.format(expected), + 'Actual (analytic): {}'.format(actual_lowrank_full), + ])) + + error_full_lowrank = torch.abs(actual_full_lowrank - expected).max() + self.assertLess(error_full_lowrank, self.precision, '\n'.join([ + 'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n), + 'Expected (from KL MultivariateNormal): {}'.format(expected), + 'Actual (analytic): {}'.format(actual_full_lowrank), + ])) + + def test_kl_lowrank_multivariate_normal_batched(self): + b = 7 # Number of batches + loc = [torch.randn(b, 3) for _ in range(0, 2)] + cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)] + cov_diag = [transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2)] + expected_kl = torch.stack([ + kl_divergence(LowRankMultivariateNormal(loc[0][i], cov_factor[0][i], cov_diag[0][i]), + LowRankMultivariateNormal(loc[1][i], cov_factor[1][i], cov_diag[1][i])) + for i in range(0, b)]) + actual_kl = kl_divergence(LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]), + LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])) + self.assertEqual(expected_kl, actual_kl) + def test_kl_exponential_family(self): for (p, _), (_, q) in self.finite_examples: if type(p) == type(q) and issubclass(type(p), ExponentialFamily): @@ -3289,6 +3497,10 @@ def setUp(self): LogNormal(random_var, positive_var.clamp(max=3)), scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp()) ), + ( + LowRankMultivariateNormal(random_var, torch.zeros(20, 1), positive_var2), + scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2)) + ), ( Multinomial(10, simplex_tensor), scipy.stats.multinomial(10, simplex_tensor) @@ -3328,7 +3540,7 @@ def test_mean(self): if isinstance(pytorch_dist, (Cauchy, HalfCauchy)): # Cauchy, HalfCauchy distributions' mean is nan, skipping check continue - elif isinstance(pytorch_dist, MultivariateNormal): + elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)): self.assertEqual(pytorch_dist.mean, scipy_dist.mean, allow_inf=True, message=pytorch_dist) else: self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist) @@ -3341,7 +3553,7 @@ def test_variance_stddev(self): elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)): self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist) - elif isinstance(pytorch_dist, MultivariateNormal): + elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)): self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov), message=pytorch_dist) self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov) ** 0.5, message=pytorch_dist) else: diff --git a/test/test_jit.py b/test/test_jit.py index a20436e167188..5df9caaa4eaca 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1088,7 +1088,7 @@ def addmm(mat, mat1, mat2, alpha, beta): a = mat.addmm(mat1, mat2) b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0) c = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0) - d = mat.addmm(mat1, mat2, alpha=alpha, beta=beta) + d = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta)) return a + b + c + d @@ -2033,7 +2033,6 @@ def test(op, const, swap_args): scope = {} exec(code, globals(), scope) cu = torch.jit.CompilationUnit(code) - self.assertEqual(cu.func(tensor), scope['func'](tensor)) var_int = 2 @@ -3491,7 +3490,7 @@ def func(a): def test_loop_unrolling(self): def fn(x): y = FIXME_zerol() - for i in range(x): + for i in range(int(x)): y += i return y @@ -3526,7 +3525,7 @@ def test_loop_unrolling_nested(self): def fn(x): y = FIXME_zerol() for i in range(10): - for j in range(x): + for j in range(int(x)): y += j return y @@ -3538,7 +3537,7 @@ def fn(x): def test_loop_unroll_unused_counter(self): def fn(x): y = FIXME_zerol() - for i in range(x): + for i in range(int(x)): y += 1 return y @@ -3549,7 +3548,7 @@ def fn(x): def test_loop_unroll_negative(self): def fn(x): y = FIXME_zerol() - for i in range(x): + for i in range(int(x)): y += 1 return y @@ -3594,7 +3593,7 @@ def test_chunk_non_constant(self): with self.assertRaisesRegex(RuntimeError, 'argument \'chunks\' must be a constant'): @torch.jit.script def chunk_non_constant(x, y): - return x.chunk(y) + return x.chunk(int(y)) def test_unknown_builtin(self): with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'): @@ -3603,7 +3602,7 @@ def unknown_builtin(x): return x.splork(3) def test_expected_tensor_found_tuple(self): - with self.assertRaisesRegex(RuntimeError, 'expected a tensor value but found a Tuple'): + with self.assertRaisesRegex(RuntimeError, 'expected a tensor value but found'): @torch.jit.script def return_tuple_wrong(x): a = (x, x) diff --git a/test/test_torch.py b/test/test_torch.py index 7148ab69f82ec..4d3d443db39bc 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6131,8 +6131,7 @@ def test_tensor_shape_empty(self): # functions that operate over a dimension but don't reduce. @skipIfNoZeroSize def test_dim_function_empty(self): - # FIXME: enable CUDA tests. - devices = ['cpu'] # if not torch.cuda.is_available() else ['cpu', 'cuda'] + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] for device in devices: shape = (0, 1, 2, 0) x = torch.randn(shape, device=device) @@ -6188,40 +6187,52 @@ def test_dim_function_empty(self): self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)]) # gather - self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64)).shape) - self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64)).shape) - larger_shape = (0, 1, 3, 0) - self.assertEqual(larger_shape, torch.gather(x, 2, torch.empty(larger_shape, dtype=torch.int64)).shape) - smaller_shape = (0, 1, 0, 0) - self.assertEqual(smaller_shape, torch.gather(x, 2, torch.empty(smaller_shape, dtype=torch.int64)).shape) + self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape) + self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape) + larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device) + self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape) + smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device) + self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape) y = torch.randn((2, 3, 4), device=device) - self.assertEqual((0, 3, 4), torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64)).shape) + self.assertEqual((0, 3, 4), + torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape) # scatter, scatter_add for dim in [0, 2]: y = torch.randn(shape, device=device) y_src = torch.randn(shape, device=device) - self.assertEqual(shape, y.scatter_(dim, torch.empty(shape, dtype=torch.int64), y_src).shape) - self.assertEqual(shape, y.scatter_add_(dim, torch.empty(shape, dtype=torch.int64), y_src).shape) + ind = torch.empty(shape, dtype=torch.int64, device=device) + self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape) + self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape) z = torch.randn((2, 3, 4), device=device) z_src = torch.randn((2, 3, 4), device=device) - self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64), z_src)) - self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64), z_src)) + self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) + self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) # index_fill, index_copy, index_add c = x.clone() - ind_empty = torch.tensor([], dtype=torch.int64) - ind_01 = torch.tensor([0, 1], dtype=torch.int64) - self.assertEqual(c, c.index_fill_(0, ind_empty, -1)) - self.assertEqual(c, c.index_fill_(2, ind_empty, -1)) - self.assertEqual(c, c.index_fill_(2, torch.tensor([0, 1], dtype=torch.int64), -1)) - self.assertEqual(c, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) - self.assertEqual(c, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) - self.assertEqual(c, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) + c_clone = c.clone() + ind_empty = torch.tensor([], dtype=torch.int64, device=device) + ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device) + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1)) + self.assertEqual(c_clone, c.index_fill_(2, torch.tensor([0, 1], dtype=torch.int64, device=device), -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) + self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) + + c = torch.randn((0, 1, 2), device=device) + c_clone = c.clone() + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) # index fill/copy/add non-empty z = torch.randn((2, 3, 4), device=device) @@ -6237,6 +6248,10 @@ def test_dim_function_empty(self): self.assertEqual(x, x.index_select(2, ind_01)) z = torch.randn((2, 3, 4), device=device) # non-empty self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape) + c = torch.randn((0, 1, 2), device=device) + self.assertEqual(c, c.index_select(0, ind_empty)) + c = torch.randn((0, 1, 2), device=device) + self.assertEqual(c, c.index_select(0, ind_empty)) @skipIfNoZeroSize def test_blas_empty(self): diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 03e0f64169614..0a3d66ee2b0ed 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -11,8 +11,9 @@ #include "torch/csrc/autograd/functions/tensor.h" #include "torch/csrc/autograd/functions/basic_ops.h" #include "torch/csrc/jit/tracer.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/symbolic_variable.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/utils/variadic.h" #include "torch/csrc/autograd/functions/utils.h" @@ -55,14 +56,13 @@ static void setattr(jit::Node* n, jit::Symbol name, std::string v) { n-> template static void setattr(jit::Node* n, jit::Symbol name, std::array v) { n->is_(name, std::vector(v.begin(), v.end())); } -template -static jit::Value* createConstant(jit::Node* n, T value) { - return n->owningGraph()->createConstant(jit::as_tensor(value))->insertBefore(n)->output(); +static jit::Value* insertConstant(jit::Node* n, jit::IValue value) { + jit::WithInsertPoint guard(n); + return insertConstant(*n->owningGraph(), std::move(value)); } -template -static void genericInsertInput(jit::Node* n, size_t idx, T value) { - n->insertInput(idx, createConstant(n, value)); +static void genericInsertInput(jit::Node* n, size_t idx, jit::IValue value) { + n->insertInput(idx, insertConstant(n, std::move(value))); } void failPositionalAttr() { @@ -78,19 +78,18 @@ static void setposattr(jit::Node* n, size_t idx, const char *name, const at::Int auto info = ArgumentStash::popIntList(name); for (size_t i = 0; i < info.size(); ++i) { if (info[i] != nullptr) continue; - info[i] = createConstant(n, v[i]); + info[i] = insertConstant(n, v[i]); } - jit::TensorType expected_type {at::kLong, -1, {}}; for (jit::Value* v : info) { - if (*v->type() != expected_type) { + if (*v->type() != *jit::IntType::get()) { throw std::runtime_error( "Type mismatch in setposattr for IntList. Check that your program " "is valid without tracing, and please file a bug report if it is."); } } jit::WithInsertPoint insert_point{n}; - auto symbolic_info = fmap(info); - auto size = jit::SymbolicVariable::stack(symbolic_info, 0); + auto& g = *n->owningGraph(); + auto size = g.insertNode(g.createList(jit::IntType::get(), info))->output(); n->insertInput(idx, size); } else { return genericInsertInput(n, idx, v); diff --git a/tools/autograd/templates/variable_factories.h b/tools/autograd/templates/variable_factories.h index 08308bd29d582..5ec36bcff20e1 100644 --- a/tools/autograd/templates/variable_factories.h +++ b/tools/autograd/templates/variable_factories.h @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -38,5 +39,24 @@ namespace torch { AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR) #undef TENSOR +inline autograd::Variable from_blob( + void* data, + at::IntList sizes, + const std::function& deleter, + const at::TensorOptions& options = {}) { + at::Tensor tensor = + at::from_blob(data, sizes, deleter, options.discard_runtime_type()); + return autograd::make_variable( + tensor, /*requires_grad=*/options.requires_grad()); +} + +inline autograd::Variable from_blob( + void* data, + at::IntList sizes, + const at::TensorOptions& options = {}) { + return torch::from_blob(data, sizes, /*deleter=*/[](void*) {}, options); +} + ${function_definitions} + } // namespace torch diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py index 75a14b2d5550f..49948de7180a0 100644 --- a/tools/jit/gen_jit_dispatch.py +++ b/tools/jit/gen_jit_dispatch.py @@ -52,79 +52,74 @@ def jit_type_of(arg): typ = '{}?'.format(typ) return typ -# map from _jit type_, generated from jit_type_of to attribute used to store it -ATTR_METHOD_MAP = { - 'int': 'i', - 'float': 'f', - 'bool': 'i', - 'Scalar': 't', - 'int[]': 'is', - 'bool[]': 'is', - 'Layout': 'i', - 'Device': 'is', - 'ScalarType': 'i', -} - - -def attr_of(jit_type): - # for attributes, we dont care about the length of an array, - # so strip it from the type - jit_type = re.sub("\\[\d+\\]", "[]", jit_type) - return ATTR_METHOD_MAP[jit_type] - # map from aten 'simple_type' to the function that will cast a attribute value # to that type FROM_ATTRIBUTE = { - 'std::array': 'as_bool_array<2>', - 'std::array': 'as_bool_array<3>', - 'std::array': 'as_bool_array<4>', - 'Scalar': 'Scalar', - 'IntList': 'std::vector', - 'Layout': 'int64_t', - 'Device': 'std::vector', - 'ScalarType': 'int64_t', + 'Device': 'as_device(node->is(attr::{}))', + 'IntList': 'std::vector(node->is(attr::{}))', + 'Layout': 'static_cast(node->i(attr::{}))', + 'Scalar': 'Scalar(node->t(attr::{}))', + 'ScalarType': 'static_cast(node->i(attr::{}))', + 'Tensor': 'node->t(attr::{})', + 'bool': 'bool(node->i(attr::{}))', + 'double': 'node->f(attr::{})', + 'int64_t': 'node->i(attr::{})', + 'std::array': 'as_bool_array<2>(node->is(attr::{}))', + 'std::array': 'as_bool_array<3>(node->is(attr::{}))', + 'std::array': 'as_bool_array<4>(node->is(attr::{}))', } + +def from_attribute(arg): + simple_type = arg['simple_type'] + return FROM_ATTRIBUTE[simple_type].format(arg['name']) + + # map from aten 'simple_type' to the function that will turn a tensor into # that type -FROM_TENSOR = { - 'Device': 'tensor_as>', - 'ScalarType': 'tensor_as', - 'Layout': 'tensor_as', - 'IntList': 'tensor_as>', +FROM_IVALUE = { + 'Device': 'as_device({}.toIntList()->elements())', + 'IntList': '{}.toIntList()->elements()', + 'Layout': 'static_cast({}.toInt())', + 'Scalar': '{}.toScalar()', + 'ScalarType': 'static_cast({}.toInt())', + 'Tensor': '{}.toTensor()', + 'bool': '{}.toInt()', + 'double': '{}.toDouble()', + 'int64_t': '{}.toInt()', + 'std::array': 'as_bool_array<2>({}.toIntList()->elements())', + 'std::array': 'as_bool_array<3>({}.toIntList()->elements())', + 'std::array': 'as_bool_array<4>({}.toIntList()->elements())', } -def from_tensor(arg): +def from_ivalue(arg, value): simple_type = arg['simple_type'] - if simple_type in FROM_TENSOR: - return FROM_TENSOR[simple_type] - else: - return 'tensor_as<{}>'.format(arg['simple_type']) + return FROM_IVALUE[simple_type].format(value) -KW_ASSIGNMENT = CodeTemplate("""\ -auto ${name} = ${type_cast}(node->${method}(Symbol::attr("${name}")));\ -""") - -POS_ASSIGNMENT = CodeTemplate("""\ -auto ${name} = ${from_tensor}(std::move(peek(stack, ${i}, ${N})).toTensor());\ -""") +KW_ACCESS = CodeTemplate("""(node->${method}(Symbol::attr("${name}")))""") CALL_NAMESPACE = CodeTemplate("""\ -auto result = at::${name}(${args}); +auto result = at::${name}( + ${args} +); """) CALL_METHOD = CodeTemplate("""\ DeviceGuard device_guard(deviceForInputs(stack, ${num_dynamic_inputs})); -auto result = (${first}).${name}(${args}); +auto result = (${first}).${name}( + ${args} +); """) CALL_TENSOR_OPTIONS = CodeTemplate("""\ -const auto device_index = static_cast(device[1]); const auto options = TensorOptions() - .dtype(static_cast(dtype)) - .layout(static_cast(layout)) - .device({static_cast(device[0]), device_index}); -auto result = torch::${name}(${args}, options); + .dtype(${dtype}) + .layout(${layout}) + .device(${device}); +auto result = torch::${name}( + ${args}, + options +); """) # TODO (apaszke): remove the attributed codepath once we remove them @@ -133,7 +128,6 @@ def from_tensor(arg): ${kw_assignments} return Operation([=](Stack & stack) { autograd::profiler::RecordFunction record("${name}"); - ${pos_assignments} ${call} drop(stack, ${num_dynamic_inputs}); pack(stack, std::move(result)); @@ -204,13 +198,23 @@ def gen_jit_dispatch(declarations, out, template_path): ops = [] def get_invocation(decl, args, num_dynamic_inputs): + + # because the arg list can get lengthy we put them on a separate line + def pack_arguments(args): + return ',\n'.join(args) if decl.get('has_tensor_options'): - return CALL_TENSOR_OPTIONS.substitute(name=decl['name'], args=args[:-3]) + return CALL_TENSOR_OPTIONS.substitute(name=decl['name'], + args=pack_arguments(args[:-3]), + dtype=args[-3], + layout=args[-2], + device=args[-1]) elif 'namespace' in decl['method_of']: - return CALL_NAMESPACE.substitute(name=decl['name'], args=args, num_dynamic_inputs=num_dynamic_inputs) + return CALL_NAMESPACE.substitute(name=decl['name'], + args=pack_arguments(args), + num_dynamic_inputs=num_dynamic_inputs) else: return CALL_METHOD.substitute( - name=decl['name'], first=args[0], args=args[1:], + name=decl['name'], first=args[0], args=pack_arguments(args[1:]), num_dynamic_inputs=num_dynamic_inputs) def emit_decl_variant(decl, is_positional_arg, has_tensorlist): @@ -218,7 +222,6 @@ def emit_decl_variant(decl, is_positional_arg, has_tensorlist): # that indicates if the argument should come from the postional list # of inputs. If false, the argument comes from the constant attributes kw_assignments = [] - pos_assignments = [] arguments = [] if has_tensorlist: @@ -267,26 +270,12 @@ def emit_decl_variant(decl, is_positional_arg, has_tensorlist): .format(real_inputs, static_inputs)) elif arg['simple_type'] in default_only_types: arguments.append(arg['default']) - elif is_tensor_arg(arg): - arguments.append('std::move(peek(stack, {}, {})).toTensor()'.format(real_inputs, view_length)) + elif is_tensor_arg(arg) or is_positional_arg[i]: + value = '(std::move(peek(stack, {}, {})))'.format(real_inputs, view_length) + arguments.append(from_ivalue(arg, value)) real_inputs += 1 - elif is_positional_arg[i]: - template_kwargs = dict(from_tensor=from_tensor(arg), - name=arg['name'], - i=real_inputs, - N=view_length) - real_inputs += 1 - - assign = POS_ASSIGNMENT.substitute(**template_kwargs) - - pos_assignments.append(assign) - arguments.append(arg['name']) else: - attr_method = attr_of(jit_type_of(arg)) - simple_type = arg['simple_type'] - assign = KW_ASSIGNMENT.substitute(type_cast=FROM_ATTRIBUTE.get(simple_type, simple_type), - name=arg['name'], - method=attr_method) + assign = "auto {} = {};".format(arg['name'], from_attribute(arg)) kw_assignments.append(assign) arguments.append(arg['name']) @@ -296,9 +285,8 @@ def emit_decl_variant(decl, is_positional_arg, has_tensorlist): all_scalars = all(r['dynamic_type'] != 'TensorList' for r in returns) constructor = CONSTRUCTOR.substitute(name=decl['name'], - call=[call], # in an array so that substitute handles newlines correctly + call=call, kw_assignments=kw_assignments, - pos_assignments=pos_assignments, num_dynamic_inputs=num_dynamic_inputs) return constructor diff --git a/tools/jit/templates/register_aten_ops.cpp b/tools/jit/templates/register_aten_ops.cpp index 2f4d0558fe4cb..06ad9c2840b1c 100644 --- a/tools/jit/templates/register_aten_ops.cpp +++ b/tools/jit/templates/register_aten_ops.cpp @@ -2,7 +2,7 @@ #include "torch/csrc/autograd/profiler.h" #include "torch/csrc/jit/interned_strings.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/utils/functional.h" #include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/autograd/generated/variable_factories.h" @@ -56,6 +56,10 @@ std::array as_bool_array(const std::vector& vec) { return res; } +at::Device as_device(const std::vector& elements) { + return at::Device(static_cast(elements[0]), elements[1]); +} + RegisterOperators reg({ ${constructors} }); diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 02fd0428622c1..6b60b56fa9960 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -133,6 +133,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/import.cpp ${TORCH_SRC_DIR}/csrc/jit/interned_strings.cpp ${TORCH_SRC_DIR}/csrc/jit/interpreter.cpp + ${TORCH_SRC_DIR}/csrc/jit/constants.cpp ${TORCH_SRC_DIR}/csrc/jit/ir.cpp ${TORCH_SRC_DIR}/csrc/jit/operator.cpp ${TORCH_SRC_DIR}/csrc/jit/operator.cpp diff --git a/torch/csrc/api/include/torch/nn/cloneable.h b/torch/csrc/api/include/torch/nn/cloneable.h index 3b304a652d133..0393d4ed60e76 100644 --- a/torch/csrc/api/include/torch/nn/cloneable.h +++ b/torch/csrc/api/include/torch/nn/cloneable.h @@ -4,6 +4,9 @@ #include #include +#include +#include +#include #include #include @@ -29,7 +32,12 @@ class Cloneable : public Module { /// Performs a recursive "deep copy" of the `Module`, such that all parameters /// and submodules in the cloned module are different from those in the /// original module. - std::shared_ptr clone() const override { + std::shared_ptr clone( + at::optional device = at::nullopt) const override { + auto options = DefaultTensorOptions::get(); + OptionsGuard options_guard( + options.device(device.value_or(options.device()))); + const auto& self = static_cast(*this); auto copy = std::make_shared(self); copy->parameters_.clear(); @@ -43,8 +51,13 @@ class Cloneable : public Module { "Are you sure you called register_parameter() inside reset() " "and not the constructor?"); for (const auto& parameter : parameters_) { - copy->parameters_[parameter.key].data().copy_( - parameter->data(), /*non_blocking=*/true); + if (device) { + copy->parameters_[parameter.key].data().copy_( + parameter->data(), /*non_blocking=*/true); + } else { + at::detail::set_data( + copy->parameters_[parameter.key], parameter->data().clone()); + } } AT_CHECK( copy->buffers_.size() == buffers_.size(), @@ -53,8 +66,13 @@ class Cloneable : public Module { "Are you sure you called register_buffer() inside reset() " "and not the constructor?"); for (const auto& buffer : buffers_) { - copy->buffers_[buffer.key].data().copy_( - buffer->data(), /*non_blocking=*/true); + if (device) { + copy->buffers_[buffer.key].data().copy_( + buffer->data(), /*non_blocking=*/true); + } else { + at::detail::set_data( + copy->buffers_[buffer.key], buffer->data().clone()); + } } AT_CHECK( copy->children_.size() == children_.size(), @@ -63,17 +81,17 @@ class Cloneable : public Module { "Are you sure you called register_module() inside reset() " "and not the constructor?"); for (const auto& child : children_) { - copy->children_[child.key]->clone_(*child.value); + copy->children_[child.key]->clone_(*child.value, device); } return copy; } private: - void clone_(Module& other) final override { + void clone_(Module& other, at::optional device) final override { // Here we are *pretty* certain that `other's` type is `Derived` (because it // was registered under the same name as `this`), but you never know what // crazy things `reset()` does, so `dynamic_cast` just to be safe. - auto clone = std::dynamic_pointer_cast(other.clone()); + auto clone = std::dynamic_pointer_cast(other.clone(device)); AT_CHECK( clone != nullptr, "Attempted to clone submodule, but it is of a " diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index a80b608b52f07..fbf0ab3674f0a 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -40,8 +40,11 @@ class Module { const std::string& name() const noexcept; /// Performs a recursive deep copy of the module and all its registered - /// parameters, buffers and submodules. - virtual std::shared_ptr clone() const; + /// parameters, buffers and submodules, optionally setting the current device + /// to the one supplied before cloning. If no device is given, each + /// parameter and buffer will be moved to the device of its source. + virtual std::shared_ptr clone( + at::optional device = at::nullopt) const; /// Provides a means to traverse the `Module` tree. ModuleCursor modules(); @@ -144,7 +147,7 @@ class Module { template friend class detail::CursorBase; - virtual void clone_(Module& other); + virtual void clone_(Module& other, at::optional device); /// The implementation of the various `to()` methods. template diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/any.h index 595baee2532d2..7fb6abf1a1ead 100644 --- a/torch/csrc/api/include/torch/nn/modules/any.h +++ b/torch/csrc/api/include/torch/nn/modules/any.h @@ -8,6 +8,9 @@ #include #include +#include +#include + #include #include #include @@ -54,7 +57,7 @@ class AnyModule { /// Creates a deep copy of an `AnyModule` if it contains a module, else an /// empty `AnyModule` if it is empty. - AnyModule clone() const; + AnyModule clone(at::optional device = at::nullopt) const; /// Assigns a module to the `AnyModule` (to circumvent the explicit /// constructor). @@ -246,7 +249,8 @@ struct AnyModule::Placeholder : public AnyModule::Value::Placeholder { virtual std::unique_ptr copy() const = 0; /// Returns a `Placeholder` with a deep copy of this `AnyModule`. - virtual std::unique_ptr clone() const = 0; + virtual std::unique_ptr clone( + at::optional device) const = 0; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Holder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -308,9 +312,10 @@ struct AnyModule::Holder : public AnyModule::Placeholder { return torch::make_unique(*this); } - std::unique_ptr clone() const override { + std::unique_ptr clone( + at::optional device) const override { return torch::make_unique( - std::static_pointer_cast(module->clone())); + std::static_pointer_cast(module->clone(device))); } /// The actual concrete module instance. @@ -344,9 +349,9 @@ inline AnyModule& AnyModule::operator=(const AnyModule& other) { return *this; } -inline AnyModule AnyModule::clone() const { +inline AnyModule AnyModule::clone(at::optional device) const { AnyModule clone; - clone.content_ = content_ ? content_->clone() : nullptr; + clone.content_ = content_ ? content_->clone(device) : nullptr; return clone; } diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index 3b80d1044a2c1..bc33f8df74f75 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -26,7 +26,7 @@ class EmbeddingImpl : public torch::nn::Cloneable { Tensor forward(Tensor); EmbeddingOptions options; - Tensor table; + Tensor weight; }; TORCH_MODULE(Embedding); diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/sequential.h index 1c28656692e7f..4adf85ab06b39 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/sequential.h @@ -37,10 +37,11 @@ class SequentialImpl : public Cloneable { /// Special cloning function for `Sequential` because it does not use /// `reset()`. - std::shared_ptr clone() const override { + std::shared_ptr clone( + at::optional device = at::nullopt) const override { auto clone = std::make_shared(); for (const auto& module : modules_) { - clone->push_back(module.clone()); + clone->push_back(module.clone(device)); } return clone; } diff --git a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h index 09df4b2b1e441..bc3b60ab1423a 100644 --- a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h +++ b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -36,14 +36,8 @@ std::vector> replicate( std::vector> replicas; replicas.reserve(devices.size()); for (const auto& device : devices) { - // Here we rely on the property tensors are never (or should never be) - // allocated on any particular device, but always the default device, e.g. - // in `torch::ones({3, 4})`, the device is unspecified and pulled from the - // current thread local default options. As such, we can here modify these - // thread local default options and thereby cause all tensors in the cloned - // module to be constructed directly on the device we want. - OptionsGuard guard(device); - replicas.push_back(std::static_pointer_cast(module->clone())); + replicas.push_back( + std::static_pointer_cast(module->clone(device))); } return replicas; } diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index f21f6c5511b60..c1fec63aa8008 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -34,7 +35,7 @@ const std::string& Module::name() const noexcept { return *name_; } -std::shared_ptr Module::clone() const { +std::shared_ptr Module::clone(at::optional device) const { AT_ERROR( "clone() has not been implemented for ", name(), @@ -130,6 +131,6 @@ Tensor& Module::register_buffer(std::string name, Tensor tensor) { return buffers_.insert(std::move(name), std::move(tensor)); } -void Module::clone_(Module& other) {} +void Module::clone_(Module& other, at::optional device) {} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 54f6f314e0531..bcf41911a42a1 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -18,13 +18,13 @@ EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) } void EmbeddingImpl::reset() { - table = register_parameter( - "table", torch::empty({options.count_, options.dimension_})); - table.data().normal_(0, 1); + weight = register_parameter( + "weight", torch::empty({options.count_, options.dimension_})); + weight.data().normal_(0, 1); } Tensor EmbeddingImpl::forward(Tensor input) { - return torch::embedding(table, /*indices=*/input); + return torch::embedding(weight, /*indices=*/input); } } // namespace nn } // namespace torch diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 7654c4ee4c4b8..911208970984d 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -21,7 +21,7 @@ namespace torch { namespace autograd { Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge) - : TensorImpl(VariableType::getType(data)), + : TensorImpl(VariableType::getType(data), nullptr), data_(std::move(data)), grad_fn_(std::move(gradient_edge.function)), requires_grad_(false), diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index a91532f5af15d..acc8749e539a4 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -12,9 +12,9 @@ using value_map = std::unordered_map; using value_set = std::unordered_set; bool hasOneValuedInput(Node *n, torch::jit::Symbol name) { - auto maybe_t = n->get(name); + auto maybe_t = n->get(name); if (!maybe_t) return false; - return at::Scalar(*maybe_t).toDouble() == 1.0; + return maybe_t->toDouble() == 1.0; } bool isDifferentiable(Node * n) { @@ -33,6 +33,12 @@ bool isDifferentiable(Node * n) { if (!hasOneValuedInput(n, attr::alpha) || !hasOneValuedInput(n, attr::beta)) return false; } + auto isTensor = [](Value* v) { return v->type()->isSubtypeOf(*DynamicType::get()); }; + + if(!std::all_of(n->inputs().begin(), n->inputs().end(), isTensor) + || !std::all_of(n->outputs().begin(), n->outputs().end(), isTensor)) + return false; + if (n->kind() == aten::type_as && !n->inputs().at(1)->isTensor()) { return false; } @@ -90,7 +96,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->hasAttribute(attr::alpha)) { return {grads.at(0), grads.at(0) * at::Scalar(node->t(attr::alpha))}; } else { - return {grads.at(0), nullptr, grads.at(0) * node->input(attr::alpha)}; + return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; } case aten::sub: // o = self - alpha*other @@ -99,7 +105,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val } else if (node->hasAttribute(attr::alpha)) { return {grads.at(0), -grads.at(0) * at::Scalar(node->t(attr::alpha))}; } else { - return {grads.at(0), nullptr, grads.at(0) * node->input(attr::alpha)}; + return {grads.at(0), nullptr, grads.at(0) * node->namedInput(attr::alpha)}; } case aten::mul: // o = self * other @@ -119,7 +125,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val return {grads.at(0) * (outputs.at(0))}; case aten::chunk: case aten::split: - return {SymbolicVariable::cat(grads, node->input(attr::dim))}; + return {SymbolicVariable::cat(grads, node->namedInput(attr::dim))}; case aten::t: return {grads.at(0).t()}; case aten::neg: @@ -130,7 +136,7 @@ static std::vector gradientForNode(Node* node, ArrayRef grad_val case aten::type_as: return {grads.at(0).type_as(inputs.at(0))}; case aten::unsqueeze: - return {grads.at(0).squeeze(node->input(attr::dim))}; + return {grads.at(0).squeeze(node->namedInput(attr::dim))}; case aten::mm: { SymbolicVariable dmat1, dmat2; if (auto type = inputs.at(0).value()->type()->cast()) { diff --git a/torch/csrc/jit/constants.cpp b/torch/csrc/jit/constants.cpp new file mode 100644 index 0000000000000..1c8bf928aab5d --- /dev/null +++ b/torch/csrc/jit/constants.cpp @@ -0,0 +1,87 @@ +#include "torch/csrc/jit/constants.h" +#include "torch/csrc/jit/operator.h" +#include "torch/csrc/autograd/variable.h" + +namespace torch { namespace jit { + +// IValue -> Constant node +Value* insertConstant( + Graph& g, + IValue val, + at::optional loc) { + Node * n = g.create(prim::Constant); + if(val.isTensor()) { + at::Tensor ref = std::move(val).toTensor(); + JIT_ASSERT(ref.defined()); + n->output()->inferTypeFrom(ref); // note: before t_ because of std::move(ref) + n->t_(attr::value, std::move(ref)); + } else if(val.isInt()) { + n->i_(attr::value, val.toInt()); + n->output()->setType(IntType::get()); + } else if(val.isDouble()) { + n->f_(attr::value, val.toDouble()); + n->output()->setType(FloatType::get()); + } else if(val.isIntList()) { + n->is_(attr::value, val.toIntList()->elements()); + n->output()->setType(ListType::ofInts()); + } else { + throw std::runtime_error("Unsupported value kind: " + val.tagKind()); + } + if(loc) + n->setSourceLocation(std::make_shared(*loc)); + return g.insertNode(n)->output(); +} + +RegisterOperators reg({ + // Implementation of constant node, computes and IValue + Operator( + prim::Constant, + [](Node* node) -> Operation { + TypePtr type = node->output()->type(); + if(type->isSubtypeOf(*DynamicType::get())) { + auto t = autograd::make_variable(node->t(attr::value)); + return [t](Stack& stack) { + stack.push_back(t); + return 0; + }; + } else if ( + type->isSubtypeOf(*NumberType::get()) && + node->kindOf(attr::value) == AttributeKind::i) { + auto i = node->i(attr::value); + return [i](Stack& stack) { + push(stack, i); + return 0; + }; + } else if ( + type->isSubtypeOf(*NumberType::get()) && + node->kindOf(attr::value) == AttributeKind::f) { + auto f = node->f(attr::value); + return [f](Stack& stack) { + push(stack, f); + return 0; + }; + } else if(type->isSubtypeOf(*ListType::ofInts())) { + auto is = node->is(attr::value); + return [is](Stack& stack) { + push(stack, is); + return 0; + }; + } else { + std::stringstream ss; + ss << "constant literal not supported for: " << type->str(); + throw std::runtime_error(ss.str()); + } + }), +}); + +at::optional toIValue(Value* v) { + if(v->node()->kind() != prim::Constant) + return at::nullopt; + // use implemenation of prim::Constant to compute the output IValue + auto op = getOperation(v->node()); + Stack stack; + op(stack); + return stack.back(); +} + +}} diff --git a/torch/csrc/jit/constants.h b/torch/csrc/jit/constants.h new file mode 100644 index 0000000000000..35dc9f111aa82 --- /dev/null +++ b/torch/csrc/jit/constants.h @@ -0,0 +1,37 @@ +#pragma once +#include "ATen/ATen.h" +#include "torch/csrc/jit/ivalue.h" +#include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/script/lexer.h" +#include "torch/csrc/WindowsTorchApiMacro.h" + +// helpers for handling constants in the IR +// - create constant nodes from ints, floats, intlist, Tensors, and other types +// - implement primitive constant ops. +namespace torch { namespace jit { + +TORCH_API Value* insertConstant( + Graph& g, + IValue val, + at::optional loc = at::nullopt); + + +////////////////////////////////////////////////////////////////////////////////// +// Helper for retrieving constants +////////////////////////////////////////////////////////////////////////////////// + +// attempt to convert a (possibly constant) Value* into an intepreter value (IValue). +// returns at::nullopt if the Value* was not constant +TORCH_API at::optional toIValue(Value* v); + +// if a value is a constant then try to turn into type T using the +// same rules as the interpreter +template +at::optional constant_as(Value* v) { + if(auto ivalue = toIValue(v)) { + return ivalue->to(); + } + return at::nullopt; +} + +}} diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h index ec56f6144bfca..390f89f0398c0 100644 --- a/torch/csrc/jit/function_schema.h +++ b/torch/csrc/jit/function_schema.h @@ -1,6 +1,8 @@ #pragma once #include "ATen/ATen.h" + #include "torch/csrc/jit/type.h" +#include "torch/csrc/jit/ivalue.h" namespace torch { namespace jit { @@ -12,7 +14,7 @@ struct Argument { std::string name = "", TypePtr type = nullptr, at::optional N = at::nullopt, - at::optional default_value = at::nullopt, + at::optional default_value = at::nullopt, bool kwarg_only = true) : name(std::move(name)), type(type? type : DynamicType::get()), @@ -28,8 +30,7 @@ struct Argument { // become a list. at::optional N; - // encoded using as_tensor, use tensor_as to get value for attribute - at::optional default_value; + at::optional default_value; // is this only specifyable as a keyword argument? bool kwarg_only; }; diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 3e04369987eaf..7012a02bb23a3 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -3,7 +3,7 @@ #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/resource_guard.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/utils/disallow_copy.h" #include "torch/csrc/variable_tensor_functions.h" @@ -750,10 +750,15 @@ static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp"; // actually supports it or not, so we heuristically use the host // compiler to predict if the runtime compiler supports the option we // want. This probably won't work if you're cross-compiling. +// NB: -march=native is disabled because it has caused problems where +// compiler and assembler do not agree on what native instruction they +// understand for AVX512. When we need better CPU performance this +// optimization can be re-enabled by tracking down the platforms where +// this error occurs and only selectively disabling it. static const std::string compile_string = "\"${cxx}\" -O3 -g " #ifndef __PPC64__ - "-march=native " +// "-march=native " #endif "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; diff --git a/torch/csrc/jit/interned_strings.cpp b/torch/csrc/jit/interned_strings.cpp index d633514256dbb..c606f01871560 100644 --- a/torch/csrc/jit/interned_strings.cpp +++ b/torch/csrc/jit/interned_strings.cpp @@ -55,8 +55,9 @@ Symbol InternedStrings::_symbol(const std::string& s) { auto pos = s.find("::"); if (pos == std::string::npos) { - throw std::runtime_error( - "all symbols must have a namespace, ::"); + std::stringstream ss; + ss << "all symbols must have a namespace, ::, but found: " << s; + throw std::runtime_error(ss.str()); } Symbol ns = _symbol("namespaces::" + s.substr(0, pos)); diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index a0d6a7a7fef50..d279ca67da52f 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -43,6 +43,7 @@ _(prim, Undefined) \ _(prim, Starred) \ _(prim, TupleConstruct) \ _(prim, TupleUnpack) \ +_(prim, ListConstruct) \ _(prim, NumToTensor) \ _(prim, TensorToNum) \ _(prim, AutogradAdd) \ diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index 1dd6ea6c5877c..b66647c69df01 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -8,8 +8,9 @@ #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/ivalue.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/variable_tensor_functions.h" #include "torch/csrc/autograd/generated/variable_factories.h" @@ -61,14 +62,14 @@ Value* createTripCountConjunctiveCondition( // Emit initial comparison -- initial_trip_count < max_trip_count Value* initial_comparison_value = g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1)) - ->output(); + ->output()->setType(IntType::get()); // Replace initial condition with logical `and` of trip count and // initial condition Value* new_cond = g->insertNode( g->create(aten::__and__, {initial_comparison_value, cond}, 1)) - ->output(); + ->output()->setType(IntType::get()); return new_cond; } @@ -93,9 +94,7 @@ void desugarTripCounts(Block * b) { { WithInsertPoint guard(n); // int i = 0 - Value* initial_trip_count = - g->insertNode(g->createConstant(at::zeros({1}, at::kLong))) - ->output(); + Value* initial_trip_count = insertConstant(*g, 0); // Set up initial iteration number value for loop-carried dependency n->removeInput(0); // Input 0 is now initial termination condition, insert this after that. @@ -113,14 +112,12 @@ void desugarTripCounts(Block * b) { // increment the trip count at the end of the body. Then, emit the same // conjunctive stopping condition as above. - Value* const_one = - g->insertNode(g->createConstant(at::ones({1}, at::kLong))) - ->output(); + Value* const_one = insertConstant(*g, 1); Value* inc_trip_count = g->insertNode(g->create( - aten::add, {block_trip_count_input, const_one, const_one}, 1)) - ->output(); + aten::add, {block_trip_count_input, const_one}, 1)) + ->output()->setType(IntType::get()); body_block->insertOutput(1, inc_trip_count); Value* body_cond = createTripCountConjunctiveCondition( @@ -339,7 +336,7 @@ struct PreprocessGraph { struct ContainerTensor : public at::TensorImpl { public: ContainerTensor() - : TensorImpl(&(at::globalContext().getType(at::Backend::Undefined,at::ScalarType::Undefined))) {} + : TensorImpl(&(at::globalContext().getType(at::Backend::Undefined,at::ScalarType::Undefined)), nullptr) {} virtual ~ContainerTensor() {} virtual const char * toString() const override { @@ -401,7 +398,7 @@ struct CodeImpl { CodeImpl(std::shared_ptr& graph_) : preprocess(*graph_) { graph = preprocess.graph; - //std::cout << "into code graph:\n" << *graph << "\n"; + // std::cout << "into code graph:\n" << *graph << "\n"; insertNodesFromBlock(graph->block()); } @@ -411,7 +408,7 @@ struct CodeImpl { JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = tensor_as(pop(stack).toTensor()); + auto t = pop(stack).toInt(); return (t == 0) ? offset : 0; }; inst.debug_name = prim::JumpZ; @@ -423,7 +420,7 @@ struct CodeImpl { JIT_ASSERT(inst.debug_name == prim::Placeholder); auto offset = relativeJump(from_inst, to_inst); inst.callback = [offset](Stack & stack) { - auto t = tensor_as(pop(stack).toTensor()); + auto t = pop(stack).toInt(); return (t != 0) ? offset : 0; }; inst.debug_name = prim::JumpNZ; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index b68cec65cbf20..b782d4ae1d2c1 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -1,8 +1,9 @@ #include "ir.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/operator.h" #include "torch/csrc/autograd/function.h" +#include "torch/csrc/jit/constants.h" #include #include @@ -568,161 +569,56 @@ Value* Value::setUniqueName(const std::string & name) { return this; } -template -Value* Graph::insertConstant(T value) { - Node *n = create(prim::Constant); - insertNode(n); - auto t_value = as_tensor(value); - n->t_(attr::value, t_value.clone()); - n->output()->inferTypeFrom(t_value); - return n->output(); -} - -// This is necessary, because integral literals are of type int by default, -// and will dispatch to this function. -template<> -Value * Graph::insertConstant(int value) { - return insertConstant(static_cast(value)); -} - -template Value* Graph::insertConstant(int64_t value); -template Value* Graph::insertConstant(double value); -template Value* Graph::insertConstant(at::Tensor value); -template Value* Graph::insertConstant(at::IntList value); -template Value* Graph::insertConstant(at::Scalar value); - -namespace { - -// Of course any sane person would define this thing as a templated function, but -// it so happens that clang 3.8 has a pretty annoying bug which makes it complain that -// specializations are redefinitions of themselves, and so here we are. -template -struct getattr {}; - -template<> -struct getattr { - int64_t operator()(Node *n, Symbol name) { - return n->i(name); - } -}; - -template<> -struct getattr { - double operator()(Node *n, Symbol name) { - return n->f(name); - } -}; - -template<> -struct getattr { - at::Tensor operator()(Node *n, Symbol name) { - return n->t(name); - } -}; - -template<> -struct getattr> { - std::vector operator()(Node *n, Symbol name) { - return n->is(name); +std::pair findArgument(const FunctionSchema& the_schema, Symbol name) { + auto name_str = name.toUnqualString(); + for (size_t i = 0; i < the_schema.arguments.size(); ++i) { + const Argument* arg = &the_schema.arguments[i]; + if (arg->name == name_str) { + return std::make_pair(i, arg); + } } -}; - -} // anonymous namespace + throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); +} -template -at::optional Node::get(Symbol name) { +at::optional Node::get(Symbol name) const { // TODO (apaszke): remove. this is in here for now just so that we can ensure // we always use this in places where the node has a valid schema already // (will make next commits easier). - if (!schema_) findSchema(); - // TODO (apaszke): remove once tracer and compiler stop emitting attributes - if (hasAttributes()) { - // If it has an attribute, then it is a constant. If it's missing, it means we're - // doing an invalid lookup and it should throw anyway. - return getattr()(this, name); - } - auto inp = findInput(name); - const Argument & arg = inp.second; - if (!inp.first) { - return tensor_as(arg.default_value.value()); - } - Node *producer = inp.first->node(); - if (producer->kind() != prim::Constant) return at::nullopt; - auto value = producer->t(attr::value); - return tensor_as(std::move(value)); -} - -template at::optional Node::get(Symbol name); -template at::optional Node::get(Symbol name); -template at::optional Node::get(Symbol name); -template at::optional> Node::get(Symbol name); - -at::optional Node::get(Symbol name) { - // TODO (apaszke): remove once tracer and compiler stop emitting attributes if (hasAttribute(name)) { switch (kindOf(name)) { case AttributeKind::i: - return IValue{as_tensor(i(name))}; - case AttributeKind::t: - return IValue{as_tensor(t(name))}; + return IValue(i(name)); + case AttributeKind::f: + return IValue(f(name)); + case AttributeKind::t: { + // attributes are ambiguous, this might be a at::Scalar + // disambiguate via schema + at::Tensor ten = t(name); + const Argument* arg = findArgument(schema(), name).second; + if(arg->type->isSubtypeOf(*NumberType::get())) { + return IValue(at::Scalar(ten)); + } + return IValue(ten); + } case AttributeKind::is: - return IValue{as_tensor(is(name))}; + return IValue(is(name)); default: throw std::runtime_error("get() NYI"); } } - auto inp = findInput(name); - const Argument & arg = inp.second; - if (!inp.first) { - return IValue{arg.default_value.value()}; - } - Node * producer = inp.first->node(); - if (producer->kind() != prim::Constant) return at::nullopt; - auto value = producer->t(attr::value); - return IValue{std::move(value)}; -} - -Value* Node::input(Symbol name) { - // TODO (apaszke): remove once tracer and compiler stop emitting attributes - if (hasAttribute(name)) { - switch (kindOf(name)) { - case AttributeKind::i: - return owningGraph()->insertConstant(i(name)); - case AttributeKind::is: - return owningGraph()->insertConstant(is(name)); - case AttributeKind::t: - return owningGraph()->insertConstant(t(name)); - default: - throw std::runtime_error("getValue() NYI"); - } - } - auto inp = findInput(name); - if (inp.first) return inp.first; - return owningGraph()->insertConstant(inp.second.default_value.value()); + return toIValue(namedInput(name)); } -// XXX: the first coordinate can be a nullptr, which means that you should use -// the default value for this arg, because it's optional and missing -std::pair Node::findInput(Symbol name) { - if (!schema_) { - findSchema(); +Value* Node::namedInput(Symbol name) const { + if(hasAttribute(name)) { + // XXX - const cast because this really should not be modifying graph + // and once we remove attributes it no longer will + Value* v = insertConstant(const_cast(*owningGraph()), get(name).value()); + // XXX - insert point can be anywhere since modifying the graph is unexpected, + // so this is completely unsafe and needs to be gone as soon as possible. + return v; } - auto name_str = name.toUnqualString(); - size_t input_i = 0; - for (size_t i = 0; i < schema_->arguments.size(); ++i) { - const auto & arg = schema_->arguments[i]; - if (hasAttributeS(arg.name)) continue; - if (arg.name == name_str) { - if (input_i < inputs().size()) { - return std::pair(input(input_i), arg); - } else { - JIT_ASSERT(arg.default_value); - return std::pair(nullptr, arg); - } - } - input_i++; - } - throw std::runtime_error(std::string("Couldn't find an argument called ") + name.toQualString()); + return input(findArgument(schema(), name).first); } bool Node::matches(const char *signature_literal, at::ArrayRef const_inputs) { @@ -733,7 +629,7 @@ bool Node::matches(const char *signature_literal, at::ArrayRef const_inp return true; } -void Node::findSchema() { +void Node::findSchema() const { schema_ = &getOperatorFor(this).schema; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index a2f71f702bc9b..229b14aa107d1 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1,3 +1,4 @@ + #pragma once #include "torch/csrc/jit/attributes.h" @@ -290,7 +291,8 @@ struct Node : public Attributes { // This field is effective a cache that's populated on attribute lookups and // invalidated every time we perform an operation that could potentially change // the schema. - const FunctionSchema* schema_; + // note: mutable because schema_ is effectively a cache + mutable const FunctionSchema* schema_; protected: Node(Graph * graph_, NodeKind kind_); //defined after graph public: @@ -392,14 +394,21 @@ struct Node : public Attributes { Value * input(size_t i) { return inputs_.at(i); } - const Value * input(size_t i) const { + Value * input(size_t i) const { return inputs_.at(i); } + Value* namedInput(Symbol name) const; + + at::optional get(Symbol name) const; + template - at::optional get(Symbol name); - at::optional get(Symbol name); - Value* input(Symbol name); + at::optional get(Symbol name) const { + if(auto v = get(name)) + return v->template to(); + return at::nullopt; + } + // Returns true if the value of input name is statically known bool is_constant(Symbol name) { @@ -666,15 +675,16 @@ struct Node : public Attributes { // XXX: this function is meant to be used with string literals only! bool matches(const char *signature_literal, at::ArrayRef const_inputs={}); - const FunctionSchema& schema() { - if (!schema_) findSchema(); + const FunctionSchema& schema() const { + if (!schema_) + findSchema(); return *schema_; } virtual ~Node() {} private: std::pair findInput(Symbol name); - void findSchema(); + void findSchema() const; // Lookup iterator in use list of _input i_ that corresponds to its use of _this_ use_list::iterator findUseForInput(size_t i) { auto & input_uses = inputs_[i]->uses_; @@ -974,13 +984,7 @@ friend struct Block; Node * createUndefined() { return create(prim::Undefined); } - Node * createConstant(const at::Tensor& ref) { - JIT_ASSERT(ref.defined()); - auto n = create(prim::Constant); - n->t_(attr::value, ref.clone()); - n->output()->inferTypeFrom(ref); - return n; - } + Node * createFusionGroup(int device) { auto n = create(prim::FusionGroup, 0); n->g_(attr::Subgraph,std::make_shared(scope_root_)); @@ -1002,6 +1006,25 @@ friend struct Block; } return n; } + Node* createList(const TypePtr& elem_type, at::ArrayRef values) { + auto n = create(prim::ListConstruct, values); + for(const auto & v : values) { + JIT_ASSERT(v->type()->isSubtypeOf(*elem_type)); + } + n->output()->setType(std::make_shared(elem_type)); + return n; + } + Node* createNumToTensor(Value* value) { + auto typ = value->type(); + Node * result = create(prim::NumToTensor, {value}); + result->output()->setType(TensorType::fromNumberType(typ)); + return result; + } + Node* createTensorToNum(const TypePtr& type, Value* value) { + auto* result = create(prim::TensorToNum, {value}); + result->output()->setType(type); + return result; + } Node* createPythonOp( THPObjectPtr&& pyobj, const std::string& cconv, @@ -1028,9 +1051,6 @@ friend struct Block; return r; } - template - Value * insertConstant(T value); - Node * appendNode(Node * n) { return block_->appendNode(n); } diff --git a/torch/csrc/jit/ivalue.h b/torch/csrc/jit/ivalue.h index c31436de5ab10..06187b8054cf0 100644 --- a/torch/csrc/jit/ivalue.h +++ b/torch/csrc/jit/ivalue.h @@ -87,6 +87,10 @@ using DoubleList = ConstantList; // The tag is currently 4 bytes to determine the type, and 1 byte // to mark whether that type is a subtype of at::Retainable and needs // retain/release calls. + +#define TORCH_FORALL_TAGS(_) \ + _(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) + struct IValue { IValue() : payload(0) @@ -171,11 +175,15 @@ struct IValue { : tag(Tag::Int), retainable(false) { as_int = i; } + // allow you to pass literals (3, 4) without ambiguity IValue(int32_t i) : IValue(static_cast(i)) {} + IValue(bool b) + : IValue(static_cast(b)) {} bool isInt() const { return Tag::Int == tag; } + int64_t toInt() const { JIT_ASSERT(isInt()); return as_int; @@ -183,6 +191,9 @@ struct IValue { // IntList IValue(Shared v); + IValue(std::vector v); + IValue(at::ArrayRef v) + : IValue(std::vector(v.begin(), v.end())) {} bool isIntList() const { return Tag::IntList == tag; } Shared toIntList() && { JIT_ASSERT(isIntList()); @@ -193,8 +204,11 @@ struct IValue { return toRetainable(); } + std::vector copyToIntList() const; + // DoubleList IValue(Shared v); + IValue(std::vector v); bool isDoubleList() const { return Tag::DoubleList == tag; } Shared toDoubleList() && { JIT_ASSERT(isDoubleList()); @@ -209,6 +223,51 @@ struct IValue { return Tag::None == tag; } + // Scalar, which gets encoded as either an Int or a Double + IValue(at::Scalar s) + : IValue() { + if(s.isFloatingPoint()) { + *this = s.toDouble(); + } else { + *this = s.toLong(); + } + } + bool isScalar() { + return isDouble() || isInt(); + } + at::Scalar toScalar() const { + if(isDouble()) + return toDouble(); + else if(isInt()) + return toInt(); + else + throw std::runtime_error("IValue is not a Scalar"); + } + + // for debugging + std::string tagKind() { + switch(tag) { + #define DEFINE_CASE(x) case Tag::x: return #x; + TORCH_FORALL_TAGS(DEFINE_CASE) + #undef DEFINE_CASE + } + return "Invalid Tag"; + } + + // generic v.to() implementations + // that can be used in special functions like pop/push + // that use template meta-programming. + // prefer the directly named methods when you can, + // since they are simpler to understand + + // Note: if you get linker errors saying one of these is missing, + // change it to ... && = delete; and you will see better error messages for why + // However, we cannot commit this because some compiler versions barf on it. + template + T to() &&; + template + T to() const &; + private: template Shared moveToRetainable() { @@ -226,7 +285,9 @@ struct IValue { retainable = false; } enum class Tag : uint32_t { - None, Tensor, Double, Int, Tuple, IntList, DoubleList + #define DEFINE_TAG(x) x, + TORCH_FORALL_TAGS(DEFINE_TAG) + #undef DEFINE_TAG }; union { at::TensorImpl* as_tensor_impl; @@ -241,6 +302,29 @@ struct IValue { bool retainable; }; +#undef TORCH_FORALL_TAGS + + +#define DEFINE_TO(type, method_name) \ +template<> \ +inline type IValue::to() && { \ + return std::move(*this).method_name(); \ +} \ +template<> \ +inline type IValue::to() const & { \ + return this->method_name(); \ +} +DEFINE_TO(at::Tensor, toTensor) +DEFINE_TO(Shared, toTuple) +DEFINE_TO(double, toDouble) +DEFINE_TO(int64_t, toInt) +DEFINE_TO(Shared, toDoubleList) +DEFINE_TO(Shared, toIntList) +DEFINE_TO(at::Scalar, toScalar) +DEFINE_TO(bool, toInt) +DEFINE_TO(std::vector, copyToIntList) + +#undef DEFINE_TO // non-mutable list template @@ -257,6 +341,9 @@ struct ConstantList : at::Retainable { at::ArrayRef elements() const { return elements_; } + operator at::ArrayRef() const { + return elements(); + } }; inline IValue::IValue(Shared v) @@ -268,11 +355,18 @@ inline IValue::IValue(Shared v) : tag(Tag::IntList), retainable(true) { as_retainable = v.detach(); } +inline IValue::IValue(std::vector v) +: IValue(IntList::create(std::move(v))) {} inline IValue::IValue(Shared v) : tag(Tag::DoubleList), retainable(true) { as_retainable = v.detach(); } +inline IValue::IValue(std::vector v) +: IValue(DoubleList::create(std::move(v))) {} +inline std::vector IValue::copyToIntList() const { + return std::vector(toIntList()->elements()); +} }} diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 652eb90bf3797..26e314c53eaa3 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -2,7 +2,7 @@ #include "torch/csrc/jit/script/lexer.h" #include "torch/csrc/jit/script/tree.h" #include "torch/csrc/jit/operator.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/script/error_report.h" namespace torch { namespace jit { @@ -51,23 +51,16 @@ struct SchemaParser { {"Layout", IntType::get() }, {"Device", ListType::ofInts() }, {"Scalar", NumberType::get() }, + {"float", FloatType::get() }, + {"int", IntType::get() }, + {"bool", IntType::get() }, // TODO: add separate bool type }; - switch(L.cur().kind) { - case TK_FLOAT: - L.next(); - return FloatType::get(); - case TK_INT: - case TK_BOOL: // TODO: add separate bool type - L.next(); - return IntType::get(); - default: - auto tok = L.expect(TK_IDENT); - auto text = tok.text(); - auto it = type_map.find(text); - if(it == type_map.end()) - throw ErrorReport(tok.range) << "unknown type specifier"; - return it->second; - } + auto tok = L.expect(TK_IDENT); + auto text = tok.text(); + auto it = type_map.find(text); + if(it == type_map.end()) + throw ErrorReport(tok.range) << "unknown type specifier"; + return it->second; } void parseType(Argument& arg) { arg.type = parseBaseType(); @@ -103,65 +96,70 @@ struct SchemaParser { parseType(arg); args.push_back(std::move(arg)); } - at::Tensor parseSingleConstant(TypeKind kind) { + IValue parseSingleConstant(TypeKind kind) { switch(L.cur().kind) { case TK_TRUE: L.next(); - return one(); + return true; case TK_FALSE: L.next(); - return zero(); - case TK_FLOAT: - L.next(); - return as_tensor(static_cast(at::kFloat)); + return false; case TK_IDENT: { auto tok = L.next(); auto text = tok.text(); - if("cpu" == text) { - return as_tensor(static_cast(at::Device::Type::CPU)); + if("float" == text) { + return static_cast(at::kFloat); + } else if("cpu" == text) { + return static_cast(at::Device::Type::CPU); } else if("strided" == text) { - return as_tensor(static_cast(at::kStrided)); + return static_cast(at::kStrided); } else if("ElementwiseMean" == text) { - return as_tensor(static_cast(Reduction::ElementwiseMean)); + return static_cast(Reduction::ElementwiseMean); } else { throw ErrorReport(L.cur().range) << "invalid numeric default value"; } - } default: + } + default: std::string n; if(L.nextIf('-')) n = "-" + L.expect(TK_NUMBER).text(); else n = L.expect(TK_NUMBER).text(); if(kind == TypeKind::FloatType || n.find(".") != std::string::npos || n.find("e") != std::string::npos) { - return at::full({}, std::stod(n), at::kDouble); // float? + return std::stod(n); } else { int64_t v = std::stoll(n); - return at::full({}, v, at::kLong); + return v; } } } - at::Tensor parseConstantList(TypeKind kind) { + IValue convertToList(TypeKind kind, const SourceRange& range, std::vector vs) { + switch(kind) { + case TypeKind::FloatType: + return fmap(vs, [](IValue v) { + return v.toDouble(); + }); + case TypeKind::IntType: + return fmap(vs, [](IValue v) { + return v.toInt(); + }); + default: + throw ErrorReport(range) << "lists are only supported for float or int types."; + } + } + IValue parseConstantList(TypeKind kind) { auto tok = L.expect('['); - std::vector vs; + std::vector vs; if(L.cur().kind != ']') { do { vs.push_back(parseSingleConstant(kind)); } while(L.nextIf(',')); } L.expect(']'); - if(vs.size() == 0) { - switch(kind) { - case TypeKind::FloatType: - return at::empty({}, at::kFloat); - case TypeKind::IntType: - return at::empty({}, at::kLong); - default: - throw ErrorReport(tok) << "empty lists are only supported for float or int types."; - } - } - return at::stack(vs); + return convertToList(kind, tok.range, std::move(vs)); } - at::Tensor parseTensorDefault(const SourceRange& range) { + + IValue parseTensorDefault(const SourceRange& range) { if("None" == L.expect(TK_IDENT).text()) { return at::Tensor(); } else { @@ -184,7 +182,9 @@ struct SchemaParser { if(L.cur().kind == TK_IDENT) { arg.default_value = parseTensorDefault(range); } else if(arg.N && L.cur().kind != '[') { - arg.default_value = parseSingleConstant(elem_kind->kind()).expand({*arg.N}); + IValue v = parseSingleConstant(elem_kind->kind()); + std::vector repeated(*arg.N, v); + arg.default_value = convertToList(elem_kind->kind(), range, repeated); } else { arg.default_value = parseConstantList(elem_kind->kind()); } @@ -209,14 +209,6 @@ struct SchemaParser { } Lexer L; bool kwarg_only; - static at::Tensor one() { - static at::Tensor v = at::full({}, 1, at::kLong); - return v; - } - static at::Tensor zero() { - static at::Tensor v = at::full({}, 0, at::kLong); - return v; - } }; } @@ -346,23 +338,10 @@ at::optional attributeKindOf(TypePtr type) { } bool typeMatches(TypePtr actual, TypePtr formal) { - if(actual->isSubtypeOf(*formal)) - return true; - - // XXX - this is here because we allow tensors to be used in place of numbers - // or lists of numbers in the script because of the restriction that all inputs to script must be tensors. - // Once numbers are always treated as seperate types from Tensors, this line - // should be removed, since it opens up the possibility of ambigous declarations - // dispatching to the wrong implementation. - if ((formal->isSubtypeOf(*NumberType::get()) || - formal->isSubtypeOf(*ListType::ofInts())) && - actual->isSubtypeOf(*DynamicType::get())) - return true; - - return false; + return actual->isSubtypeOf(*formal); } -bool Operator::matches(Node* node) const { +bool Operator::matches(const Node* node) const { if (node->kind().toQualString() != schema.name) { return false; } @@ -420,7 +399,7 @@ bool Operator::matches(Node* node) const { return true; } -std::shared_ptr findOperatorFor(Node* node) { +std::shared_ptr findOperatorFor(const Node* node) { const auto& candidates = getAllOperatorsFor(node->kind()); for(const auto& candidate : candidates) { if(candidate->matches(node)) { @@ -430,7 +409,7 @@ std::shared_ptr findOperatorFor(Node* node) { return nullptr; } -const Operator& getOperatorFor(Node* node) { +const Operator& getOperatorFor(const Node* node) { auto op = findOperatorFor(node); if(op) return *op; diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 74f0ec95d8bce..02202160e146f 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -32,7 +32,7 @@ struct TORCH_API Operator { FunctionSchema schema; - bool matches(Node* n) const; + bool matches(const Node* n) const; // Operators have different versions depending on if some inputs are encoded // as attributes or inputs. This function returns the right Operation function, // given a node encoded for one variant. @@ -45,14 +45,17 @@ struct TORCH_API Operator { return op(n); } } + bool hasAttributedVersion() const { + return op_const_attributes != nullptr; + } private: OperationCreator op; OperationCreator op_const_attributes; }; const std::vector>& getAllOperatorsFor(Symbol name); -std::shared_ptr findOperatorFor(Node* node); -const Operator& getOperatorFor(Node* node); +std::shared_ptr findOperatorFor(const Node* node); +const Operator& getOperatorFor(const Node* node); inline Operation getOperation(Node* node) { // note: getOperatorFor ensures that getOperatorFor(node).matches(node) == true diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 0d182bd8fd37c..24d6979fb80a4 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -2,6 +2,7 @@ #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/interned_strings.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/utils/functional.h" #include @@ -181,7 +182,8 @@ void BatchMMBlock(Block* block) { if (!root || root.tree_size < min_fusion_size) continue; auto matmuls = root.gatherMatMuls(); - auto type = root.node->output()->type()->expect(); + auto type_ = root.node->output()->type(); + auto type = type_->expect(); auto batch_inputs = [&](Side s, std::array cat_sizes) -> Value* { int inputs_off = s == Side::LHS ? 0 : 1; @@ -190,7 +192,7 @@ void BatchMMBlock(Block* block) { auto inputs = fmap(matmuls, [=](Node *mm) { return mm->inputs()[inputs_off]; }); WithInsertPoint iguard { root.node }; - inputs.push_back(graph->insertConstant(cat_dim)); + inputs.push_back(insertConstant(*graph, cat_dim)); Node *cat = graph->insertNode(graph->create(aten::cat, inputs)); cat->output()->setType(type->withSizes(cat_sizes)); return cat->output(); @@ -199,7 +201,7 @@ void BatchMMBlock(Block* block) { auto lhs_batch = batch_inputs(Side::LHS, root.lhs_sizes); auto rhs_batch = batch_inputs(Side::RHS, root.rhs_sizes); Node *batch_mm = graph->create(aten::mm, {lhs_batch, rhs_batch}); - batch_mm->output()->setType(type->asShared()); + batch_mm->output()->setType(type_); batch_mm->insertBefore(root.node); root.node->output()->replaceAllUsesWith(batch_mm->output()); // NB: don't bother with cleaning up after yourself. We'll use DCE for that. diff --git a/torch/csrc/jit/passes/decompose_addmm.cpp b/torch/csrc/jit/passes/decompose_addmm.cpp index 1a0dd94f9960d..3101e12474308 100644 --- a/torch/csrc/jit/passes/decompose_addmm.cpp +++ b/torch/csrc/jit/passes/decompose_addmm.cpp @@ -1,7 +1,7 @@ #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/passes/decompose_addmm.h" #include "torch/csrc/jit/symbolic_variable.h" -#include "torch/csrc/jit/tensor_conversions.h" + namespace torch { namespace jit { @@ -16,10 +16,10 @@ static void DecomposeAddmm(Block* block) { // shape analysis and differentiation passes for those two individual ops. // Later, we will fuse together those two ops into a single addmm. if (it->kind() == aten::addmm && it->inputs().size() == 3) { - auto alpha = it->get(attr::alpha); - auto beta = it->get(attr::beta); + auto alpha = it->get(attr::alpha); + auto beta = it->get(attr::beta); if (!alpha || !beta) continue; - if (tensor_as(*alpha) != 1.0 || tensor_as(*beta) != 1.0) continue; + if (alpha->toDouble() != 1.0 || beta->toDouble() != 1.0) continue; WithInsertPoint guard(*it); diff --git a/torch/csrc/jit/passes/erase_number_types.cpp b/torch/csrc/jit/passes/erase_number_types.cpp index 03e77e32da122..91f08c0941e7c 100644 --- a/torch/csrc/jit/passes/erase_number_types.cpp +++ b/torch/csrc/jit/passes/erase_number_types.cpp @@ -1,15 +1,8 @@ #include "torch/csrc/jit/passes/erase_number_types.h" +#include "torch/csrc/jit/constants.h" namespace torch { namespace jit { -static bool isNumberTypeCast(const Value* value, const Use& use) { - auto* node = use.user; - if (node->kind() != aten::type_as) { - return false; - } - return node->inputs()[0] == value; -} - static void EraseNumberTypesOnBlock(Block* block) { for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; ++it) { @@ -18,26 +11,24 @@ static void EraseNumberTypesOnBlock(Block* block) { } switch (it->kind()) { case prim::Constant: { - it->output()->inferTypeFrom(it->t(attr::value)); - } break; - case prim::TensorToNum: { - it->output()->replaceAllUsesWith(it->inputs()[0]); - // Let DCE cleanup + // remove primitive constants, replacing with tensor equivalent + // ONNX does not support non-tensor constants + if(it->output()->type()->isSubtypeOf(*NumberType::get())) { + auto s = *constant_as(it->output()); + WithInsertPoint guard(*it); + Value* r = insertConstant(*block->owningGraph(), s.toTensor()); + it->output()->replaceAllUsesWith(r); + } } break; + case prim::TensorToNum: case prim::NumToTensor: { - auto* ten = it->output(); - for (const auto& use : ten->uses()) { - if (isNumberTypeCast(ten, use)) { - use.user->output()->replaceAllUsesWith(ten); - } - } - ten->replaceAllUsesWith(it->inputs()[0]); + it->output()->replaceAllUsesWith(it->inputs()[0]); // Let DCE cleanup } break; default: { for(auto o : it->outputs()) { if (o->type()->isSubtypeOf(*NumberType::get())) { - o->setType(DynamicType::get()); + o->setType(TensorType::fromNumberType(o->type())); } } } break; diff --git a/torch/csrc/jit/passes/erase_number_types.h b/torch/csrc/jit/passes/erase_number_types.h index 5ec43ce575b86..561078f92a18e 100644 --- a/torch/csrc/jit/passes/erase_number_types.h +++ b/torch/csrc/jit/passes/erase_number_types.h @@ -5,13 +5,12 @@ namespace torch { namespace jit { // Erase NumberType information. This is necessary for and only used in -// exporting to ONNX. -// +// exporting to ONNX. This pass ensures that no remaining Values have +// NumberType types, replacing them with tensors. // The following things are done to erase NumberType info: // - NumberType outputs are changed to DynamicType. -// - Any aten::type_as nodes that are added to correct Number math -// are removed because ONNX export does not support them. -// - prim::Constant nodes' outputs get assigned their default type from ir.h +// - prim::Constant nodes which are numbers get changed into 0-dim tensors of +// the corresponding type // - prim::TensorToNum, and prim::NumToTensor nodes are erased. // // The pass assumes that DCE will be called sometime after. diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index bbf05f198a3b6..c3a7e60cf3777 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -2,8 +2,9 @@ #include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/jit/symbolic_variable.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/passes/dead_code_elimination.h" +#include "torch/csrc/jit/constants.h" namespace torch { namespace jit { @@ -132,20 +133,35 @@ void repeatBody(Block *body, int64_t times) { EliminateDeadCode(body, false); } +//TODO(zach): we need to replace these with a generic facility for resolving overloads +// currently we cant us SymbolicVariable because it assumes we are computing on tensors +// once we have something like emitBuiltinCall usable outside of the compiler, +// we can replace these with symbolic variable +Value* intMath(Symbol sym, Value* a, Value* b) { + auto& g = *a->owningGraph(); + return g.insertNode(g.create(sym, {a, b})) + ->output() + ->setType(IntType::get()); +} +Value* intMath(Symbol sym, Value* a, int64_t b) { + return intMath(sym, a, insertConstant(*a->owningGraph(), b)); +} + // Replaces the builtin loop counter with a "mutable" variable outside of the loop. void replaceLoopCounter(Node *loop) { Graph *graph = loop->owningGraph(); Block *body = loop->blocks().at(0); - Node *init_counter_node = graph->createConstant(at::CPU(at::kLong).scalarTensor(0)) - ->insertBefore(loop); - loop->insertInput(2, init_counter_node->output()); + WithInsertPoint guard(loop); + Value* init_counter = insertConstant(*graph, 0); + + loop->insertInput(2, init_counter); loop->insertOutput(0); Value * internal_counter = body->insertInput(1); body->inputs()[0]->replaceAllUsesWith(internal_counter); WithInsertPoint insertPointGuard{ body->return_node() }; - body->insertOutput(1, SymbolicVariable(internal_counter) + at::Scalar(1)); + body->insertOutput(1, intMath(aten::add, internal_counter, 1) ); } void unroll(Node *loop) { @@ -183,10 +199,10 @@ void unroll(Node *loop) { repeatBody(body, kUnrollFactor); // Change the iteration counts of both loops - SymbolicVariable iter_count = loop->inputs().at(0); - SymbolicVariable unrolled_iter_count = iter_count / kUnrollFactor; + Value* iter_count = loop->inputs().at(0); + Value* unrolled_iter_count = intMath(aten::div, iter_count, kUnrollFactor); loop->replaceInput(0, unrolled_iter_count); - loop_epilogue->replaceInput(0, iter_count - (unrolled_iter_count * kUnrollFactor)); + loop_epilogue->replaceInput(0, intMath(aten::sub, iter_count, intMath(aten::mul, unrolled_iter_count , kUnrollFactor))); } void UnrollLoops(Block *block) { diff --git a/torch/csrc/jit/passes/lower_grad_of.cpp b/torch/csrc/jit/passes/lower_grad_of.cpp index ddd31bce5b3d4..1ed492fc1dd04 100644 --- a/torch/csrc/jit/passes/lower_grad_of.cpp +++ b/torch/csrc/jit/passes/lower_grad_of.cpp @@ -10,8 +10,11 @@ void LowerGradOf(Graph& g) { // else: // outputs = undefineds WithInsertPoint guard(*it); - auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs())); - auto if_stat = g.insertNode(g.create(prim::If,{cond->output()}, it->outputs().size())); + auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs())) + ->output() + ->setType(IntType::get()); + auto if_stat = g.insertNode( + g.create(prim::If, {cond}, it->outputs().size())); if_stat->addBlock()->cloneFrom( it->blocks().at(0), [](Value* v) { return v; }); auto else_block = if_stat->addBlock(); diff --git a/torch/csrc/jit/passes/onnx/peephole.cpp b/torch/csrc/jit/passes/onnx/peephole.cpp index b3c87c914d159..fc8845a5d1860 100644 --- a/torch/csrc/jit/passes/onnx/peephole.cpp +++ b/torch/csrc/jit/passes/onnx/peephole.cpp @@ -259,6 +259,22 @@ void pushPackingPastRnn(Block *b) { newPackPadded->addInput(next->outputs()[0]); newPackPadded->addInput(n->inputs()[1]); + // See https://github.com/pytorch/pytorch/issues/9043 for a full + // description. Since PackPadded is for now treated in an + // unhygenic way, Pytorch ends up propagating an incorrect type. + // Until a long-term cleanup comes around, we can fix this by + // resetting the size to the correct value. + TensorType* oldType = rnn->inputs()[0]->type()->cast(); + if (oldType) { + std::vector new_sizes; + new_sizes.push_back(oldType->sizes()[0]); + new_sizes.push_back(oldType->sizes()[1]); + new_sizes.push_back(rnn->i(attr::hidden_size)); + TensorTypePtr newType = std::make_shared( + oldType->scalarType(), oldType->device(), new_sizes); + next->outputs()[0]->setType(newType); + } + it.destroyCurrent(); } } diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index feebbcf2fd505..2ee777aee0e66 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -1,7 +1,7 @@ #include "torch/csrc/jit/passes/peephole.h" #include "torch/csrc/jit/symbolic_variable.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/passes/dead_code_elimination.h" namespace torch { namespace jit { @@ -27,7 +27,7 @@ void PeepholeOptimize(Block * block) { if (node->matches("aten::expand(Tensor self, int[] size, *, int implicit) -> Tensor", /*with_const=*/attr::size)) { // x.expand(x.size()) == x - if (auto input_type = node->input(attr::self)->type()->cast()) { + if (auto input_type = node->namedInput(attr::self)->type()->cast()) { auto expanded_sizes = node->get>(attr::size); if (expanded_sizes == input_type->sizes()) { node->output()->replaceAllUsesWith(node->input()); @@ -51,7 +51,7 @@ void PeepholeOptimize(Block * block) { } else if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", /*with_const=*/attr::alpha)) { // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z - if (tensor_as(node->get(attr::alpha).value()) == 1.) { + if (node->get(attr::alpha).value().toDouble() == 1.) { // Look for mm from both sides of the add for (size_t mm_side = 0; mm_side < 2; mm_side++) { if (node->input(mm_side)->node()->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { @@ -69,6 +69,11 @@ void PeepholeOptimize(Block * block) { } } } + } else if(node->kind() == prim::TensorToNum) { + Node* input_node = node->input()->node(); + if (input_node->kind() == prim::NumToTensor) { + node->output()->replaceAllUsesWith(input_node->input()); + } } } } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 1775a57326fc5..3acc520608270 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1,6 +1,7 @@ #include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/argument_spec.h" #include "torch/csrc/jit/operator.h" @@ -21,9 +22,9 @@ struct propagation_error : std::exception {}; namespace { -void setDynamicType(Node * node) { +void setUnshapedType(Node * node) { for(auto o : node->outputs()) { - o->setType(DynamicType::get()); + o->setType(unshapedType(o->type())); } } @@ -34,47 +35,66 @@ int64_t wrapDim(int64_t dim, at::IntList sizes) { return dim; } -at::Tensor representativeTensor(const TensorType * type) { - auto backend = type->device() == -1 ? at::kCPU : at::kCUDA; - at::DeviceGuard device_guard(type->device()); - auto & attype = at::getType(backend, type->scalarType()); - return attype.tensor(type->sizes(), type->strides()).zero_(); +IValue representativeValue(Value* v) { + TypePtr type_ = v->type(); + // if the value is actually constant, just use it! + if(auto iv = toIValue(v)) { + return *iv; + } + if (TensorType* type = type_->cast()) { + auto backend = type->device() == -1 ? at::kCPU : at::kCUDA; + at::DeviceGuard device_guard(type->device()); + auto& attype = at::getType(backend, type->scalarType()); + return attype.tensor(type->sizes(), type->strides()).zero_(); + } else if (type_->isSubtypeOf(*FloatType::get())) { + return 0.f; + } + // we should not get here because isValidArgumentForRunning should have + // prevented it + std::stringstream ss; + ss << "unable to create representative value for: " << type_->str() + << ". File a bug report."; + throw std::runtime_error(ss.str()); } void PropagateShapeOnBlock(Block * block, bool insert_expands=true); +// for each node in the schema with type Tensor, extract the TensorType +// returns at::nullopt if any Tensor in the schema does not have a known shape +// ignores non-tensor in the list of inputs at::optional> gatherTensorTypes(Node *node) { std::vector tensor_types; - tensor_types.reserve(node->inputs().size()); - // TODO (apaszke): Remove once we stop using attributes - // XXX: we also make the exception for cat, because we need shape prop to work for it - // (we have tests). We'll have to remove the special case once we stop flattening lists into inputs. - if (node->hasAttributes() || node->kind() == aten::cat) { - std::vector inputs = node->inputs(); - if (node->kind() == aten::cat && inputs.back()->type()->isSubtypeOf(*IntType::get())) { - inputs.pop_back(); - } - for (Value *v : inputs) { - TensorType* type = v->type()->cast(); - if(!type) return at::nullopt; - tensor_types.push_back(type); - } - } else { - auto & schema = node->schema(); - auto & args = schema.arguments; - // XXX: This gets triggered for nodes that have Tensor[] as arguments. - // Those are currently very annoying to handle, because the lists are simply - // inlined into the node inputs, so we bail out from shape propagation for now. - if (schema.is_vararg || args.size() != node->inputs().size()) { - return at::nullopt; + + auto & schema = node->schema(); + auto & args = schema.arguments; + // can't handle varargs primitives because we don't know what should be a Tensor + if (schema.is_vararg) { + return at::nullopt; + } + size_t input_i = 0; + for (auto& arg : args) { + size_t consume_n; // how many tensors do we check for in the input list + if (arg.type->isSubtypeOf(*ListType::ofTensors())) { + // we have a list of tensor, there is only ever one list + // so we calculte how many elements must be in it by how much bigger + // or smaller the input list is compared to the arguments in the schema + consume_n = node->inputs().size() + 1 - args.size(); + } else if (arg.type->isSubtypeOf(*DynamicType::get())) { + // a single Tensor for this argument + consume_n = 1; + } else { + // this argument is not a tensor, skip it + consume_n = 0; } - for (size_t i = 0; i < node->inputs().size(); ++i) { - if (!args[i].type->isSubtypeOf(*DynamicType::get())) continue; - TensorType *type = node->input(i)->type()->cast(); - if (!type) return at::nullopt; + for(size_t j = 0; j < consume_n; j++) { + // bail out if a tensor does not have a size + TensorType *type = node->input(input_i++)->type()->cast(); + if (!type) + return at::nullopt; tensor_types.push_back(type); } } + return tensor_types; } @@ -116,36 +136,12 @@ void broadcastBinary(Node *node, std::vector& types, size_t idx1, s types[1] = node->inputs().at(idx2)->type()->expect(); } -void PropagateShapeOnNodeByRunningIt(Node* node, const std::vector& types) { +void PropagateShapeOnNodeByRunningIt(Node* node) { auto op = getOperation(node); Stack stack; - size_t types_i = 0; - // TODO (apaszke): remove once we stop using attributes - if (node->hasAttributes()) { - for (auto & type : types) { - stack.push_back(representativeTensor(type)); - } - // TODO (apaszke): remove once aten::cat is saner (see first XXX in gatherTensorTypes) - } else if (node->kind() == aten::cat) { - for (auto & type : types) { - stack.push_back(representativeTensor(type)); - } - stack.push_back(node->get(attr::dim).value()); - } else { - JIT_ASSERT(node->schema().arguments.size() == node->inputs().size()); - for (const auto & arg : node->schema().arguments) { - if (arg.type->isSubtypeOf(*DynamicType::get())) { - stack.emplace_back(representativeTensor(types[types_i++])); - } else { - auto maybe_val = node->get(Symbol::attr(arg.name)); - if (!maybe_val) { - setDynamicType(node); - return; - } - stack.push_back(std::move(*maybe_val)); - } - } + for (auto input : node->inputs()) { + stack.push_back(representativeValue(input)); } // XXX: we're not catching any exceptions from the op for now. This @@ -156,10 +152,55 @@ void PropagateShapeOnNodeByRunningIt(Node* node, const std::vector& JIT_ASSERT(stack.size() == node->outputs().size()); for (size_t i = 0; i < stack.size(); ++i) { - node->outputs()[i]->inferTypeFrom(stack[i].toTensor()); + // some ops may have mixed tensor/primitive outputs + // for primitives, we don't need to change the type because it is already + // its most constrained form. + if(stack[i].isTensor()) + node->outputs()[i]->inferTypeFrom(stack[i].toTensor()); } } +// is it ok to try to run the op +// If an input is a constant, then we assume that the input is valid +// and we can try to run it. +// Otherwise: +// Integral typed _inputs_ are often an indicator that we're indexing into +// a tensor, so we should special-case these ops in the shape propagation. +// Additionally, passing in a zero representative tensor into an integer +// division op causes divide-by-zero errors +// _Outputs_ must be tensors or primtives +// We will call inferTypeFrom on the tensors, and ignore the primitives. +// However, we allow primitive returns because we want to support mixed +// primitive/tensor outputs. + +bool isValidArgumentForRunning(Value* v) { + // allow constants + if(toIValue(v)) + return true; + if(TensorType* tt = v->type()->cast()) { + return !at::isIntegralType(tt->scalarType()); + } + return v->type()->isSubtypeOf(*FloatType::get()); +} +bool isValidReturnForRunning(Value* v) { + return v->type()->isSubtypeOf(*DynamicType::get()) || + v->type()->isSubtypeOf(*NumberType::get()); +} + +bool canPropagateShapeByRunningIt(Node* node) { + bool valid_args = std::all_of( + node->inputs().begin(), node->inputs().end(), isValidArgumentForRunning); + if (!valid_args) + return false; + + bool valid_returns = std::all_of( + node->outputs().begin(), node->outputs().end(), isValidReturnForRunning); + if (!valid_returns) + return false; + + return true; +} + void PropagateShapeOnNode(Node * node, bool insert_expands) { // These don't require the types, and have complicated schema. Return early after we process them. switch(node->kind()) { @@ -198,28 +239,31 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { } return; } + case prim::TensorToNum: case prim::NumToTensor: - case prim::TensorToNum: { - node->output()->setType(node->inputs()[0]->type()); - return; - } + return; // correct num type is already set case prim::Constant: { - node->output()->inferTypeFrom(node->t(attr::value)); + if(node->output()->type()->isSubtypeOf(*DynamicType::get())) { + node->output()->inferTypeFrom(node->t(attr::value)); + } return; } case prim::PythonOp: case prim::Print: case prim::Undefined: { - setDynamicType(node); + setUnshapedType(node); return; } default: break; // fall-through } + bool can_propagate_by_running = canPropagateShapeByRunningIt(node); auto maybe_tensor_types = gatherTensorTypes(node); if (!maybe_tensor_types) { - return setDynamicType(node); + if(can_propagate_by_running) + return PropagateShapeOnNodeByRunningIt(node); + return setUnshapedType(node); } auto & tensor_types = *maybe_tensor_types; @@ -344,7 +388,7 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { return; } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { if (tensor_types.at(0)->scalarType() == tensor_types.at(1)->scalarType()) { - node->output()->setType(node->input(attr::self)->type()); + node->output()->setType(node->namedInput(attr::self)->type()); } else { // This will be a copy, so the result will be contiguous node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes())); @@ -377,23 +421,16 @@ void PropagateShapeOnNode(Node * node, bool insert_expands) { std::make_shared(at::kLong, -1, dims)); return; } else if (node->kind() == onnx::Reshape) { - setDynamicType(node); + setUnshapedType(node); return; } // If we haven't managed to handle the op so far, we fall back to inferring the // shapes by doing an example run of the op (if we can). - // Integral typed inputs are often an indicator that we're indexing into - // a tensor, so we should special-case these ops in the shape propagation. - // Additionally, passing in a zero representative tensor into an integer - // division op causes divide-by-zero errors - bool shape_inferenceable = !std::any_of(tensor_types.begin(), tensor_types.end(), [](TensorType* t){ - return at::isIntegralType(t->scalarType()); - }); - if (shape_inferenceable) { - PropagateShapeOnNodeByRunningIt(node, tensor_types); + if (can_propagate_by_running) { + PropagateShapeOnNodeByRunningIt(node); } else { - setDynamicType(node); + setUnshapedType(node); } } @@ -402,7 +439,7 @@ void PropagateShapeOnBlock(Block * block, bool insert_expands) { try { PropagateShapeOnNode(node, insert_expands); } catch(propagation_error& e) { - setDynamicType(node); + setUnshapedType(node); } catch(std::exception & e) { if(auto sl = node->getSourceLocation()) { sl->wrapAndRethrowException(e, "operation failed shape propagation"); diff --git a/torch/csrc/jit/python_interpreter.cpp b/torch/csrc/jit/python_interpreter.cpp index 5af53c4455b12..d8baa1859a85f 100644 --- a/torch/csrc/jit/python_interpreter.cpp +++ b/torch/csrc/jit/python_interpreter.cpp @@ -9,7 +9,7 @@ #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/ir.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/variable_tensor_functions.h" #include diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index f03edfe2d6b3b..92445aaf6881d 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -236,7 +236,6 @@ void initPythonIRBindings(PyObject * module_) { .def("return_node", [](Graph &g) { return g.block()->return_node(); }) - .GS(createConstant) .GS(createFusionGroup) .def("createClone",[](Graph & g, Node * n, py::object fn) { return g.createClone(n, [&](Value * e) { diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 3a2ae20850de8..010e0919f9cd0 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -7,7 +7,7 @@ #include "torch/csrc/jit/graph_executor.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/operator.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/variable_tensor_functions.h" #include @@ -51,19 +51,37 @@ RegisterOperators reg({ return 0; }; }), - Operator( - prim::Constant, - [](Node* node) { - auto t = autograd::make_variable(node->t(attr::value)); - return [t](Stack& stack) { - stack.push_back(t); + prim::TensorToNum, + [](Node* node) -> Operation { + if(node->output()->type() == IntType::get()) { + return [](Stack& stack) { + at::Tensor a; + pop(stack, a); + at::DeviceGuard guard(a); + push(stack, a.toCLong()); + return 0; + }; + } else { + return [](Stack& stack) { + at::Tensor a; + pop(stack, a); + at::DeviceGuard guard(a); + push(stack, a.toCDouble()); + return 0; + }; + } + }), + Operator( + prim::NumToTensor, + [](Node* node) -> Operation { + return [](Stack& stack) { + at::Scalar s; + pop(stack, s); + push(stack, autograd::make_variable(s.toTensor())); return 0; }; }), - - Operator(prim::NumToTensor, noop), - Operator(prim::TensorToNum, noop), Operator( prim::Undefined, [](Node* node) { @@ -121,11 +139,12 @@ RegisterOperators reg({ onnx::Reshape, [](Node* node) { return [=](Stack& stack) { - auto shape = pop(stack).toTensor().contiguous(); - auto input = pop(stack).toTensor(); + at::Tensor input, shape; + pop(stack, input, shape); + shape = shape.contiguous(); JIT_ASSERT(shape.ndimension() == 1); at::IntList shape_list(shape.data(), shape.size(0)); - stack.push_back(input.reshape(shape_list)); + push(stack, input.reshape(shape_list)); return 0; }; }), @@ -150,8 +169,6 @@ RegisterOperators reg({ prim::AnyDefined, [](Node* node) { size_t num_inputs = node->inputs().size(); - auto true_ = at::full({}, 1, at::kLong); - auto false_ = at::full({}, 0, at::kLong); return [=](Stack& stack) { bool result = false; for (const IValue& t : last(stack, num_inputs)) { @@ -161,7 +178,7 @@ RegisterOperators reg({ } } drop(stack, num_inputs); - stack.push_back(result ? true_ : false_); + stack.push_back(result); return 0; }; }), @@ -170,8 +187,8 @@ RegisterOperators reg({ prim::AutogradAdd, [](Node* node) { return [=](Stack& stack) { - auto a = pop(stack).toTensor(); - auto b = pop(stack).toTensor(); + at::Tensor a, b; + pop(stack, a, b); if (!a.defined()) stack.push_back(b); else if (!b.defined()) @@ -181,5 +198,151 @@ RegisterOperators reg({ return 0; }; }), + Operator( + prim::ListConstruct, + [](Node* node) -> Operation { + size_t num_inputs = node->inputs().size(); + ListType* lt = node->output()->type()->expect(); + if(IntType::get() == lt->getElementType()) { + return [=](Stack& stack) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + std::vector vals = fmap(inputs, [](const IValue& v) { + return v.toInt(); + }); + drop(stack, num_inputs); + push(stack, std::move(vals)); + return 0; + }; + } else if(FloatType::get() == lt->getElementType()) { + return [=](Stack& stack) { + auto inputs = peekSlice(stack, 0, num_inputs, num_inputs); + std::vector vals = fmap(inputs, [](const IValue& v) { + return v.toDouble(); + }); + drop(stack, num_inputs); + push(stack, std::move(vals)); + return 0; + }; + } else { + std::stringstream ss; + ss << "unsupported list type: " << *lt->getElementType(); + throw std::runtime_error(ss.str()); + } + }), +}); + +// define implementations for primitive number ops +#define DEFINE_GENERIC_OP(aten_op, op, float_result) \ + Operator( \ + #aten_op "(int a, int b) -> int", \ + [](Node* node) { \ + return [=](Stack& stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + return 0; \ + }; \ + }), \ + Operator( \ + #aten_op "(float a, float b) -> " #float_result, [](Node* node) { \ + return [=](Stack& stack) { \ + double a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + return 0; \ + }; \ + }), + +#define DEFINE_INT_OP(aten_op, op) \ + Operator(#aten_op "(int a, int b) -> int", [](Node* node) { \ + return [=](Stack& stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + return 0; \ + }; \ + }), + +#define DEFINE_BINARY_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, float) +#define DEFINE_COMPARISON_OP(aten_op, op) DEFINE_GENERIC_OP(aten_op, op, int) + +// define helpers for where aten is missing scalar overloads +// note: it would be better to define these in a standard library as +// script functions and have the compiler substitute them in +// however, we need to add type annotations to the parser in order for us +// to move them there. +// e.g. s + t ==> t + s +// e.g. s - d == -d + s + +#define DEFINE_ST_OP(aten_op, reverse_exp) \ + Operator("aten::" #aten_op "(Scalar a, Tensor b) -> Tensor", [](Node* node) { \ + return [=](Stack& stack) { \ + at::Scalar a; \ + at::Tensor b; \ + pop(stack, a, b); \ + at::DeviceGuard guard(b); \ + push(stack, reverse_exp); \ + return 0; \ + }; \ + }), + +RegisterOperators reg2({ + DEFINE_BINARY_OP(aten::add, a + b) + DEFINE_BINARY_OP(aten::sub, a - b) + DEFINE_BINARY_OP(aten::mul, a * b) + DEFINE_BINARY_OP(aten::div, a / b) + DEFINE_BINARY_OP(aten::pow, static_cast(pow(a, b))) + + DEFINE_COMPARISON_OP(aten::ne, a != b) + DEFINE_COMPARISON_OP(aten::eq, a == b) + DEFINE_COMPARISON_OP(aten::lt, a < b) + DEFINE_COMPARISON_OP(aten::gt, a > b) + DEFINE_COMPARISON_OP(aten::le, a <= b) + DEFINE_COMPARISON_OP(aten::ge, a >= b) + + DEFINE_INT_OP(aten::__and__, a&& b) + DEFINE_INT_OP(aten::__or__, a || b) + + Operator( + "aten::neg(int a) -> int", + [](Node* node) { + return [=](Stack& stack) { + push(stack, -pop(stack).toInt()); + return 0; + }; + }), + Operator( + "aten::neg(float a) -> float", + [](Node* node) { + return [=](Stack& stack) { + push(stack, -pop(stack).toDouble()); + return 0; + }; + }), + Operator( + "aten::__not__(int a) -> int", + [](Node* node) { + return [=](Stack& stack) { + push(stack, !pop(stack).toInt()); + return 0; + }; + }), + + // commutative + DEFINE_ST_OP(mul, at::mul(b, a)) + DEFINE_ST_OP(add, at::add(b, a)) + DEFINE_ST_OP(ne, at::ne(b, a)) + DEFINE_ST_OP(eq, at::eq(b, a)) + + // comparisons, reverse the condition + DEFINE_ST_OP(lt, b > a) + DEFINE_ST_OP(le, b >= a) + DEFINE_ST_OP(gt, b < a) + DEFINE_ST_OP(ge, b <= a) + + // rsub + DEFINE_ST_OP(sub, at::add(b.neg(), a)) + // rdiv + DEFINE_ST_OP(div, at::mul(at::reciprocal(b), a)) }); }}} // torch::jit::anon diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index df3ff8151c6c1..b12bb7b58d016 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -6,7 +6,8 @@ #include "torch/csrc/jit/script/parser.h" #include "torch/csrc/utils/object_ptr.h" #include "torch/csrc/jit/operator.h" -#include "torch/csrc/jit/tensor_conversions.h" + +#include "torch/csrc/jit/constants.h" #include "ATen/optional.h" @@ -24,18 +25,90 @@ using ValueTable = std::unordered_map; using AttributeMap = std::unordered_map; using ListAttributeMap = std::unordered_map>; -// what type will this have in the interpreter, ignoring extra static information -// in particular Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) -static TypePtr interpreterType(const TypePtr& type) { - if(TupleType* t = type->cast()) { - return std::make_shared(fmap(t->elements(), interpreterType)); - } else if(type->kind() == TypeKind::TensorType) { - return DynamicType::get(); - } else { - return type; +struct NoneValue : SugaredValue { + NoneValue() {} + virtual std::string kind() const override { + return "None"; + } +}; + +struct PrintValue : public SugaredValue { + std::string kind() const override { + return "print"; } + std::shared_ptr call( + SourceRange loc, + Method & m, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + auto& g = *m.graph(); + if (!attributes.empty()) + throw ErrorReport(loc) << "print doesn't accept any keyword arguments"; + auto values = toValues(inputs); + ensureTensors(loc, values); + g.insertNode(g.create(prim::Print, values, 0) + ->setSourceLocation(std::make_shared(loc))); + return std::make_shared(); + } +}; + +static Value* numToTensor(const SourceRange& loc, Value* value) { + auto& graph = *value->owningGraph(); + auto n = graph.insertNode(graph.createNumToTensor(value)) + ->setSourceLocation(std::make_shared(loc)); + return n->output(); +} + +static Value* tensorToNum( + const SourceRange& loc, + Value* value, + const TypePtr type) { + auto& graph = *value->owningGraph(); + auto* result = graph.insertNode(graph.createTensorToNum(type, value)) + ->setSourceLocation(std::make_shared(loc)) + ->output(); + return result; } +// expressions like int(x) +struct CastValue : public SugaredValue { + CastValue(TypePtr type) + : type(type) {} + std::string kind() const override { + std::stringstream ss; + ss << "<" << type->str() << " cast primitive>"; + return ss.str(); + } + std::shared_ptr call( + SourceRange loc, + Method & m, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + if (!attributes.empty()) + throw ErrorReport(loc) << "casts do not accept any keyword arguments"; + if (inputs.size() != 1) + throw ErrorReport(loc) << "expected a single argument for cast"; + auto values = toValues(inputs); + Value* input = values.at(0); + if(!input->type()->isSubtypeOf(*type)) { + if(*type == *DynamicType::get()) { + if(!input->type()->isSubtypeOf(*NumberType::get())) { + throw ErrorReport(loc) << "expected a number"; + } + input = numToTensor(loc, input); + } else { + ensureTensors(loc, values); + input = tensorToNum(loc, values.at(0), type); + } + } + return std::make_shared(input); + } +private: + TypePtr type; +}; + // Auxiliary data structure for desugaring variable binding into our always // explicitly scoped language as we descend down // nested control structures in the frontend (which themselves don't introduce @@ -170,7 +243,7 @@ struct Environment { throw ErrorReport(loc) << "Cannot re-assign '" << name << "' because it has type " << value->kind() << " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed"; } - if(!as_simple_value->type()->isSubtypeOf(*interpreterType(simple_parent->type()))) { + if(!as_simple_value->type()->isSubtypeOf(*unshapedType(simple_parent->type()))) { throw ErrorReport(loc) << "variable '" << name << "' previously has type " << simple_parent->type()->str() << " but is now being assigned to a value of type " << as_simple_value->type()->str(); } @@ -194,6 +267,21 @@ struct Environment { retval = resolver(ident); } + if(!retval) { + static std::unordered_map globals = { + {"print", std::make_shared()}, + {"float", std::make_shared(FloatType::get())}, + {"int", std::make_shared(IntType::get())}, + {"bool", std::make_shared(IntType::get())}, + // todo(zach): remove when we can correctly export torch.full via ONNX + // or we have implicit conversion that can convert numbers to tensors + {"_to_tensor", std::make_shared(DynamicType::get()) }, + }; + auto it = globals.find(ident); + if(it != globals.end()) + retval = it->second; + } + if (!retval && required) { throw ErrorReport(range) << "undefined value " << ident; } @@ -247,15 +335,9 @@ std::shared_ptr packOutputs(Graph& g, at::ArrayRef values) return std::make_shared(g.insertNode(g.createTuple(values))->output()); } -Value* createConstant(Graph& g, const SourceRange& loc, const at::Tensor& val) { - auto n = g.createConstant(val); - n->setSourceLocation(std::make_shared(loc)); - return g.insertNode(n)->output(); -} - Value* createNumber(Graph& g, const SourceRange& loc, const at::Tensor& val) { JIT_ASSERT(val.numel() == 1); - auto* output = createConstant(g, loc, val); + auto* output = insertConstant(g, val, loc); if (val.type().scalarType() == at::kLong) { output->setType(IntType::get()); } else if (val.type().scalarType() == at::kFloat) { @@ -277,7 +359,7 @@ Value* createStack(Graph& g, const SourceRange& loc, at::ArrayRef inputs auto values = fmap(inputs, [&](Value* v) { return v->node()->t(attr::value); }); - return createConstant(g, loc, at::stack(values)); + return insertConstant(g, at::stack(values), loc); } return g.insertNode(g.create(aten::stack, inputs) ->i_(attr::dim, 0) @@ -288,18 +370,10 @@ static bool isTensorSubtype(Value* v) { return v->type()->isSubtypeOf(*DynamicType::get()); } -static bool isNumberSubtype(const Value* v) { - return v->type()->isSubtypeOf(*NumberType::get()); -} - -static bool isNumberSubtype(const TypePtr& type) { - return type->isSubtypeOf(*NumberType::get()); -} - at::optional> getIntListAttribute(at::optional N, Value* input) { - auto list = constant_as>(input); + auto list = constant_as>(input); if(list) - return list; + return std::vector(*list); // broadcast IntList[3] with value 4 -> {4, 4, 4} if(!N) return at::nullopt; @@ -343,10 +417,10 @@ void liftConstantAttributes(const FunctionSchema& schema, Node* node) { attributes.f_(Symbol::attr(arg.name), *r); } break; case TypeKind::NumberType: { - auto r = constant_as(input); + auto r = constant_as(input); if(!r) return; - attributes.t_(Symbol::attr(arg.name), *r); + attributes.t_(Symbol::attr(arg.name), r->toTensor()); } break; case TypeKind::ListType: { auto elem = arg.type->expect()->getElementType(); @@ -385,33 +459,6 @@ at::ArrayRef createTupleUnpack(Value* v) { return g.insertNode(g.createTupleUnpack(v))->outputs(); } - -static Value* numToTensor( - const SourceRange& loc, - Graph& graph, - Value* value) { - JIT_ASSERT(isNumberSubtype(value)); - auto* result = graph.insertNode(graph.create(prim::NumToTensor, {value}) - ->setSourceLocation(std::make_shared(loc))) - ->output(); - result->setType(DynamicType::get()); - return result; -} - -static Value* tensorToNum( - const SourceRange& loc, - Graph& graph, - Value* value, - const TypePtr type) { - JIT_ASSERT(isTensorSubtype(value)); - JIT_ASSERT(isNumberSubtype(type)); - auto* result = graph.insertNode(graph.create(prim::TensorToNum, {value}) - ->setSourceLocation(std::make_shared(loc))) - ->output(); - result->setType(type); - return result; -} - static inline bool isIntUsedAsIntList( const Value* value, const Argument& arg) { @@ -469,7 +516,7 @@ at::optional> tryMatchSchema( positional_inputs[i] = NamedValue( loc, i, - createConstant(graph, loc, *default_value) + insertConstant(graph, *default_value, loc) ->setType(schema.arguments[i].type)); } @@ -479,34 +526,20 @@ at::optional> tryMatchSchema( NamedValue v = *positional_inputs[i]; const auto& arg = schema.arguments[i]; - // some functions that take lists of integers for fixed size arrays // also allow single ints to be passed in their place. // the single int is then repeated to the length of the list if (isIntUsedAsIntList(v.value, arg)) { std::vector repeated(*arg.N, v.value); - v.value = graph.insertNode(graph.createTuple(repeated))->output(); + v.value = graph.insertNode(graph.createList(IntType::get(), repeated))->output(); } - // Tuples of integers are created using TuplePack which we do not actually - // support in the interpreter, so we have to replace it with a - // stack call, which creates a Tensor to represent the list. + // Allow tuples that only contain integers to turn into lists of integers if(*ListType::ofInts() == *arg.type && v.value->type()->kind() == TypeKind::TupleType && v.value->type()->isSubtypeOf(*ListType::ofInts())) { auto unpacked = createTupleUnpack(v.value); - // elements are numbers so we have to convert to tensors before - // stack will be valid - auto unpacked_t = fmap(unpacked, [&](Value* e) { - return numToTensor(v.loc, graph, e); - }); - v.value = createStack(graph, loc, unpacked_t)->setType(ListType::ofInts()); - } - - // implicit conversion from Tensor to Python Number - // FIXME: remove this when we support passing numbers into script fns - if (isTensorSubtype(v.value) && isNumberSubtype(arg.type)) { - v.value = tensorToNum(loc, graph, v.value, arg.type); + v.value = graph.insertNode(graph.createList(IntType::get(), unpacked))->output(); } if(!v.value->type()->isSubtypeOf(*arg.type)) { @@ -530,7 +563,7 @@ at::optional> tryMatchSchema( static std::shared_ptr tryEmitBuiltin( - const FunctionSchema& schema, + const std::shared_ptr& op, std::stringstream& failure_messages, const SourceRange& loc, Method& method, @@ -539,7 +572,7 @@ static std::shared_ptr tryEmitBuiltin( at::ArrayRef attributes) { auto graph = method.graph(); - auto flat_inputs = tryMatchSchema(schema, loc, *graph, inputs, attributes, failure_messages); + auto flat_inputs = tryMatchSchema(op->schema, loc, *graph, inputs, attributes, failure_messages); if(!flat_inputs) return nullptr; // we successfully matched this schema, construct the node @@ -552,8 +585,6 @@ static std::shared_ptr tryEmitBuiltin( auto n = graph->insertNode(graph->create(kind, *flat_inputs, 0)) ->setSourceLocation(std::make_shared(loc)); - size_t num_outputs = schema.returns.size(); - // special case for chunk when the chunks= is known // DO NOT ADD MORE SPECIAL CASES HERE, REFACTOR INTO A FUNCTION IF // NEEDED @@ -562,13 +593,16 @@ static std::shared_ptr tryEmitBuiltin( if(!value) { throw ErrorReport(loc) << "argument 'chunks' must be a constant"; } - num_outputs = *value; + for(int64_t i = 0; i < *value; ++i) + n->addOutput(); + } else { + for(auto & ret : op->schema.returns) { + n->addOutput()->setType(ret.type); + } } - for(size_t i = 0; i < num_outputs; ++i) - n->addOutput(); - - liftConstantAttributes(schema, n); + if(op->hasAttributedVersion()) + liftConstantAttributes(op->schema, n); // assert that we did indeed create an op that has implementation // otherwise schema and dispatch are not in sync @@ -603,7 +637,7 @@ std::shared_ptr emitBuiltinCall( std::stringstream failure_messages; for (const std::shared_ptr& op : variants) { if (auto result = tryEmitBuiltin( - op->schema, failure_messages, loc, method, name, inputs, attributes)) { + op, failure_messages, loc, method, name, inputs, attributes)) { return result; } } @@ -619,25 +653,18 @@ std::shared_ptr emitBuiltinCall( << "for call at"; } -struct NoneValue : SugaredValue { - NoneValue() {} - virtual std::string kind() const override { - return "None"; - } -}; - static Value* ensureTensor(const SourceRange& range, Value* v) { if(!isTensorSubtype(v)) { throw ErrorReport(range) << "expected a tensor value but found a " - << *v->type(); + << v->type()->str(); } return v; } -static Value* ensureTensorOrNumber(const SourceRange& range, Value* v) { - if(!isNumberSubtype(v) && !isTensorSubtype(v)) { - throw ErrorReport(range) << "expected a Number or Tensor value but found a " - << *v->type(); +static Value* ensureInt(const SourceRange& range, Value* v) { + if(!v->type()->isSubtypeOf(*IntType::get())) { + throw ErrorReport(range) << "expected a int but found a " + << v->type()->str(); } return v; } @@ -721,9 +748,14 @@ struct to_ir { results = createTupleUnpack(result); } } - ensureTensors(return_stmt.range(), results); - for(auto r : results) { - graph->registerOutput(r); + auto range = return_stmt.range(); + for (auto& r : results) { + if(r->type()->isSubtypeOf(*NumberType::get())) { + graph->registerOutput(numToTensor(range, r)); + } else { + ensureTensor(range, r); + graph->registerOutput(r); + } returns.push_back({"", DynamicType::get()}); } } @@ -808,13 +840,15 @@ struct to_ir { } Value* emitTernaryIf(const TernaryIf& expr) { - Value* cond_value = emitExpr(expr.cond()); + Value* cond_value = emitCond(expr.cond()); Node* n = graph->insertNode(create(prim::If, expr.range(), 0)); + n->addInput(cond_value); auto* true_block = n->addBlock(); auto* false_block = n->addBlock(); + auto emit_if_expr = [this](Block* b, const Expr& expr) { pushFrame(b); WithInsertPoint guard(b); @@ -826,14 +860,33 @@ struct to_ir { emit_if_expr(true_block, expr.true_expr()); emit_if_expr(false_block, expr.false_expr()); + auto true_type = unshapedType(true_block->outputs().at(0)->type()); + auto false_type = unshapedType(false_block->outputs().at(0)->type()); + if (*true_type != *false_type) { + throw ErrorReport(expr) + << "if-expression's true branch has type " << true_type->str() + << " but false branch has type " << false_type->str(); + } + // Add op outputs - auto expr_value = n->addOutput(); // Resulting value + auto expr_value = n->addOutput()->setType(true_type); // Resulting value return expr_value; } + Value* emitCond(Expr cond) { + Value* v = emitExpr(cond, identity); + if(v->type()->isSubtypeOf(*DynamicType::get())) { + v = tensorToNum(cond.range(), v, IntType::get()); + } + if(!v->type()->isSubtypeOf(*IntType::get())) { + throw ErrorReport(cond) << "expected a tensor or integer expression for condition but found " << v->type()->str(); + } + return v; + } + void emitIf(const If& stmt) { - Value* cond_value = emitExpr(stmt.cond()); + Value* cond_value = emitCond(stmt.cond()); Node* n = graph->insertNode(create(prim::If, stmt.range(), 0)); n->addInput(cond_value); @@ -922,21 +975,21 @@ struct to_ir { { WithInsertPoint guard(n); if (max_trip_count) { - max_trip_count_val = emitExpr(max_trip_count.value(), ensureTensorOrNumber); + max_trip_count_val = emitExpr(max_trip_count.value(), ensureInt); } else { max_trip_count_val = - emitConst(Const::create(range, std::to_string(INT_MAX))); + insertConstant(*graph, INT_MAX, range); } if (cond) { - cond_val = emitExpr(cond.value(), ensureTensorOrNumber); + cond_val = emitCond(cond.value()); } else { - cond_val = emitBooleanConst(range, true); + cond_val = insertConstant(*graph, true, range); } } n->addInput(max_trip_count_val); n->addInput(cond_val); auto* body_block = n->addBlock(); - Value* trip_count = body_block->addInput(); // Iteration num + Value* trip_count = body_block->addInput()->setType(IntType::get()); // Iteration num { pushFrame(body_block); @@ -948,10 +1001,10 @@ struct to_ir { // Also emit the conditional if (cond) { - Value* body_cond_value = emitExpr(cond.value(), ensureTensorOrNumber); + Value* body_cond_value = emitCond(cond.value()); body_block->registerOutput(body_cond_value); } else { - Value* cond_value_dummy = emitBooleanConst(range, true); + Value* cond_value_dummy = insertConstant(*graph, true, range); body_block->registerOutput(cond_value_dummy); } @@ -1239,12 +1292,6 @@ struct to_ir { auto it = function_table.find(ident.name()); if (it != function_table.end()) { return packOutputs(*graph, method.emit_call_to(ident.range(), it->second, inputs, attributes)); - } else if (ident.name() == "print") { - if (!attributes.empty()) - throw ErrorReport(ident) << "print doesn't accept any keyword arguments"; - ensureTensors(ident.range(), toValues(inputs)); - emitNode(prim::Print, ident.range(), toValues(inputs), 0); - return std::make_shared(); } if(auto result = emitBuiltinCall(ident.range(), method, ident.name(), inputs, attributes, false)) { return result; @@ -1276,173 +1323,6 @@ struct to_ir { throw std::runtime_error("reverseComparision: unsupported NodeKind. File a bug"); } - std::vector toNamedValues( - const SourceRange& loc, - ArrayRef inputs) { - return fmap(inputs, [&](Value* v) { - return NamedValue(loc, "", v); - }); - } - - Value* emitBasicMath( - const SourceRange& loc, - Method& method, - NodeKind kind, - at::ArrayRef inputs) { - auto sugared_ptr = emitBuiltinCall( - loc, - method, - kind.toUnqualString(), - toNamedValues(loc, inputs), - /*attributes=*/{}, - /*required=*/true); - auto simple_ptr = std::dynamic_pointer_cast(sugared_ptr); - JIT_ASSERT(simple_ptr); - return simple_ptr->getValue(); - } - - // Handles binary python math ops. - Value* emitPythonMath( - const SourceRange& loc, - Method& method, - NodeKind kind, - Value* lhs, - Value* rhs) { - // Assume lhs, rhs are either IntType or FloatType. - bool lhs_is_float = lhs->type()->kind() == TypeKind::FloatType; - bool rhs_is_float = rhs->type()->kind() == TypeKind::FloatType; - JIT_ASSERT(lhs_is_float || lhs->type()->kind() == TypeKind::IntType); - JIT_ASSERT(rhs_is_float || rhs->type()->kind() == TypeKind::IntType); - - auto out_type = lhs->type(); - if (kind == aten::ge || kind == aten::le || kind == aten::eq || - kind == aten::gt || kind == aten::lt || kind == aten::ne) { - // Stand-in for bool type. - out_type = NumberType::get(); - } else { - // If the types are different, one must be FloatType. - // We should promote the other value to FloatType. - if (lhs_is_float != rhs_is_float) { - out_type = FloatType::get(); - } - } - - // Strategy: cast inputs to tensor, perform op, recast to number - lhs = numToTensor(loc, *graph, lhs); - rhs = numToTensor(loc, *graph, rhs); - - // FIXME: support (python) math between IntType and FloatType. - // Here, without loss of generality, let's say lhs is a float and rhs is an - // int. We should insert an aten::type_as(lhs, rhs) node into the graph. - // However, the graph fuser generally has problems working with scalar tensors - // (#8560), so we don't support this right now. - if (lhs_is_float != rhs_is_float) { - throw std::runtime_error("NYI: math between float and int. See #8560."); - } - - auto* out = emitBasicMath(loc, method, kind, { lhs, rhs }); - return tensorToNum(loc, *graph, out, out_type); - } - - // math ops between a tensor and a number require that the number be the - // the same type (ScalarType and Backend) as the tensor, because numbers - // in the JIT are represented as scalar tensors. - // This function casts the number to the same type as the tensor. - Value* emitTensorNumberMath( - const SourceRange& loc, - Method& method, - NodeKind kind, - Value* lhs, - Value* rhs) { - auto rhs_kind = rhs->type()->kind(); - JIT_ASSERT(rhs_kind == TypeKind::FloatType || rhs_kind == TypeKind::IntType); - JIT_ASSERT(isTensorSubtype(lhs)); - - rhs = numToTensor(loc, *graph, rhs); - auto args = { rhs, lhs }; - rhs = graph->insertNode(graph->create(aten::type_as, args)) - ->output(); - return emitBasicMath(loc, method, kind, { lhs, rhs }); - } - - // Handles binary math ops. - Value* emitMath( - const SourceRange& loc, - Method& method, - NodeKind kind, - ArrayRef inputs) { - JIT_ASSERT(inputs.size() == 2); - auto& lhs = inputs[0]; - auto& rhs = inputs[1]; - bool lhs_is_number = isNumberSubtype(lhs); - bool lhs_is_tensor = isTensorSubtype(lhs); - bool rhs_is_number = isNumberSubtype(rhs); - bool rhs_is_tensor = isTensorSubtype(rhs); - JIT_ASSERT(lhs_is_tensor || lhs_is_number); - JIT_ASSERT(rhs_is_tensor || rhs_is_number); - - if (lhs_is_number && rhs_is_number) { - return emitPythonMath(loc, method, kind, lhs, rhs); - } - - if (lhs_is_number && rhs_is_tensor) { - - // commutative operations: just swap the args - if (kind == aten::mul || kind == aten::add || - kind == aten::ne || kind == aten::eq) { - return emitTensorNumberMath(loc, method, kind, rhs, lhs); - - // rsub - } else if (kind == aten::sub) { - auto* node = emitNode(aten::neg, loc, { rhs }, 1); - return emitTensorNumberMath(loc, method, aten::add, node->output(), lhs); - - // rdiv - } else if (kind == aten::div) { - auto* node = emitNode(aten::reciprocal, loc, { rhs }, 1); - return emitTensorNumberMath(loc, method, aten::mul, node->output(), lhs); - - // Comparision ops: swap args and use reverse comparison - } else if (kind == aten::lt || kind == aten::le || - kind == aten::gt || kind == aten::ge) { - return emitTensorNumberMath(loc, method, - reverseComparision(kind), - rhs, lhs); - } else { - throw std::runtime_error("Unknown node kind, please file a bug report"); - } - } - - if (lhs_is_tensor && rhs_is_number) { - return emitTensorNumberMath(loc, method, kind, lhs, rhs); - } - - return emitBasicMath(loc, method, kind, inputs); - } - - // Handles unary math ops. - Value* emitUnaryMath( - const SourceRange& loc, - Method& method, - NodeKind kind, - ArrayRef inputs) { - JIT_ASSERT(inputs.size() == 1); - auto* in = inputs[0]; - bool in_is_number = isNumberSubtype(in); - bool in_is_tensor = isTensorSubtype(in); - JIT_ASSERT(in_is_number || in_is_tensor); - - if (in_is_tensor) { - return emitBasicMath(loc, method, kind, inputs); - } - - // Cast to tensor, perform op, recast to number - auto out_type = in->type(); - in = numToTensor(loc, *graph, in); - auto* out = emitBasicMath(loc, method, kind, { in }); - return tensorToNum(loc, *graph, out, out_type); - } - // any expression that can produce a SugaredValue is handled here // expressions that only return a single Value* are handled in emitSimpleExpr std::shared_ptr emitSugaredExpr(Expr tree, size_t n_binders) { @@ -1478,11 +1358,7 @@ struct to_ir { case TK_POW: case TK_AND: case TK_OR: - case TK_NOT: { - const auto& inputs = tree->trees(); - auto kind = getNodeKind(tree->kind(), inputs.size()); - return emitNode(kind, tree->range(), getValues(inputs), 1)->output(); - } break; + case TK_NOT: case TK_NE: case TK_EQ: case '<': @@ -1492,33 +1368,31 @@ struct to_ir { case '*': case '/': case '+': - case '-': { - const auto& inputs = tree->trees(); - auto kind = getNodeKind(tree->kind(), inputs.size()); - auto input_vals = getValues(inputs, /*maybe_unpack*/false, ensureTensorOrNumber); - return emitMath(tree->range(), method, kind, input_vals); - } + case '-': case TK_UNARY_MINUS: { const auto& inputs = tree->trees(); auto kind = getNodeKind(tree->kind(), inputs.size()); - auto input_vals = getValues(inputs, /*maybe_unpack*/false, ensureTensorOrNumber); - return emitUnaryMath(tree->range(), method, kind, input_vals); + auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false, identity); + return emitBuiltinCall( + tree->range(), + method, + kind.toUnqualString(), + named_values, + {}, + /*required=*/true) + ->asValue(tree->range(), method); } case TK_STARRED: { throw ErrorReport(tree) << "Unexpected starred expansion. File a bug report."; } - case TK_CAST: { - const auto cast = Cast(tree); - return emitCast(cast.input(), cast.type()); - } break; case TK_CONST: { return emitConst(Const(tree)); } break; case TK_TRUE: { - return emitBooleanConst(tree->range(), true); + return insertConstant(*graph, true, tree->range()); } break; case TK_FALSE: { - return emitBooleanConst(tree->range(), false); + return insertConstant(*graph, false, tree->range()); } break; case TK_SLICE: { const auto slice = Slice(tree); @@ -1545,60 +1419,11 @@ struct to_ir { } } - Value* emitCast(Expr input, const ScalarType& type) { - at::ScalarType t; - switch (type.kind()) { - case TK_INT: - t = at::kInt; - break; - case TK_FLOAT: - t = at::kFloat; - break; - case TK_LONG: - t = at::kLong; - break; - case TK_BOOL: - t = at::kByte; - break; - default: - throw ErrorReport(input) << "Unrecognized type: " << type; - } - return emitNode( - Symbol::aten("type_as"), - input.range(), - {emitExpr(input), createConstant(*graph, input.range(), at::ones({1}, t))}, - 1) - ->output(); - } - - Value* emitBooleanConst(SourceRange range, bool val) { - return createConstant(*graph, range, at::CPU(at::kByte).scalarTensor(val)); - } - Value* emitConst(const Const& c) { - if (c.isFloatingPoint()) { - return createNumber( - *graph, - c.range(), - at::CPU(at::kFloat).scalarTensor(c.asFloatingPoint())); - } else { - return createNumber( - *graph, - c.range(), - at::CPU(at::kLong).scalarTensor(c.asIntegral())); - } - } - - Node* emitNode( - NodeKind kind, - const SourceRange& loc, - const std::vector inputs, - size_t n_outputs) { - Node* n = graph->insertNode(create(kind, loc, n_outputs)); - for (auto* input_value : inputs) { - n->addInput(input_value); - } - return n; + if (c.isFloatingPoint()) + return insertConstant(*graph, c.asFloatingPoint(), c.range()); + else + return insertConstant(*graph, c.asIntegral(), c.range()); } // Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin, @@ -1610,14 +1435,14 @@ struct to_ir { Compound::create(TK_LIST, loc, std::move(inputs)); const auto input_values = getNamedValues(applyInputs->trees(), /*maybe_unpack*/false, - ensureTensorOrNumber); + identity); NamedValue tensor = input_values[0]; NamedValue begin = input_values[1]; NamedValue end = input_values[2]; NamedValue dim = NamedValue(loc, "dim", - createConstant(*graph, loc, at::CPU(at::kLong).scalarTensor(0))); + insertConstant(*graph, 0, loc)); NamedValue step = NamedValue(loc, "step", - createConstant(*graph, loc, at::CPU(at::kLong).scalarTensor(1))); + insertConstant(*graph, 1, loc)); return emitBuiltinCall( loc, method, "slice", {tensor, dim, begin, end, step}, {}, true) @@ -1632,12 +1457,12 @@ struct to_ir { Compound::create(TK_LIST, loc, std::move(inputs)); auto input_values = getNamedValues(applyInputs->trees(), /*maybe_unpack*/false, - ensureTensorOrNumber); + identity); NamedValue tensor = input_values[0]; NamedValue dim = NamedValue( loc, "dim", - createConstant(*graph, loc, at::CPU(at::kLong).scalarTensor(0))); + insertConstant(*graph, 0, loc)); NamedValue idx = input_values[1]; return emitBuiltinCall(loc, method, "select", {tensor, dim, idx}, {}, true) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 2bb72c4a18b7a..c028b595b4937 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -4,10 +4,10 @@ #include "torch/csrc/Dtype.h" #include "torch/csrc/Layout.h" #include "torch/csrc/jit/script/compiler.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/python_tracer.h" #include "torch/csrc/jit/pybind_utils.h" -#include "torch/csrc/jit/passes/to_batch.h" +#include "torch/csrc/jit/constants.h" #include @@ -39,12 +39,8 @@ static std::string typeString(py::handle h) { return py::str(h.get_type().attr("__name__")); } -static std::shared_ptr createConstant(SourceRange loc, Method& m, const at::Tensor& val, TypePtr typ=nullptr) { - auto n = m.graph()->createConstant(val); - if(typ) - n->output()->setType(typ); - n->setSourceLocation(std::make_shared(loc)); - return std::make_shared(m.graph()->insertNode(n)->output()); +inline std::shared_ptr toSimple(Value* v) { + return std::make_shared(v); } struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { @@ -190,24 +186,25 @@ struct VISIBILITY_HIDDEN ConstantPythonValue : public PythonValue { // f = python_constant // while ... // f = f + 1 + auto& g = *m.graph(); if(py::isinstance(self)) { - return createConstant(loc, m, at::CPU(at::kLong).scalarTensor(py::cast(self))); + return toSimple(insertConstant(g, py::cast(self), loc)); } else if(py::isinstance(self)) { - return createConstant(loc, m, at::CPU(at::kFloat).scalarTensor(py::cast(self))); + return toSimple(insertConstant(g, py::cast(self), loc)); } else if(py::isinstance(self)) { - return createConstant(loc, m, at::CPU(at::kByte).scalarTensor(py::cast(self))); + return toSimple(insertConstant(g, py::cast(self), loc)); } else if(THPDevice_Check(self.ptr())) { auto device = (THPDevice*) self.ptr(); - auto t = as_tensor({static_cast(device->device.type()), device->device.index()}); - return createConstant(loc, m, t, ListType::ofInts()); + std::vector v = {static_cast(device->device.type()), device->device.index()}; + return toSimple(insertConstant(g, std::move(v))); } else if(THPLayout_Check(self.ptr())) { auto layout = (THPLayout*) self.ptr(); const auto v = static_cast(layout->layout); - return createConstant(loc, m, at::CPU(at::kLong).scalarTensor(v), IntType::get()); + return toSimple(insertConstant(g, v, loc)); } else if(THPDtype_Check(self.ptr())) { auto dtype = (THPDtype*)(self.ptr()); const auto v = static_cast(dtype->scalar_type); - return createConstant(loc, m, at::CPU(at::kLong).scalarTensor(v), IntType::get()); + return toSimple(insertConstant(g, v, loc)); } return std::make_shared(self); } @@ -373,22 +370,22 @@ void initJitScriptBindings(PyObject* module) { .def("_set_optimized", &Module::set_optimized) .def( "_define", - [](Module& m, + [](std::shared_ptr m, const std::string& script, ResolutionCallback rcb, bool has_self) { - auto self = has_self ? std::make_shared(m.shared_from_this()) : nullptr; - return defineMethodsInModule(m, script, pythonResolver(rcb), self); + auto self = has_self ? std::make_shared(m) : nullptr; + return defineMethodsInModule(*m, script, pythonResolver(rcb), self); }) - .def("_create_methods", [](Module& m, const std::vector& defs, const std::vector& rcbs) { + .def("_create_methods", [](std::shared_ptr m, const std::vector& defs, const std::vector& rcbs) { std::vector resolvers; for(auto & callback : rcbs) { resolvers.push_back(pythonResolver(callback)); } defineMethodsInModule( - m, + *m, defs, resolvers, - std::make_shared(m.shared_from_this())); + std::make_shared(m)); }) .def("_get_method", [](Module& self, const std::string& name) -> const Method& { diff --git a/torch/csrc/jit/script/lexer.h b/torch/csrc/jit/script/lexer.h index 7e2c81233ce76..982b2694c6edf 100644 --- a/torch/csrc/jit/script/lexer.h +++ b/torch/csrc/jit/script/lexer.h @@ -30,10 +30,6 @@ namespace script { _(TK_NEWLINE, "newline", "") \ _(TK_INDENT, "indent", "") \ _(TK_DEDENT, "dedent", "") \ - _(TK_FLOAT, "float", "float") \ - _(TK_DOUBLE, "double", "double") \ - _(TK_LONG, "long", "long") \ - _(TK_INT, "int", "int") \ _(TK_DEF, "def", "def") \ _(TK_EQUIVALENT, "equivalent", "<=>") \ _(TK_IDENT, "ident", "") \ @@ -47,7 +43,6 @@ namespace script { _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ _(TK_PARAM, "param", "") \ _(TK_INFERRED, "inferred", "") \ - _(TK_BOOL, "bool", "bool") \ _(TK_ACCESS, "access", "") \ _(TK_ASSIGN, "assign", "") \ _(TK_ATTRIBUTE, "attribute", "") \ diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index fbf4575d01ad2..3494eb392bb2b 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -207,7 +207,7 @@ struct NamedParameter { std::unique_ptr parameter; }; -struct Module : public std::enable_shared_from_this { +struct Module { TH_DISALLOW_COPY_AND_ASSIGN(Module); Module() : modules("Module") diff --git a/torch/csrc/jit/script/parser.h b/torch/csrc/jit/script/parser.h index 4da07a6663c32..ed54c47b0f7f6 100644 --- a/torch/csrc/jit/script/parser.h +++ b/torch/csrc/jit/script/parser.h @@ -67,16 +67,6 @@ struct Parser { auto list = parseList('[', ',', ']', &Parser::parseExp); prefix = ListLiteral::create(list.range(), List(list)); } break; - case TK_FLOAT: - case TK_INT: - case TK_LONG: { - auto r = L.cur().range; - auto type = c(L.next().kind, r, {}); - L.expect('('); - auto exp = parseExp(); - L.expect(')'); - prefix = Cast::create(r, Type(type), Expr(exp)); - } break; default: { Ident name = parseIdent(); prefix = Var::create(name.range(), name); @@ -286,19 +276,6 @@ struct Parser { } } } - TreeRef parseScalarType() { - switch (L.cur().kind) { - case TK_INT: - case TK_FLOAT: - case TK_LONG: - case TK_DOUBLE: { - auto t = L.next(); - return c(t.kind, t.range, {}); - } - default: - return parseIdent(); - } - } TreeRef parseOptionalIdentList() { TreeRef list = nullptr; if (L.cur().kind == '(') { diff --git a/torch/csrc/jit/script/tree_views.h b/torch/csrc/jit/script/tree_views.h index 77325dfc4e587..9e5c78b01e830 100644 --- a/torch/csrc/jit/script/tree_views.h +++ b/torch/csrc/jit/script/tree_views.h @@ -51,7 +51,6 @@ namespace script { // | Not TK_NOT // | USub '-' // | Const(String value) TK_CONST -// | Cast(ScalarType type, Expr expr) TK_CAST // -- NB: x.name(y) is desugared into name(x, y) // | Apply(Ident name, List args, List kwargs) TK_APPLY // | Select(Expr base, Ident attr_name) '.' @@ -71,10 +70,6 @@ namespace script { // | Mul() TK_TIMES_EQ // | Div() TK_DIV_EQ // -// ScalarType = IntType() TK_INT -// | FloatType() TK_FLOAT -// | LongType() TK_LONG -// | DoubleType() TK_DOUBLE // Each subclass of TreeView should provide: // 1. Constructor that takes a TreeRef, and checks that it's of the right type. @@ -321,20 +316,6 @@ struct TensorType : public Type { } }; -struct ScalarType : public TreeView { - explicit ScalarType(const TreeRef& tree) : TreeView(tree) { - switch (tree->kind()) { - case TK_INT: - case TK_LONG: - case TK_FLOAT: - case TK_DOUBLE: - return; - default: - throw ErrorReport(tree) << kindToString(tree->kind()) << " is not a valid ScalarType"; - } - } -}; - //////////////////////////////////////////////////////////////////////////////// // Top level definitions //////////////////////////////////////////////////////////////////////////////// @@ -580,21 +561,6 @@ struct Const : public Expr { } }; -struct Cast : public Expr { - explicit Cast(const TreeRef& tree) : Expr(tree) { - tree_->match(TK_CAST); - } - ScalarType type() const { - return ScalarType(subtree(0)); - } - Expr input() const { - return Expr(subtree(1)); - } - static Cast create(const SourceRange& range, const Type& type, const Expr& input) { - return Cast(Compound::create(TK_CAST, range, {type, input})); - } -}; - struct Apply : public Expr { explicit Apply(const TreeRef& tree) : Expr(tree) { tree_->match(TK_APPLY); diff --git a/torch/csrc/jit/stack.h b/torch/csrc/jit/stack.h index e4d1d185db5be..654c87088e02a 100644 --- a/torch/csrc/jit/stack.h +++ b/torch/csrc/jit/stack.h @@ -1,6 +1,6 @@ #pragma once #include "ATen/ATen.h" -#include "torch/csrc/jit/tensor_conversions.h" + #include "torch/csrc/jit/ivalue.h" namespace torch { namespace jit { @@ -42,21 +42,37 @@ static inline IValue pop(Stack & stack) { return r; } +// variadic pop: +// int64_t a; at::Tensor b; +// pop(stack, a, b); +// equivalent to: +// b = pop(stack).toTensor(); +// a = pop(stack).toInt(); +template +static inline void pop(Stack& stack, Types&... args) { + size_t i = 0; + constexpr size_t N = sizeof...(args); + int result[N] = { + (args = std::move(peek(stack,i++, N)).template to(),0)... + }; + (void) result; + drop(stack, N); +} +template +static inline void push(Stack& stack, Types... args) { + constexpr size_t N = sizeof...(args); + int result[N] = { + (stack.push_back(std::forward(args)), 0)... + }; + (void) result; +} + // The packer here is carefully written not to make any unnecessary // copies. // pack takes the return values of aten functions pushes them onto the stack template inline void pack(Stack & stack, T&& v) { - stack.push_back(IValue(as_variable(std::move(v)))); -} -template<> -inline void pack(Stack & stack, at::Tensor&& v) { - stack.push_back(IValue(std::move(v))); -} - -template<> -inline void pack(Stack & stack, autograd::Variable&& v) { stack.push_back(IValue(std::move(v))); } diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h index 12417390478a4..ff9e5149068a8 100644 --- a/torch/csrc/jit/symbolic_variable.h +++ b/torch/csrc/jit/symbolic_variable.h @@ -52,11 +52,6 @@ struct SymbolicVariable { return (int64_t) i == s.toLong(); } } - // TODO (apaszke): Use this instead of attribute setters - template - SymbolicVariable insertConstant(T value) const { - return v->owningGraph()->insertConstant(std::move(value)); - } SymbolicVariable operator*(const SymbolicVariable rhs) const { return create(aten::mul, {*this, rhs})[0].typeLike(*this); } diff --git a/torch/csrc/jit/tensor_conversions.h b/torch/csrc/jit/tensor_conversions.h deleted file mode 100644 index 36bfda79056a4..0000000000000 --- a/torch/csrc/jit/tensor_conversions.h +++ /dev/null @@ -1,149 +0,0 @@ -#pragma once -#include "ATen/ATen.h" - -#include -#include -#include "torch/csrc/autograd/variable.h" - -namespace torch { namespace jit { - -////////////////////////////////////////////////////////////////////////////////// -// Tensor -> T conversion -////////////////////////////////////////////////////////////////////////////////// -struct tensor_conversion_error : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -template -inline T tensor_as(at::Tensor t); - -namespace detail { - -template -struct tensor_as_impl {}; - -template -struct tensor_as_impl::value>::type> { - T operator()(at::Tensor&& t) { - // workaround for 1-dim 1-element pytorch tensors until zero-dim - // tensors are fully supported - if(t.ndimension() == 1 && t.size(0) == 1) { - t = t[0]; - } - return at::Scalar(t).to(); - } -}; - -template<> -struct tensor_as_impl { - bool operator()(at::Tensor&& t) { - return tensor_as(std::move(t)) != 0; - } -}; - -// this is an identity but is needed in constant_as in the compiler -template<> -struct tensor_as_impl { - at::Tensor operator()(at::Tensor&& t) { - return t; - } -}; - -template -struct tensor_as_impl> { - std::array operator()(at::Tensor&& t) { - throw tensor_conversion_error("tensor_as>: NYI"); - } -}; - -template<> -struct tensor_as_impl> { - std::vector operator()(at::Tensor&& t) { - if (t.type().scalarType() != at::ScalarType::Long) - throw tensor_conversion_error("Expected a LongTensor"); - if (t.dim() != 1) - throw tensor_conversion_error("Expected a 1D LongTensor"); - if (!t.is_contiguous()) - throw tensor_conversion_error("Expected a contiguous LongTensor"); - return std::vector(t.data(), t.data() + t.numel()); - } -}; - -template<> -struct tensor_as_impl { - at::Scalar operator()(at::Tensor&& t) { - return at::Scalar(t.view({})); - } -}; - -} - -template -inline T tensor_as(at::Tensor t) { - return detail::tensor_as_impl()(std::move(t)); -} - -////////////////////////////////////////////////////////////////////////////////// -// T -> Tensor conversion -////////////////////////////////////////////////////////////////////////////////// - -inline at::Tensor as_tensor(int64_t v) { - return at::Scalar(v).toTensor(); -} - -inline at::Tensor as_tensor(double v) { - return at::Scalar(v).toTensor(); -} - -inline at::Tensor as_tensor(bool v) { - return at::Scalar(v).toTensor(); -} - -inline at::Tensor as_tensor(at::IntList l) { - void* data = const_cast(reinterpret_cast(l.data())); - auto sizes = {static_cast(l.size())}; - return at::from_blob(data, sizes, at::kLong).clone(); -} - -inline at::Tensor as_tensor(const at::Scalar& s) { - return s.toTensor(); -} - -inline at::Tensor as_tensor(at::Tensor t) { - return t; -} - -template -inline at::Tensor as_tensor(std::array&& bools) { - auto r = at::empty({N}, at::kByte); - auto accessor = r.accessor(); - for(size_t i = 0; i < N; ++i) { - accessor[i] = bools[i]; - } - return r; -} - -template -inline at::Tensor as_variable(const T& t) { - return autograd::make_variable(as_tensor(t)); -} - -////////////////////////////////////////////////////////////////////////////////// -// Helper for retrieving constants -////////////////////////////////////////////////////////////////////////////////// - -// if a value is a constant then try to turn into type T using the -// same rules as the interpreter -template -at::optional constant_as(Value* v) { - if(v->node()->kind() != prim::Constant) - return at::nullopt; - auto tensor = v->node()->t(attr::value); - try { - return tensor_as(std::move(tensor)); - } catch (tensor_conversion_error& err) { - return at::nullopt; - } -} - -}} // namespace torch::jit diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index e93b3b1aeadec..4c976aeb64125 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -63,10 +63,12 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) { auto size_var = autograd::make_variable(at::Scalar(var.size(dim)).toTensor()); auto* value = getValueTrace(var); WithInsertPoint ipoint { graph->block() }; - auto* node = graph->insertNode(graph->create(aten::size, {value, graph->insertConstant(dim)})); - node->output()->inferTypeFrom(size_var); - setValueTrace(size_var, node->output()); + auto* node = graph->insertNode(graph->create(aten::size, {value, insertConstant(*graph, dim)})); + node->output()->setType(jit::IntType::get()); + auto ten = + graph->appendNode(graph->createNumToTensor(node->output()))->output(); + setValueTrace(size_var, ten); return size_var; } @@ -82,7 +84,13 @@ void ArgumentStash::stashIntListElem(const std::string& arg_name, size_t size, s JIT_ASSERT(size == list_trace.size()); JIT_ASSERT(idx < list_trace.size()); JIT_ASSERT(list_trace[idx] == nullptr); - list_trace[idx] = getValueTrace(var); + + Value* ten = getValueTrace(var); + auto& g = *ten->owningGraph(); + auto prim = g.createTensorToNum(jit::IntType::get(), ten) + ->insertAfter(ten->node()) + ->output(); + list_trace[idx] = prim; } //////////////////////////////////////////////////////////////////////////////// diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 7198dd1aba150..2987502986929 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -1,6 +1,7 @@ #pragma once #include "torch/csrc/jit/ir.h" +#include "torch/csrc/jit/constants.h" #include "torch/csrc/assertions.h" #include "torch/csrc/WindowsTorchApiMacro.h" #include "torch/csrc/utils/functional.h" @@ -120,7 +121,7 @@ inline Value* getValueTrace(const Variable& var) { auto & value_map = getTracingState()->value_map; auto it = value_map.find(var); if (it == value_map.end()) { - Value *constant = state->graph->appendNode(state->graph->createConstant(var.data()))->output(); + Value *constant = insertConstant(*state->graph, var.data()); constant->inferTypeFrom(var.data()); it = value_map.emplace_hint(it, var, constant); } diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index fe06ec7833a11..403d42582f871 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -3,6 +3,7 @@ #include "torch/csrc/jit/interned_strings.h" #include "torch/csrc/assertions.h" #include "torch/csrc/WindowsTorchApiMacro.h" +#include "torch/csrc/utils/functional.h" #include @@ -30,8 +31,7 @@ struct Type; using TypePtr = std::shared_ptr; -struct TORCH_API Type : std::enable_shared_from_this { - +struct TORCH_API Type { private: TypeKind kind_; @@ -79,9 +79,6 @@ struct TORCH_API Type : std::enable_shared_from_this { JIT_ASSERT(T::Kind == kind()); return static_cast(this); } - std::shared_ptr asShared() { - return shared_from_this(); - } virtual ~Type() {} }; @@ -169,6 +166,15 @@ struct TORCH_API TensorType : public Type { // don't want to reveal underlying size information. return "Tensor"; } + bool numel() const { + size_t prod = 1; + for(auto s : sizes()) { + prod *= s; + } + return prod; + } + static TypePtr fromNumberType(TypePtr typ); + private: static std::vector contiguousStridesOf(at::IntList sizes) { std::vector strides(sizes.size()); @@ -321,5 +327,31 @@ struct TORCH_API IntType : public Type { TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t); +// what is the type, ignoring extra size/shape information? +// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) + +inline TypePtr unshapedType(const TypePtr& type) { + if(TupleType* t = type->cast()) { + return std::make_shared(fmap(t->elements(), unshapedType)); + } else if(ListType* t = type->cast()) { + return std::make_shared(unshapedType(t->getElementType())); + } else if(type->kind() == TypeKind::TensorType) { + return DynamicType::get(); + } else { + return type; + } +} + +inline TypePtr TensorType::fromNumberType(TypePtr typ) { + JIT_ASSERT(typ->isSubtypeOf(*NumberType::get())); + if(typ->isSubtypeOf(*IntType::get())) { + TensorType tt(at::kLong, -1, {}); + return std::make_shared(std::move(tt)); + } else if(typ->isSubtypeOf(*FloatType::get())) { + TensorType tt(at::kFloat, -1, {}); + return std::make_shared(std::move(tt)); + } + AT_ERROR("unknown number type", typ->str()); +} }} // namespace torch::jit diff --git a/torch/distributions/__init__.py b/torch/distributions/__init__.py index de8c9e0267f76..9979f96e27a50 100644 --- a/torch/distributions/__init__.py +++ b/torch/distributions/__init__.py @@ -93,6 +93,7 @@ from .laplace import Laplace from .log_normal import LogNormal from .logistic_normal import LogisticNormal +from .lowrank_multivariate_normal import LowRankMultivariateNormal from .multinomial import Multinomial from .multivariate_normal import MultivariateNormal from .normal import Normal diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index caedb3e93a133..13cda93abafc8 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -19,7 +19,10 @@ from .half_normal import HalfNormal from .laplace import Laplace from .logistic_normal import LogisticNormal -from .multivariate_normal import MultivariateNormal, _batch_mahalanobis, _batch_diag, _batch_inverse +from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet, + _batch_lowrank_mahalanobis, _batch_vector_diag) +from .multivariate_normal import (MultivariateNormal, _batch_diag, _batch_mahalanobis, + _batch_trtrs_lower) from .normal import Normal from .one_hot_categorical import OneHotCategorical from .pareto import Pareto @@ -128,9 +131,10 @@ def _batch_trace_XXT(bmat): """ Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions """ - mat_size = bmat.size(-1) - flat_trace = bmat.reshape(-1, mat_size * mat_size).pow(2).sum(-1) - return flat_trace.view(bmat.shape[:-2]) + n = bmat.size(-1) + m = bmat.size(-2) + flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) + return flat_trace.reshape(bmat.shape[:-2]) def kl_divergence(p, q): @@ -290,6 +294,73 @@ def _kl_laplace_laplace(p, q): return t1 + t2 + t3 - 1 +@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) +def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\ + different event shapes cannot be computed") + + term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) - + _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril)) + term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc, + q._capacitance_tril) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) + # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) + qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2) + A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril) + term21 = (p.cov_diag / q.cov_diag).sum(-1) + term22 = _batch_trace_XXT(p.cov_factor * q.cov_diag.rsqrt().unsqueeze(-1)) + term23 = _batch_trace_XXT(A * p.cov_diag.sqrt().unsqueeze(-2)) + term24 = _batch_trace_XXT(A.matmul(p.cov_factor)) + term2 = term21 + term22 - term23 - term24 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(MultivariateNormal, LowRankMultivariateNormal) +def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed") + + term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) - + 2 * _batch_diag(p._unbroadcasted_scale_tril).log().sum(-1)) + term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc, + q._capacitance_tril) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T + # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T + qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2) + A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril) + term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * q.cov_diag.rsqrt().unsqueeze(-1)) + term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) + term2 = term21 - term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(LowRankMultivariateNormal, MultivariateNormal) +def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed") + + term1 = (2 * _batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) - + _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril)) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + # Expands term2 according to + # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) + combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2], + p.cov_factor.shape[:-2]) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_cov_factor = p.cov_factor.expand(combined_batch_shape + (n, p.cov_factor.size(-1))) + p_cov_diag = _batch_vector_diag(p.cov_diag.sqrt()).expand(combined_batch_shape + (n, n)) + term21 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_factor, q_scale_tril)) + term22 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_diag, q_scale_tril)) + term2 = term21 + term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + @register_kl(MultivariateNormal, MultivariateNormal) def _kl_multivariatenormal_multivariatenormal(p, q): # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence @@ -297,10 +368,16 @@ def _kl_multivariatenormal_multivariatenormal(p, q): raise ValueError("KL-divergence between two Multivariate Normals with\ different event shapes cannot be computed") - term1 = _batch_diag(q.scale_tril).log().sum(-1) - _batch_diag(p.scale_tril).log().sum(-1) - term2 = _batch_trace_XXT(torch.matmul(_batch_inverse(q.scale_tril), p.scale_tril)) - term3 = _batch_mahalanobis(q.scale_tril, (q.loc - p.loc)) - return term1 + 0.5 * (term2 + term3 - p.event_shape[0]) + half_term1 = (_batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) - + _batch_diag(p._unbroadcasted_scale_tril).log().sum(-1)) + combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2], + p._unbroadcasted_scale_tril.shape[:-2]) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + term2 = _batch_trace_XXT(_batch_trtrs_lower(p_scale_tril, q_scale_tril)) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + return half_term1 + 0.5 * (term2 + term3 - n) @register_kl(Normal, Normal) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py new file mode 100644 index 0000000000000..d5f3d9c98d058 --- /dev/null +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -0,0 +1,180 @@ +import math + +import torch +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.multivariate_normal import (_batch_diag, _batch_mahalanobis, _batch_mv, + _batch_potrf_lower, _batch_trtrs_lower, + _get_batch_shape) +from torch.distributions.utils import lazy_property + + +def _batch_vector_diag(bvec): + """ + Returns the diagonal matrices of a batch of vectors. + """ + n = bvec.size(-1) + bmat = bvec.new_zeros(bvec.shape + (n,)) + bmat.view(bvec.shape[:-1] + (-1,))[..., ::n + 1] = bvec + return bmat + + +def _batch_capacitance_tril(W, D): + r""" + Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` + and a batch of vectors :math:`D`. + """ + m = W.size(-1) + Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2) + K = torch.matmul(Wt_Dinv, W).contiguous() + K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K + return _batch_potrf_lower(K) + + +def _batch_lowrank_logdet(W, D, capacitance_tril): + r""" + Uses "matrix determinant lemma":: + log|W @ W.T + D| = log|C| + log|D|, + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute + the log determinant. + """ + return 2 * _batch_diag(capacitance_tril).log().sum(-1) + D.log().sum(-1) + + +def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): + r""" + Uses "Woodbury matrix identity":: + inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared + Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. + """ + Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2) + Wt_Dinv_x = _batch_mv(Wt_Dinv, x) + mahalanobis_term1 = (x.pow(2) / D).sum(-1) + mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) + return mahalanobis_term1 - mahalanobis_term2 + + +class LowRankMultivariateNormal(Distribution): + r""" + Creates a multivariate normal distribution with covariance matrix having a low-rank form + parameterized by `cov_factor` and `cov_diag`:: + covariance_matrix = cov_factor @ cov_factor.T + cov_diag + + Example: + + >>> m = MultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1])) + >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` + cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape + `batch_shape + event_shape + (rank,)` + cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape + `batch_shape + event_shape` + + Note: + The computation for determinant and inverse of covariance matrix is avoided when + `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity + `_ and + `matrix determinant lemma `_. + Thanks to these formulas, we just need to compute the determinant and inverse of + the small size "capacitance" matrix:: + capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor + """ + arg_constraints = {"loc": constraints.real, + "cov_factor": constraints.real, + "cov_diag": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + event_shape = loc.shape[-1:] + if cov_factor.dim() < 2: + raise ValueError("cov_factor must be at least two-dimensional, " + "with optional leading batch dimensions") + if cov_factor.shape[-2:-1] != event_shape: + raise ValueError("cov_factor must be a batch of matrices with shape {} x m" + .format(event_shape[0])) + if cov_diag.shape[-1:] != event_shape: + raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape)) + + scale_batch_shape = _get_batch_shape(cov_factor, cov_diag) + try: + batch_shape = torch._C._infer_size(loc.shape[:-1], scale_batch_shape) + except RuntimeError: + raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}" + .format(loc.shape, cov_factor.shape, cov_diag.shape)) + + loc_shape = batch_shape + event_shape + self.loc = loc.expand(loc_shape) + self.cov_factor = cov_factor.expand(loc_shape + cov_factor.shape[-1:]) + self.cov_diag = cov_diag.expand(loc_shape) + self._capacitance_tril = _batch_capacitance_tril(self.cov_factor, self.cov_diag) + super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape, + validate_args=validate_args) + + @property + def mean(self): + return self.loc + + @property + def variance(self): + return self.cov_factor.pow(2).sum(-1) + self.cov_diag + + @lazy_property + def scale_tril(self): + # The following identity is used to increase the numerically computation stability + # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): + # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 + # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, + # hence it is well-conditioned and safe to take Cholesky decomposition. + n = self._event_shape[0] + cov_diag_sqrt_unsqueeze = self.cov_diag.sqrt().unsqueeze(-1) + Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze + K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous() + K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K + return cov_diag_sqrt_unsqueeze * _batch_potrf_lower(K) + + @lazy_property + def covariance_matrix(self): + return (torch.matmul(self.cov_factor, self.cov_factor.transpose(-1, -2)) + + _batch_vector_diag(self.cov_diag)) + + @lazy_property + def precision_matrix(self): + # We use "Woodbury matrix identity" to take advantage of low rank form:: + # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) + # where :math:`C` is the capacitance matrix. + Wt_Dinv = self.cov_factor.transpose(-1, -2) / self.cov_diag.unsqueeze(-2) + A = _batch_trtrs_lower(Wt_Dinv, self._capacitance_tril) + return (_batch_vector_diag(self.cov_diag.reciprocal()) - + torch.matmul(A.transpose(-1, -2), A)) + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + eps_W = self.loc.new_empty(shape[:-1] + (self.cov_factor.size(-1),)).normal_() + eps_D = self.loc.new_empty(shape).normal_() + return self.loc + _batch_mv(self.cov_factor, eps_W) + self.cov_diag.sqrt() * eps_D + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_lowrank_mahalanobis(self.cov_factor, self.cov_diag, diff, + self._capacitance_tril) + log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag, + self._capacitance_tril) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) + + def entropy(self): + log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag, + self._capacitance_tril) + H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index d8bf4d0b4e5dd..72300ee4e0039 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,5 +1,4 @@ import math -from numbers import Number import torch from torch.distributions import constraints @@ -15,7 +14,7 @@ def _get_batch_shape(bmat, bvec): vec_shape = torch._C._infer_size(bvec.shape, bmat.shape[:-1]) except RuntimeError: raise ValueError("Incompatible batch shapes: vector {}, matrix {}".format(bvec.shape, bmat.shape)) - return torch.Size(vec_shape[:-1]) + return vec_shape[:-1] def _batch_mv(bmat, bvec): @@ -29,13 +28,7 @@ def _batch_mv(bmat, bvec): to a batch shape. They are not necessarily assumed to have the same batch shape, just ones which can be broadcasted. """ - n = bvec.size(-1) - batch_shape = _get_batch_shape(bmat, bvec) - - # to conform with `torch.bmm` interface, both bmat and bvec should have `.dim() == 3` - bmat = bmat.expand(batch_shape + (n, n)).reshape((-1, n, n)) - bvec = bvec.unsqueeze(-1).expand(batch_shape + (n, 1)).reshape((-1, n, 1)) - return torch.bmm(bmat, bvec).view(batch_shape + (n,)) + return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) def _batch_potrf_lower(bmat): @@ -43,15 +36,15 @@ def _batch_potrf_lower(bmat): Applies a Cholesky decomposition to all matrices in a batch of arbitrary shape. """ n = bmat.size(-1) - cholesky = torch.stack([C.potrf(upper=False) for C in bmat.reshape((-1, n, n))]) - return cholesky.view(bmat.shape) + cholesky = torch.stack([m.potrf(upper=False) for m in bmat.reshape(-1, n, n)]) + return cholesky.reshape(bmat.shape) def _batch_diag(bmat): r""" Returns the diagonals of a batch of square matrices. """ - return bmat.reshape(bmat.shape[:-2] + (-1,))[..., ::bmat.size(-1) + 1] + return torch.diagonal(bmat, dim1=-2, dim2=-1) def _batch_inverse(bmat): @@ -59,22 +52,36 @@ def _batch_inverse(bmat): Returns the inverses of a batch of square matrices. """ n = bmat.size(-1) - flat_bmat = bmat.reshape(-1, n, n) - flat_inv_bmat = torch.stack([m.inverse() for m in flat_bmat], 0) - return flat_inv_bmat.view(bmat.shape) + flat_bmat_inv = torch.stack([m.inverse() for m in bmat.reshape(-1, n, n)]) + return flat_bmat_inv.reshape(bmat.shape) + + +def _batch_trtrs_lower(bb, bA): + """ + Applies `torch.trtrs` for batches of matrices. `bb` and `bA` should have + the same batch shape. + """ + flat_b = bb.reshape((-1,) + bb.shape[-2:]) + flat_A = bA.reshape((-1,) + bA.shape[-2:]) + flat_X = torch.stack([torch.trtrs(b, A, upper=False)[0] for b, A in zip(flat_b, flat_A)]) + return flat_X.reshape(bb.shape) -def _batch_mahalanobis(L, x): +def _batch_mahalanobis(bL, bx): r""" Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. - Accepts batches for both L and x. + Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch + shape, but `bL` one should be able to broadcasted to `bx` one. """ - # TODO: use `torch.potrs` or similar once a backwards pass is implemented. - flat_L = L.unsqueeze(0).reshape((-1,) + L.shape[-2:]) - L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape) - return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1) + n = bx.size(-1) + bL = bL.expand(bx.shape[bx.dim() - bL.dim() + 1:] + (n,)) + flat_L = bL.reshape(-1, n, n) # shape = b x n x n + flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n + flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c + M_swap = _batch_trtrs_lower(flat_x_swap, flat_L).pow(2).sum(-2) # shape = b x c + return M_swap.t().reshape(bx.shape[:-1]) class MultivariateNormal(Distribution): @@ -120,45 +127,54 @@ class MultivariateNormal(Distribution): def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None): if loc.dim() < 1: - loc = loc.unsqueeze(0) - event_shape = torch.Size(loc.shape[-1:]) + raise ValueError("loc must be at least one-dimensional.") + event_shape = loc.shape[-1:] if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1: raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.") if scale_tril is not None: if scale_tril.dim() < 2: raise ValueError("scale_tril matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.scale_tril = scale_tril + self._unbroadcasted_scale_tril = scale_tril batch_shape = _get_batch_shape(scale_tril, loc) + self.scale_tril = scale_tril.expand(batch_shape + event_shape + event_shape) elif covariance_matrix is not None: if covariance_matrix.dim() < 2: raise ValueError("covariance_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.covariance_matrix = covariance_matrix + self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix) batch_shape = _get_batch_shape(covariance_matrix, loc) + self.covariance_matrix = covariance_matrix.expand(batch_shape + event_shape + event_shape) else: if precision_matrix.dim() < 2: raise ValueError("precision_matrix must be at least two-dimensional, " "with optional leading batch dimensions") - self.precision_matrix = precision_matrix - self.covariance_matrix = _batch_inverse(precision_matrix) + covariance_matrix = _batch_inverse(precision_matrix) + self._unbroadcasted_scale_tril = _batch_potrf_lower(covariance_matrix) batch_shape = _get_batch_shape(precision_matrix, loc) - self.loc = loc + self.precision_matrix = precision_matrix.expand(batch_shape + event_shape + event_shape) + self.covariance_matrix = covariance_matrix.expand(batch_shape + event_shape + event_shape) + + self.loc = loc.expand(batch_shape + event_shape) super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args) @lazy_property def scale_tril(self): - return _batch_potrf_lower(self.covariance_matrix) + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape) @lazy_property def covariance_matrix(self): - return torch.matmul(self.scale_tril, self.scale_tril.transpose(-1, -2)) + return (torch.matmul(self._unbroadcasted_scale_tril, + self._unbroadcasted_scale_tril.transpose(-1, -2)) + .expand(self._batch_shape + self._event_shape + self._event_shape)) @lazy_property def precision_matrix(self): # TODO: use `torch.potri` on `scale_tril` once a backwards pass is implemented. - scale_tril_inv = _batch_inverse(self.scale_tril) - return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv) + scale_tril_inv = _batch_inverse(self._unbroadcasted_scale_tril) + return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv).expand( + self._batch_shape + self._event_shape + self._event_shape) @property def mean(self): @@ -166,26 +182,25 @@ def mean(self): @property def variance(self): - n = self.covariance_matrix.size(-1) - var = torch.stack([cov.diag() for cov in self.covariance_matrix.view(-1, n, n)]) - return var.view(self.covariance_matrix.size()[:-1]) + return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand( + self._batch_shape + self._event_shape) def rsample(self, sample_shape=torch.Size()): shape = self._extended_shape(sample_shape) - eps = self.loc.new(*shape).normal_() - return self.loc + _batch_mv(self.scale_tril, eps) + eps = self.loc.new_empty(shape).normal_() + return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps) def log_prob(self, value): if self._validate_args: self._validate_sample(value) diff = value - self.loc - M = _batch_mahalanobis(self.scale_tril, diff) - log_det = _batch_diag(self.scale_tril).abs().log().sum(-1) - return -0.5 * (M + self.loc.size(-1) * math.log(2 * math.pi)) - log_det + M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) + half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det def entropy(self): - log_det = _batch_diag(self.scale_tril).abs().log().sum(-1) - H = 0.5 * (1.0 + math.log(2 * math.pi)) * self._event_shape[0] + log_det + half_log_det = _batch_diag(self._unbroadcasted_scale_tril).log().sum(-1) + H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det if len(self._batch_shape) == 0: return H else: diff --git a/torch/jit/batchop.py b/torch/jit/batchop.py index cfad94a03e820..bda6a3adca3a8 100644 --- a/torch/jit/batchop.py +++ b/torch/jit/batchop.py @@ -81,7 +81,9 @@ def batch_matmul(data1, mask1, dims1, data2, mask2, dims2): @torch.jit.script -def batch_select(data, mask, dims, dim, index): +def batch_select(data, mask, dims, dim_, index_): + dim = int(dim_) + index = int(index_) # if dim == 0: # raise ValueError("Cannot select 0 dim in BatchTensor") data = data.select(dim, index) diff --git a/torch/lib/THD/master_worker/master/generic/THDTensor.cpp b/torch/lib/THD/master_worker/master/generic/THDTensor.cpp index 93dd5d4b7246a..5e579ebe6287a 100644 --- a/torch/lib/THD/master_worker/master/generic/THDTensor.cpp +++ b/torch/lib/THD/master_worker/master/generic/THDTensor.cpp @@ -1038,7 +1038,7 @@ void THDTensor_(addcdiv)(THDTensor *self, THDTensor *src1, real value, THDTensor void THDTensor_(addmv)(THDTensor *self, real beta, THDTensor *src, real alpha, THDTensor *mat, THDTensor *vec) { if ((mat->nDimension != 2) || (vec->nDimension != 1)) - THError("matrix and vector expected, got %dD, %dD", mat->nDimension, vec->nDimension); + THError("2D tensor and 1D tensor expected, got %dD, %dD tensors", mat->nDimension, vec->nDimension); if (mat->size[1] != vec->size[0]) { THDDescBuff bm = THDTensor_(sizeDesc)(mat); @@ -1047,7 +1047,7 @@ void THDTensor_(addmv)(THDTensor *self, real beta, THDTensor *src, real alpha, T } if (src->nDimension != 1) - THError("vector expected, got src: %dD", src->nDimension); + THError("1D tensor expected, got src: %dD tensor", src->nDimension); if (src->size[0] != mat->size[0]) { THDDescBuff bt = THDTensor_(sizeDesc)(src); @@ -1067,7 +1067,7 @@ void THDTensor_(addmv)(THDTensor *self, real beta, THDTensor *src, real alpha, T void THDTensor_(addmm)(THDTensor *self, real beta, THDTensor *src, real alpha, THDTensor *mat1, THDTensor *mat2) { if ((mat1->nDimension != 2) || (mat2->nDimension != 2)) - THError("matrices expected, got %dD, %dD tensors", mat1->nDimension, mat2->nDimension); + THError("2D tensors expected, got %dD, %dD tensors", mat1->nDimension, mat2->nDimension); if (mat1->size[1] != mat2->size[0]) { THDDescBuff bm1 = THDTensor_(sizeDesc)(mat1); @@ -1076,7 +1076,7 @@ void THDTensor_(addmm)(THDTensor *self, real beta, THDTensor *src, real alpha, T } if (src->nDimension != 2) - THError("matrix expected, got %dD tensor for t", src->nDimension); + THError("2D tensors expected, got %dD tensor for t", src->nDimension); if ((src->size[0] != mat1->size[0]) || (src->size[1] != mat2->size[1])) { THDDescBuff bt = THDTensor_(sizeDesc)(src); @@ -1246,7 +1246,7 @@ void THDTensor_(sign)(THDTensor *self, THDTensor *src) { } accreal THDTensor_(trace)(THDTensor *self) { - THArgCheck(self->nDimension == 2, 1, "expected a matrix"); + THArgCheck(self->nDimension == 2, 1, "expected a 2D tensor"); masterCommandChannel->sendMessage( packMessage(Functions::tensorTrace, self), diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 0b72453cdc1e2..4ac6ca7f88719 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -48,9 +48,11 @@ def _if_scalar_type_as(self, tensor): """ if isinstance(self, torch._C.Value): return self - else: + elif tensor.type().kind() == "TensorType": ty = tensor.type().scalarType().lower() return getattr(self, ty)() + else: + return self def _broadcast_if_scalar(x): diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 14a291eda144d..59d567f461789 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -95,9 +95,12 @@ def export(model, args, f, export_params=True, verbose=False, training=False, def _optimize_graph(graph, operator_export_type): - # run dce first to eliminate dead parts of the graph that might have been - # left behind by things like symbolic_override + # onnx only supports tensors, so we turn all out number types into tensors + torch._C._jit_pass_erase_number_types(graph) + + # run dce to eliminate dead parts of the graph that might have been + # left behind by things like symbolic_override torch._C._jit_pass_dce(graph) torch._C._jit_pass_lint(graph) @@ -168,7 +171,6 @@ def _model_to_graph(model, args, f, verbose=False, training=False, graph = method.propagate_and_assign_input_and_output_shapes( args, example_outputs, False, propagate) # Erase number types to bring the graph to a pre-NumberType state - torch._C._jit_pass_erase_number_types(graph) params = method.params() except AttributeError: # TODO: just trace it @@ -451,7 +453,9 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor elif ns == "prim": if op_name == "Constant": return g.op("Constant", value_t=n["value"]) - + elif op_name == "ListConstruct": + unsqueezed = [g.op("Unsqueeze", input, axes_i=[0]) for input in inputs] + return g.op("Concat", *unsqueezed, axis_i=0) elif op_name == "Undefined": # Undefined is not an ONNX operator; keep it as prim::Undefined # and let the exporter handle finally eliminating these