Skip to content

Merge from upstream #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Jul 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bbeae24
Fix Eigen issue on OS X with CUDA and nvcc compile (#9270)
orionr Jul 10, 2018
8e6e809
Revert D8768025: [pytorch][PR] Fix Eigen issue on OS X with CUDA and …
Jul 10, 2018
d867757
Fix CUDA 8 build for Windows (#9300)
peterjc123 Jul 10, 2018
efefd1d
Unify aten_dispatch and aten_schema into a single operator abstractio…
zdevito Jul 10, 2018
e9e47ce
Vectorize sigmoid (#8612)
cpuhrsch Jul 10, 2018
0a67910
Fix missing accept file changes
zdevito Jul 10, 2018
ea18692
Change depthwise convolution bandwidth formula (#9317)
Jokeren Jul 10, 2018
a47a30b
Implement grid_sampler in aten (#8929)
Jul 10, 2018
b4c6645
Add pyHIPIFY scripts needed for ROCm transpilation to PyTorch (#8812)
iotamudelta Jul 11, 2018
00aeb0b
Privatize values for vec256 (#9321)
cpuhrsch Jul 11, 2018
fb9f9c9
Implement Sinh and Cosh (#9213)
hl475 Jul 11, 2018
04a7fc1
Add Upsample support in C2 onnx backend for opset 1
houseroad Jul 11, 2018
b2a74d1
document torch.utils.dlpack (#9343)
t-vi Jul 11, 2018
01cffaa
fix extra output in generate_code.py (#9339)
zdevito Jul 11, 2018
748a90d
BBoxTransform op: Add support for rotated boxes (#8952)
viswanathgs Jul 11, 2018
05559b4
Accumulate MSELoss reduce=True into accreal instead of real (#9287)
zou3519 Jul 11, 2018
b9f575f
Remove legacy code from the JIT (#9323)
apaszke Jul 11, 2018
8da936a
Fix the build break for python3.7 PyUnicode_AsUTF8AndSize() prototype…
JerryShih Jul 11, 2018
491f317
NMS util for rotated boxes (#8954)
viswanathgs Jul 11, 2018
9126f95
GenerateProposals and BoxWithNMSLimit ops: Add support for rotated bo…
viswanathgs Jul 11, 2018
c2dd90c
Add angle normalization for rotated boxes (#9056)
viswanathgs Jul 11, 2018
18a9752
Add explicit to conversions (#9336)
goldsborough Jul 11, 2018
80380f6
Fix to make ONNXIFI flow work (#9340)
Jul 11, 2018
7d8b532
Fix CUDA build failures (#9347)
Jul 11, 2018
cbcf452
Move tanh function to math (#9328)
xiaomengy Jul 11, 2018
7f33ec5
Fix Eigen issue on OS X with CUDA and nvcc compile (#9350)
orionr Jul 11, 2018
8253947
Make error message more informative (#9352)
sunnieshang Jul 11, 2018
94bc4c6
Ensure pending tasks are finished in case of failure (#9290)
Jul 11, 2018
153e2e9
Make Sequential ref-counted (#9151)
goldsborough Jul 12, 2018
1a8e826
Skip the count_include_pad in average pool for now (#9365)
houseroad Jul 12, 2018
e30ff68
Add Hardtanh Export (#8804)
Ac2zoom Jul 12, 2018
a487b08
AutoBatching - IR transformation(basic operators) (#9198)
ChunliF Jul 12, 2018
7f38ea4
Remove unused feature: num PS tuning
heslami Jul 12, 2018
00b4b47
fix unsqueeze doc (#9374)
ssnl Jul 12, 2018
aeccec7
In Gloo backend use ring reduction by default (#9309)
pietern Jul 12, 2018
e186377
Guard gloo algorithm creation with DeviceGuard (#9371)
apaszke Jul 12, 2018
e57fe61
Merge remote-tracking branch 'upstream/master'
iotamudelta Jul 12, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,38 @@
Checks: '
*
,modernize-*
,clang-analyzer-*
,-cert-err58-cpp
,-cert-err60-cpp
,-clang-diagnostic-*
,-hicpp-no-array-decay
,-cppcoreguidelines-owning-memory
,-cppcoreguidelines-pro-bounds-array-to-pointer-decay
,-cppcoreguidelines-pro-bounds-constant-array-index
,-cppcoreguidelines-pro-type-static-cast-downcast
,-cppcoreguidelines-pro-type-vararg
,-cppcoreguidelines-special-member-functions
,-fuchsia-*
,-google-build-using-namespace
,-google-explicit-constructor
,-google-readability-braces-around-statements
,-google-readability-namespace-comments
,-llvm-namespace-comment
,-google-readability-todo
,-cppcoreguidelines-pro-bounds-array-to-pointer-decay
,-cert-err60-cpp
,-llvm-header-guard
,-cppcoreguidelines-special-member-functions
,-misc-unused-parameters
,-google-runtime-references
,-google-runtime-references
,-hicpp-braces-around-statements
,-hicpp-explicit-conversions
,-hicpp-no-array-decay
,-hicpp-special-member-functions
,-readability-braces-around-statements
,-modernize-use-default-member-init
,-google-runtime-references
,-cppcoreguidelines-pro-type-vararg
,-google-readability-braces-around-statements
,-google-build-using-namespace
,-hicpp-vararg
,-hicpp-explicit-conversions
,-performance-unnecessary-value-param
,-google-runtime-references
,-cppcoreguidelines-pro-type-static-cast-downcast
,-cppcoreguidelines-pro-bounds-constant-array-index
,-cert-err58-cpp
,-llvm-header-guard
,-llvm-namespace-comment
,-misc-unused-parameters
,-modernize-make-unique
,-cppcoreguidelines-owning-memory
,-modernize-use-default-member-init
,-performance-unnecessary-value-param
,-readability-braces-around-statements
,-readability-else-after-return
,-readability-named-parameter
,clang-analyzer-*
'
WarningsAsErrors: ''
HeaderFilterRegex: 'torch/csrc/'
Expand Down
1 change: 1 addition & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ test_aten() {
# put the dynamic libraries somewhere were the dynamic linker can find them.
# This is a bit of a hack.
ln -s "$TORCH_LIB_PATH"/libcaffe2* build/bin
ln -s "$TORCH_LIB_PATH"/libnccl* build/bin
ls build/bin
aten/tools/run_tests.sh build/bin
fi
Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/CPUApplyUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,15 @@ apply_op(int64_t numel, int64_t offset, const Op& op, Args... iters) {
}
}


inline void apply_kernel(){};

// TODO: Deal elegantly with 0-dim tensors. iters.strides_ of 0-dim
// strided_tensor_iter will be of size 0 for dim 0 and iters.strides_[iters.dim_
// - 1] will index at -1. C++14 integer_sequence could be of use here.
template <typename Op, typename... Args>
inline void
apply_kernel(int64_t numel, int64_t offset, const Op& op, Args... iters) {
// For 0-dim tensors
if (numel == 1 && max_dim(iters...) == 0) {
op(1, iters.data_..., iters.strides_[iters.dim_ - 1]...);
return;
}
if (offset > 0)
forward(offset, iters...);
int64_t size = std::min(numel, max_iterate_size(iters...));
Expand All @@ -284,6 +283,10 @@ inline void
CPU_tensor_parallel_kernel_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
if (!_apply_preamble({tensor1, tensor2}))
return;
if (tensor1.numel() == 1) {
op(1, tensor1.data<scalar1>(), tensor2.data<scalar2>(), 0, 0);
return;
}
if (tensor1.ndimension() < 8 && tensor2.ndimension() < 8) {
parallel_for(
0,
Expand Down
16 changes: 1 addition & 15 deletions aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -1114,24 +1114,10 @@
- THTensor* self
]]
[[
name: sigmoid_
name: _th_sigmoid
types:
- floating_point
backends:
- CPU
- CUDA
cname: sigmoid
return: self
arguments:
- THTensor* self
- THTensor* self
]]
[[
name: sigmoid
types:
- floating_point
backends:
- CPU
- CUDA
cname: sigmoid
variants:
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/cpu/vec256/intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
/* Microsoft C/C++-compatible compiler */
#include <intrin.h>
#if _MSC_VER <= 1900
#define _mm256_extract_epi64(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
#endif
#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
/* GCC-compatible compiler, targeting x86/x86-64 */
Expand Down
38 changes: 26 additions & 12 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ namespace {
// emulates vectorized types
template <class T>
struct Vec256 {
static constexpr int size = 32 / sizeof(T);
private:
T values[32 / sizeof(T)] = {0};
public:
static constexpr int size = 32 / sizeof(T);
Vec256() {}
Vec256(T val) {
for (int i = 0; i != size; i++) {
Expand All @@ -37,9 +39,9 @@ struct Vec256 {
Vec256 vec;
for (int64_t i = 0; i < size; i++) {
if (mask & 0x01) {
vec.values[i] = b[i];
vec[i] = b[i];
} else {
vec.values[i] = a[i];
vec[i] = a[i];
}
mask = mask >> 1;
}
Expand All @@ -49,9 +51,9 @@ struct Vec256 {
Vec256 vec;
for (int64_t i = 0; i < size; i++) {
if (i < count) {
vec.values[i] = b.values[i];
vec[i] = b[i];
} else {
vec.values[i] = a.values[i];
vec[i] = a[i];
}
}
return vec;
Expand All @@ -69,17 +71,23 @@ struct Vec256 {
void store(void* ptr, int count = size) const {
std::memcpy(ptr, values, count * sizeof(T));
}
const T& operator[](int idx) const {
return values[idx];
}
T& operator[](int idx) {
return values[idx];
}
Vec256<T> map(T (*f)(T)) const {
Vec256<T> ret;
for (int64_t i = 0; i != size; i++) {
ret.values[i] = f(values[i]);
ret[i] = f(values[i]);
}
return ret;
}
Vec256<T> abs() const {
Vec256<T> ret;
for (int64_t i = 0; i < size; i++) {
ret.values[i] = values[i] < 0 ? -values[i] : values[i];
ret[i] = values[i] < 0 ? -values[i] : values[i];
}
return ret;
}
Expand Down Expand Up @@ -125,6 +133,9 @@ struct Vec256 {
Vec256<T> floor() const {
return map(std::floor);
}
Vec256<T> neg() const {
return map([](T x) { return -x; });
}
Vec256<T> round() const {
return map(std::round);
}
Expand All @@ -146,6 +157,9 @@ struct Vec256 {
Vec256<T> sqrt() const {
return map(std::sqrt);
}
Vec256<T> reciprocal() const {
return map([](T x) { return (T)(1) / x; });
}
Vec256<T> rsqrt() const {
return map([](T x) { return 1 / std::sqrt(x); });
}
Expand All @@ -154,39 +168,39 @@ struct Vec256 {
template <class T> Vec256<T> operator+(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size; i++) {
c.values[i] = a.values[i] + b.values[i];
c[i] = a[i] + b[i];
}
return c;
}

template <class T> Vec256<T> operator-(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size; i++) {
c.values[i] = a.values[i] - b.values[i];
c[i] = a[i] - b[i];
}
return c;
}

template <class T> Vec256<T> operator*(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size; i++) {
c.values[i] = a.values[i] * b.values[i];
c[i] = a[i] * b[i];
}
return c;
}

template <class T> Vec256<T> operator/(const Vec256<T> &a, const Vec256<T> &b) __ubsan_ignore_float_divide_by_zero__ {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size; i++) {
c.values[i] = a.values[i] / b.values[i];
c[i] = a[i] / b[i];
}
return c;
}

template <class T> Vec256<T> max(const Vec256<T> &a, const Vec256<T> &b) {
Vec256<T> c = Vec256<T>();
for (int i = 0; i != Vec256<T>::size; i++) {
c.values[i] = std::max(a.values[i], b.values[i]);
c[i] = std::max(a[i], b[i]);
}
return c;
}
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ namespace {
#if defined(__AVX__) && !defined(_MSC_VER)

template <> class Vec256<double> {
private:
__m256d values;
public:
static constexpr int size = 4;
__m256d values;
Vec256() {}
Vec256(__m256d v) : values(v) {}
Vec256(double val) {
Expand Down Expand Up @@ -61,6 +62,8 @@ template <> class Vec256<double> {
std::memcpy(ptr, tmp_values, count * sizeof(double));
}
}
const double& operator[](int idx) const = delete;
double& operator[](int idx) = delete;
Vec256<double> map(double (*f)(double)) const {
__at_align32__ double tmp[4];
store(tmp);
Expand Down Expand Up @@ -121,6 +124,9 @@ template <> class Vec256<double> {
Vec256<double> floor() const {
return _mm256_floor_pd(values);
}
Vec256<double> neg() const {
return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
}
Vec256<double> round() const {
return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Expand All @@ -136,6 +142,9 @@ template <> class Vec256<double> {
Vec256<double> sqrt() const {
return _mm256_sqrt_pd(values);
}
Vec256<double> reciprocal() const {
return _mm256_div_pd(_mm256_set1_pd(1), values);
}
Vec256<double> rsqrt() const {
return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
}
Expand Down
11 changes: 10 additions & 1 deletion aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ namespace {
#if defined(__AVX__) && !defined(_MSC_VER)

template <> class Vec256<float> {
private:
__m256 values;
public:
static constexpr int64_t size = 8;
__m256 values;
Vec256() {}
Vec256(__m256 v) : values(v) {}
Vec256(float val) {
Expand Down Expand Up @@ -66,6 +67,8 @@ template <> class Vec256<float> {
std::memcpy(ptr, tmp_values, count * sizeof(float));
}
}
const float& operator[](int idx) const = delete;
float& operator[](int idx) = delete;
Vec256<float> map(float (*f)(float)) const {
__at_align32__ float tmp[8];
store(tmp);
Expand Down Expand Up @@ -126,6 +129,9 @@ template <> class Vec256<float> {
Vec256<float> floor() const {
return _mm256_floor_ps(values);
}
Vec256<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}
Vec256<float> round() const {
return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
Expand All @@ -141,6 +147,9 @@ template <> class Vec256<float> {
Vec256<float> sqrt() const {
return _mm256_sqrt_ps(values);
}
Vec256<float> reciprocal() const {
return _mm256_div_ps(_mm256_set1_ps(1), values);
}
Vec256<float> rsqrt() const {
return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
}
Expand Down
Loading