Skip to content

Merge from upstream #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b70b706
Keep kEps in one place to make sure they are consistent (#10334)
bairdzhang Aug 9, 2018
0ace3b3
Remove outdated note about CUDA vs ROCm.
iotamudelta Aug 9, 2018
52d85be
Deal with undefined tensors in unbind backward (#9995)
t-vi Aug 9, 2018
4c06050
Dropout does now compile on ROCm w/ the recent rocRAND integration (and
iotamudelta Aug 9, 2018
cf13fd7
Merge remote-tracking branch 'rocm_upstream/master'
iotamudelta Aug 9, 2018
e967fa9
Fix THTensor_nElement for scalars.
gchanan Aug 9, 2018
7d53c87
Move maybeZeroDim to TH, change condition so it doesn't turn off scal…
gchanan Aug 9, 2018
99b10ad
Fix compile flags for MSVC
peterjc123 Aug 9, 2018
3fa1c10
Avoid std::thread ctor "cannot resolve" error (#10381)
ssnl Aug 9, 2018
cc5b47f
Fix the logic for PATH guess on Windows
peterjc123 Aug 9, 2018
b43beec
Fix bincount for empty input (#9757)
vishwakftw Aug 9, 2018
64a6003
Don't copy on clamp, clamp_out (#10352)
t-vi Aug 9, 2018
18d2fcd
Fix performance of DistributedSampler per #8958
Aug 9, 2018
209af45
Back out "[pytorch][PR] Fix bincount for empty input"
gchanan Aug 9, 2018
b1e3239
Fix some backwards definitions wrt keepdim. (#10382)
gchanan Aug 9, 2018
0950d7a
support list slicing (#10318)
suo Aug 10, 2018
e9ad743
Use serialization container in ir import export (#10394)
Aug 10, 2018
7ac442c
Merge remote-tracking branch 'rocm_upstream/master'
iotamudelta Aug 10, 2018
6260298
Merge remote-tracking branch 'upstream/master'
iotamudelta Aug 10, 2018
30ebd41
Merge remote-tracking branch 'rocm_upstream/master'
iotamudelta Aug 10, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -2267,39 +2267,42 @@
- THTensor* other
]]
[[
name: _th_clamp_
name: _th_clamp
cname: clamp
variants:
- method
- function
return: argument 0
arguments:
- THTensor* self
- arg: THTensor* result
output: True
- THTensor* self
- real min
- real max
]]
[[
name: _th_clamp_min_
name: _th_clamp_min
cname: cmaxValue
variants:
- method
- function
return: argument 0
arguments:
- THTensor* self
- arg: THTensor* result
output: True
- THTensor* self
- real min
]]
[[
name: _th_clamp_max_
name: _th_clamp_max
cname: cminValue
variants:
- method
- function
return: argument 0
arguments:
- THTensor* self
- arg: THTensor* result
output: True
- THTensor* self
- real max
]]
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/TensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ int64_t TensorImpl::dim() const {

TensorImpl* TensorImpl::maybe_zero_dim(bool condition_when_zero_dim) {
AT_CHECK(tensor, "TensorImpl without THTensor in maybe_zero_dim");
bool is_zero_dim = condition_when_zero_dim && tensor->sizes().size() == 1 && tensor->size(0) == 1;
THTensor_setIsZeroDim(tensor, is_zero_dim);
THTensor_maybe_zero_dim(tensor, condition_when_zero_dim);
return this;
}

Expand Down
26 changes: 10 additions & 16 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ Tensor clamp_min(const Tensor& self, Scalar min) {

Tensor& _clamp__cpu(Tensor& self, Scalar min, Scalar max) {
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
return _th_clamp_(self, min, max);
return _th_clamp_out(self, self, min, max);
} else if (std::isnan(min.toDouble())) {
return _th_clamp_max_(self, max);
return _th_clamp_max_out(self, self, max);
} else if (std::isnan(max.toDouble())) {
return _th_clamp_min_(self, min);
return _th_clamp_min_out(self, self, min);
} else {
return self;
}
Expand All @@ -62,36 +62,30 @@ Tensor& _clamp_out_cpu(
const Tensor& self,
Scalar min,
Scalar max) {
result.resize_(self.sizes());
result.copy_(self);
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
_th_clamp_(result, min, max);
_th_clamp_out(result, self, min, max);
} else if (std::isnan(min.toDouble())) {
_th_clamp_max_(result, max);
_th_clamp_max_out(result, self, max);
} else if (std::isnan(max.toDouble())) {
_th_clamp_min_(result, min);
_th_clamp_min_out(result, self, min);
}
return result;
}

Tensor& _clamp_max__cpu(Tensor& self, Scalar max) {
return _th_clamp_max_(self, max);
return _th_clamp_max_out(self, self, max);
}

Tensor& _clamp_max_out_cpu(Tensor& result, const Tensor& self, Scalar max) {
result.resize_(self.sizes());
result.copy_(self);
return _th_clamp_max_(result, max);
return _th_clamp_max_out(result, self, max);
}

Tensor& _clamp_min__cpu(Tensor& self, Scalar min) {
return _th_clamp_min_(self, min);
return _th_clamp_min_out(self, self, min);
}

Tensor& _clamp_min_out_cpu(Tensor& result, const Tensor& self, Scalar min) {
result.resize_(self.sizes());
result.copy_(self);
return _th_clamp_min_(result, min);
return _th_clamp_min_out(result, self, min);
}

Tensor& fill_(Tensor& self, Scalar value) {
Expand Down
26 changes: 10 additions & 16 deletions aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ namespace at { namespace native {

Tensor& _clamp__cuda(Tensor& self, Scalar min, Scalar max) {
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
return _th_clamp_(self, min, max);
return _th_clamp_out(self, self, min, max);
} else if (std::isnan(min.toDouble())) {
return _th_clamp_max_(self, max);
return _th_clamp_max_out(self, self, max);
} else if (std::isnan(max.toDouble())) {
return _th_clamp_min_(self, min);
return _th_clamp_min_out(self, self, min);
} else {
return self;
}
Expand All @@ -19,36 +19,30 @@ Tensor& _clamp_out_cuda(
const Tensor& self,
Scalar min,
Scalar max) {
result.resize_(self.sizes());
result.copy_(self);
if (!std::isnan(min.toDouble()) && !std::isnan(max.toDouble())) {
_th_clamp_(result, min, max);
_th_clamp_out(result, self, min, max);
} else if (std::isnan(min.toDouble())) {
_th_clamp_max_(result, max);
_th_clamp_max_out(result, self, max);
} else if (std::isnan(max.toDouble())) {
_th_clamp_min_(result, min);
_th_clamp_min_out(result, self, min);
}
return result;
}

Tensor& _clamp_max__cuda(Tensor& self, Scalar max) {
return _th_clamp_max_(self, max);
return _th_clamp_max_out(self, self, max);
}

Tensor& _clamp_max_out_cuda(Tensor& result, const Tensor& self, Scalar max) {
result.resize_(self.sizes());
result.copy_(self);
return _th_clamp_max_(result, max);
return _th_clamp_max_out(result, self, max);
}

Tensor& _clamp_min__cuda(Tensor& self, Scalar min) {
return _th_clamp_min_(self, min);
return _th_clamp_min_out(self, self, min);
}

Tensor& _clamp_min_out_cuda(Tensor& result, const Tensor& self, Scalar min) {
result.resize_(self.sizes());
result.copy_(self);
return _th_clamp_min_(result, min);
return _th_clamp_min_out(result, self, min);
}

// These are just forwarding stubs
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

namespace at { namespace native {

// NOTE: CUDA requires func_t to be passed by value, while ROCm fails to compile
// unless it's passed as a const reference.

template<int nt, int vt, typename func_t>
__launch_bounds__(nt, 4)
__global__ void elementwise_kernel(int N, func_t f) {
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/test/tbb_init_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
#include "test_seed.h"
#include <thread>

using namespace at;

// This checks whether threads can see the global
// numbers of threads set and also whether the scheduler
// will throw an exception when multiple threads call
// their first parallel construct.
void test(int given_num_threads) {
auto t = ones({1000 * 1000}, CPU(kFloat));
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
if (given_num_threads >= 0) {
ASSERT(at::get_num_threads() == given_num_threads);
} else {
Expand Down
6 changes: 5 additions & 1 deletion aten/src/TH/THTensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ inline THStorage* THTensor_getStoragePtr(const THTensor* tensor) {
return tensor->storage_;
}


inline bool THTensor_isZeroDim(const THTensor *tensor) {
return tensor->is_zero_dim_;
}
Expand All @@ -136,6 +135,11 @@ inline void THTensor_setIsZeroDim(THTensor *tensor, bool is_zero_dim) {
tensor->is_zero_dim_ = is_zero_dim;
}

inline void THTensor_maybe_zero_dim(THTensor *tensor, bool condition_when_zero_dim) {
bool is_zero_dim = (condition_when_zero_dim && tensor->sizes().size() == 1 && tensor->size(0) == 1) || tensor->dim() == 0;
THTensor_setIsZeroDim(tensor, is_zero_dim);
}

// [NOTE: nDimension vs nDimensionLegacyNoScalars vs nDimensionLegacyAll]
// nDimension corresponds to the "true" ATen dimension. TODO: implement.
// nDimensionLegacyNoScalars correpsonds to the ATen dimension, except scalars are viewed as 1-dimensional tensors.
Expand Down
2 changes: 1 addition & 1 deletion aten/src/TH/generic/THTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ ptrdiff_t THTensor_(nElement)(const THTensor *self)
{
ptrdiff_t nElement = 1;
int d;
for(d = 0; d < THTensor_nDimensionLegacyAll(self); d++)
for(d = 0; d < THTensor_nDimension(self); d++)
nElement *= self->size(d);
return nElement;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ ptrdiff_t THCTensor_nElement(THCState *state, const THCTensor *self) {
{
ptrdiff_t nElement = 1;
int d;
for(d = 0; d < THTensor_nDimensionLegacyAll(self); d++)
for(d = 0; d < THTensor_nDimension(self); d++)
nElement *= self->size(d);
return nElement;
}
Expand Down
6 changes: 2 additions & 4 deletions caffe2/operators/normalize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ void NormalizeOp<T, Context>::DoNormalize(
const int m,
const int n,
const int sf) {
const T kEps = 1e-12f;
using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
using StridedVec =
Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
Expand All @@ -23,7 +22,7 @@ void NormalizeOp<T, Context>::DoNormalize(
auto base = (i / sf) * sf * m + (i % sf);
ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
auto norm = xVec.template lpNorm<2>();
norm = std::max(norm, kEps);
norm = std::max(norm, kEps_);
StridedVec yVec(yData + base, 1, m, InnerStride(sf));
yVec = xVec / norm;
}
Expand All @@ -37,7 +36,6 @@ void NormalizeGradientOp<T, Context>::DoNormalize(
const int m,
const int n,
const int sf) {
const T kEps = 1e-12f;
using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
using StridedVec =
Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
Expand All @@ -51,7 +49,7 @@ void NormalizeGradientOp<T, Context>::DoNormalize(

auto row_sum = xVec.dot(gOutVec);
auto row_norm = xVec.template lpNorm<2>();
row_norm = std::max(row_norm, kEps);
row_norm = std::max(row_norm, kEps_);
auto row_norm_3 = pow(row_norm, 3);
StridedVec gInVec(gInData + base, 1, m, InnerStride(sf));
gInVec = (gOutVec / row_norm) - ((xVec / row_norm_3) * row_sum);
Expand Down
4 changes: 4 additions & 0 deletions caffe2/operators/normalize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

#define KEPS 1e-12f

namespace caffe2 {

template <typename T, class Context>
Expand All @@ -31,6 +33,7 @@ class NormalizeOp final : public Operator<Context> {
}

private:
const T kEps_ = KEPS;
void
DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
};
Expand Down Expand Up @@ -62,6 +65,7 @@ class NormalizeGradientOp final : public Operator<Context> {
}

private:
const T kEps_ = KEPS;
void DoNormalize(
const T* xData,
const T* gOutData,
Expand Down
13 changes: 7 additions & 6 deletions caffe2/operators/normalize_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ __global__ void NormalizeKernel(
const int n,
const int sf,
const float* xData,
float* yData) {
const float kEps = 1e-12f;
float* yData,
const float kEps) {
typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ BlockReduce::TempStorage temp_storage;

Expand Down Expand Up @@ -45,8 +45,8 @@ __global__ void NormalizeGradientKernel(
const int SF,
const float* in_mat,
const float* grad_out_mat,
float* grad_mat) {
const float kEps = 1e-12f;
float* grad_mat,
const float kEps) {
typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ BlockReduce::TempStorage temp_storage_sum;
__shared__ BlockReduce::TempStorage temp_storage_norm;
Expand Down Expand Up @@ -92,7 +92,7 @@ void NormalizeOp<float, CUDAContext>::DoNormalize(
min(n, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(m, n, sf, xData, yData);
context_.cuda_stream()>>>(m, n, sf, xData, yData, kEps_);
}

template <>
Expand All @@ -117,7 +117,8 @@ bool NormalizeGradientOp<float, CUDAContext>::RunOnDevice() {
SF,
X.data<float>(),
dY.data<float>(),
dX->template mutable_data<float>());
dX->template mutable_data<float>(),
kEps_);
return true;
}

Expand Down
12 changes: 6 additions & 6 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -931,12 +931,12 @@ if (BUILD_ATEN)
OPTION(NDEBUG "disable asserts (WARNING: this may result in silent UB e.g. with out-of-bound indices)")
IF (NOT NDEBUG)
MESSAGE(STATUS "Removing -DNDEBUG from compile flags")
STRING(REPLACE "-DNDEBUG" "" CMAKE_C_FLAGS "" ${CMAKE_C_FLAGS})
STRING(REPLACE "-DNDEBUG" "" CMAKE_C_FLAGS_DEBUG "" ${CMAKE_C_FLAGS_DEBUG})
STRING(REPLACE "-DNDEBUG" "" CMAKE_C_FLAGS_RELEASE "" ${CMAKE_C_FLAGS_RELEASE})
STRING(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS "" ${CMAKE_CXX_FLAGS})
STRING(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS_DEBUG "" ${CMAKE_CXX_FLAGS_DEBUG})
STRING(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "" ${CMAKE_CXX_FLAGS_RELEASE})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS "" ${CMAKE_C_FLAGS})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS_DEBUG "" ${CMAKE_C_FLAGS_DEBUG})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_C_FLAGS_RELEASE "" ${CMAKE_C_FLAGS_RELEASE})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS "" ${CMAKE_CXX_FLAGS})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS_DEBUG "" ${CMAKE_CXX_FLAGS_DEBUG})
STRING(REGEX REPLACE "[-/]DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "" ${CMAKE_CXX_FLAGS_RELEASE})
ENDIF()

# OpenMP support?
Expand Down
10 changes: 10 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,15 @@ def test_unbind(self):
grad = torch.randn(3, 10, 10)
torch.autograd.backward([x, y, z], grad.unbind())
self.assertEqual(stacked.grad.data, grad)
# check that it works with only one gradient provided (#9977)
for i in range(3):
stacked = torch.randn(3, 10, 10, requires_grad=True)
outs = stacked.unbind()
gi = grad.unbind()[i]
g, = torch.autograd.grad(outs[i], stacked, gi)
g_expected = torch.stack([gi if j == i else torch.zeros_like(gi)
for j in range(3)], dim=0)
self.assertEqual(g, g_expected)

def test_put(self):
root = torch.randn(4, 5, requires_grad=True)
Expand Down Expand Up @@ -2953,6 +2962,7 @@ class dont_convert(tuple):
('zero_', (S, S, S), NO_ARGS),
('zero_', (), NO_ARGS, 'scalar'),
('logsumexp', (S, S), (1,)),
('logsumexp', (), (0,), 'scalar'),
('norm', (S, S), (2,)),
('norm', (S, S), (0,), '0'),
('norm', (S, S), (0.5,), '0_5'),
Expand Down
Loading