Skip to content

Commit f0583cd

Browse files
committed
Merge remote-tracking branch 'rocm_upstream/upstream' into ifu
2 parents 7ead1f1 + 1d3f650 commit f0583cd

File tree

133 files changed

+1264
-812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

133 files changed

+1264
-812
lines changed

.jenkins/pytorch/build.sh

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,6 @@ fi
102102
# Add the test binaries so that they won't be git clean'ed away
103103
git add -f build/bin
104104

105-
# Test C FFI plugins
106-
# cffi install doesn't work for Python 3.7
107-
if [[ "$BUILD_ENVIRONMENT" != *pynightly* ]]; then
108-
# TODO: Don't run this here
109-
pip install cffi
110-
git clone https://github.com/pytorch/extension-ffi.git
111-
pushd extension-ffi/script
112-
python build.py
113-
popd
114-
fi
115-
116105
# Test documentation build
117106
if [[ "$BUILD_ENVIRONMENT" == *xenial-cuda8-cudnn6-py3* ]]; then
118107
pushd docs

.jenkins/pytorch/enabled-configs.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ pytorch-macos-10.13-cuda9.2-cudnn7-py3-build
4040
pytorch-docker-build-test
4141
short-perf-test-cpu
4242
short-perf-test-gpu
43-
py2-clang3.8-rocm1.7.1-ubuntu16.04-build
44-
py2-clang3.8-rocm1.7.1-ubuntu16.04-test
43+
py2-clang7-rocmdeb-ubuntu16.04-build
44+
py2-clang7-rocmdeb-ubuntu16.04-test
4545
pytorch-ppc64le-cuda9.2-cudnn7-py3-build
4646
pytorch-ppc64le-cuda9.2-cudnn7-py3-test
4747
pytorch-ppc64le-cuda9.1-cudnn7-py3-build

aten/src/ATen/core/TensorImpl.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,20 @@ IntList TensorImpl::sizes() const {
4545
}
4646

4747
IntList TensorImpl::strides() const {
48+
AT_ASSERTM(strides_.size() == sizes_.size(),
49+
"Caffe2 tensors don't (yet) have meaningful strides and cannot "
50+
"be used in PyTorch.");
4851
return strides_;
4952
}
5053

5154
bool TensorImpl::compute_contiguous() const {
5255
bool is_contiguous = true;
5356
if (is_empty())
5457
return is_contiguous;
58+
if (strides_.empty()) {
59+
// Special case for Caffe2 tensors which don't have strides set.
60+
return true;
61+
}
5562
int64_t z = 1;
5663
for (int64_t d = dim() - 1; d >= 0; d--) {
5764
if (size(d) != 1) {
@@ -82,6 +89,9 @@ int64_t TensorImpl::size(int64_t d) const {
8289
}
8390

8491
int64_t TensorImpl::stride(int64_t d) const {
92+
AT_ASSERTM(strides_.size() == sizes_.size(),
93+
"Caffe2 tensors don't (yet) have meaningful strides and cannot "
94+
"be used in PyTorch.");
8595
d = at::maybe_wrap_dim(d, dim(), false);
8696
return strides_[d];
8797
}

aten/src/ATen/core/TensorImpl.h

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "ATen/core/LegacyTypeDispatch.h"
1111
#include "ATen/core/Backend.h"
1212
#include "ATen/core/context_base.h"
13+
#include "ATen/core/WrapDimMinimal.h"
1314

1415
#include "caffe2/core/allocator.h"
1516
#include "caffe2/core/common.h"
@@ -89,16 +90,6 @@ inline int64_t size_between_dim_(int k, int l, IntList dims) {
8990
return r;
9091
}
9192

92-
// Wrap around axis_index if it is negative, s.t., -1 is the last dim
93-
inline int canonical_axis_index_(int axis_index, int ndims) {
94-
CAFFE_ENFORCE_GE(axis_index, -ndims);
95-
CAFFE_ENFORCE_LT(axis_index, ndims);
96-
if (axis_index < 0) {
97-
return axis_index + ndims;
98-
}
99-
return axis_index;
100-
}
101-
10293
/**
10394
* The low-level representation of a tensor, which contains a storage
10495
* (which contains the actual data) and metadata (e.g., sizes and strides)
@@ -291,13 +282,13 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
291282
}
292283

293284
virtual void set_size(int64_t dim, int64_t new_size) {
294-
sizes_[dim] = new_size;
285+
sizes_.at(dim) = new_size;
295286
refresh_numel();
296287
refresh_contiguous();
297288
}
298289

299290
virtual void set_stride(int64_t dim, int64_t new_stride) {
300-
strides_[dim] = new_stride;
291+
strides_.at(dim) = new_stride;
301292
refresh_numel();
302293
refresh_contiguous();
303294
}
@@ -374,6 +365,10 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
374365
return storage_.device_type();
375366
}
376367

368+
at::Device GetDevice() const {
369+
return storage_.device();
370+
}
371+
377372
/**
378373
* The static context of a tensor intuitively represents the device
379374
* type of a tensor; e.g., a CPU tensor is associated with the
@@ -385,18 +380,6 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
385380
return ::caffe2::get_static_context(device_type());
386381
}
387382

388-
/* @brief
389-
* Create a context that has the same device_type
390-
* as the tensor.
391-
* Note that this doesn't support passing in argument
392-
* TODO(jerryzh): move this to a global registry
393-
* that can create context for us, and then eliminate
394-
* this method.
395-
*/
396-
std::unique_ptr<at::BaseContext> CreateContext() const {
397-
return GetStaticContext()->CreateContext();
398-
}
399-
400383
/**
401384
* @brief Copies the data from a source tensor, with a contex provided to
402385
* carry out the underlying memcpy operation. This method respects
@@ -438,8 +421,12 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
438421
// knows how to copy between CPU and that context
439422
if (src.device_type() != ::at::DeviceType::CPU || device_type() == ::at::DeviceType::CPU) {
440423
if (!context) {
441-
src.CreateContext()->CopyBytesToDevice(
442-
numel() * itemsize(), src.data(), raw_mutable_data(data_type_), device_type());
424+
CreateContext(src.GetDevice())
425+
->CopyBytesToDevice(
426+
numel() * itemsize(),
427+
src.data(),
428+
raw_mutable_data(data_type_),
429+
device_type());
443430
} else {
444431
CAFFE_ENFORCE(
445432
context->device_type() == src.device_type(),
@@ -451,8 +438,11 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
451438
// In case source context is CPU, and target context is non-CPU
452439
// We'll have to create a Context from target and perform the
453440
// copy using that context
454-
CreateContext()->CopyBytesFromCPU(
455-
numel() * itemsize(), src.data(), raw_mutable_data(data_type_));
441+
CreateContext(GetDevice())
442+
->CopyBytesFromCPU(
443+
numel() * itemsize(),
444+
src.data(),
445+
raw_mutable_data(data_type_));
456446
}
457447
}
458448
}
@@ -874,14 +864,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
874864
}
875865

876866
inline void update_to_contiguous_strides() {
877-
strides_.resize(sizes_.size());
878-
if (dim() > 0) {
879-
int last_idx = dim() - 1;
880-
strides_[last_idx] = 1;
881-
for (auto i = last_idx - 1; i >= 0; --i) {
882-
strides_[i] = strides_[i + 1] * std::max<int64_t>(sizes_[i + 1], 1);
883-
}
884-
}
867+
strides_.resize(0);
885868
is_contiguous_ = true;
886869
}
887870

aten/src/ATen/core/WrapDimMinimal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,10 @@ static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wr
2020
return dim;
2121
}
2222

23+
// Wrap around axis_index if it is negative, s.t., -1 is the last dim
24+
// This is the "Caffe2" name
25+
static inline int canonical_axis_index_(int axis_index, int ndims) {
26+
return maybe_wrap_dim(axis_index, ndims, false);
27+
}
28+
2329
}

aten/src/ATen/core/context_base.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#include <ATen/core/context_base.h>
22

3+
namespace at {
4+
5+
C10_DEFINE_TYPED_REGISTRY(
6+
ContextRegistry,
7+
at::DeviceType,
8+
at::BaseContext,
9+
std::unique_ptr,
10+
at::Device);
11+
12+
} // namespace at
13+
314
namespace caffe2 {
415

516
// TODO: rename context.h -> context_cpu.h & context_base.h -> context.h

aten/src/ATen/core/context_base.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
#include <memory>
77
#include <unordered_map>
88

9-
#include <ATen/core/DeviceType.h>
9+
#include <ATen/core/ATenGeneral.h>
10+
#include <ATen/core/Device.h>
1011
#include <ATen/core/Error.h>
1112
#include <ATen/core/UniqueVoidPtr.h>
1213
#include <ATen/core/typeid.h>
13-
#include <ATen/core/ATenGeneral.h>
14+
#include <c10/util/Registry.h>
1415

1516
namespace caffe2 {
1617
class Event;
@@ -31,11 +32,6 @@ class CAFFE2_API BaseStaticContext {
3132

3233
virtual std::pair<void*, DeleterFnPtr> New(size_t nbytes) const = 0;
3334

34-
virtual std::unique_ptr<BaseContext> CreateContext() = 0;
35-
36-
virtual std::unique_ptr<BaseContext> CreateContext(
37-
const caffe2::DeviceOption&) = 0;
38-
3935
virtual DeviceType GetDeviceType() = 0;
4036

4137
/*
@@ -184,6 +180,22 @@ class CAFFE2_API BaseContext {
184180
}
185181
};
186182

183+
// Context constructor registry
184+
C10_DECLARE_TYPED_REGISTRY(
185+
ContextRegistry,
186+
at::DeviceType,
187+
at::BaseContext,
188+
std::unique_ptr,
189+
at::Device);
190+
191+
#define REGISTER_CONTEXT(type, ...) \
192+
C10_REGISTER_TYPED_CLASS(ContextRegistry, type, __VA_ARGS__)
193+
194+
inline std::unique_ptr<at::BaseContext> CreateContext(
195+
const at::Device& device) {
196+
return at::ContextRegistry()->Create(device.type(), device);
197+
}
198+
187199
} // namespace at
188200

189201
namespace caffe2 {

aten/src/ATen/function_wrapper.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,10 @@ def TypedDict(name, attrs, total=True): # type: ignore
107107
# NB: As far as ezyang can tell, we don't *have* to codegen this,
108108
# because we will inherit it from the TYPE_METHOD_DEFINITION_CONCRETE in
109109
# the superclass. But it doesn't seem to be harmful.
110-
#
111-
# TODO: self_ty is a hack to make things work for native methods which need to
112-
# take a dtype, but also need to dispatch differently for different types.
113-
# Eliminate it at some point.
114110
TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate("""\
115111
${return_type} ${Type}::${api_name}(${type_method_formals}) const {
116112
${device_guard_declaration}
117-
const auto& self_ty = *this;
118-
(void)self_ty;
119-
${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${actuals});
113+
${return_call} at::native::${native_type_method_dispatch}(/* actuals */ ${type_derived_call_actuals});
120114
}
121115
""")
122116
TYPE_DERIVED_DEFINITION_NATIVE_MISSING = CodeTemplate("""\
@@ -1574,8 +1568,15 @@ def process_native(option):
15741568
TYPE_DERIVED_DEFINITION_NATIVE_MISSING.substitute(env))
15751569
else:
15761570
option['native_type_method_dispatch'] = native_dispatch
1571+
type_derived_call_actuals = []
1572+
for actual, arg in zip(option['actuals'], option['arguments']):
1573+
if arg.get('is_type_dispatched', False):
1574+
type_derived_call_actuals.append('*this')
1575+
else:
1576+
type_derived_call_actuals.append(actual)
15771577
type_object_definitions.append(
1578-
TYPE_DERIVED_DEFINITION_NATIVE.substitute(env))
1578+
TYPE_DERIVED_DEFINITION_NATIVE.substitute(
1579+
env, type_derived_call_actuals=type_derived_call_actuals))
15791580

15801581
for declaration in declarations:
15811582
for option in declaration['options']:

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,8 +2068,8 @@
20682068
SparseCPU: hspmm_sparse_cpu
20692069
SparseCUDA: hspmm_sparse_cuda
20702070

2071-
# This "raw copy" doesn't handle conversions NOR does it handle non-blocking.
2072-
- func: raw_copy_sparse_(Tensor self, Tensor src) -> Tensor
2071+
- func: copy_sparse_to_sparse_(Tensor self, Tensor src, bool non_blocking=false) -> Tensor
2072+
variants: function
20732073
dispatch:
20742074
SparseCPU: copy_sparse_
20752075
SparseCUDA: copy_sparse_

aten/src/ATen/native/sparse/SparseTensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ SparseTensor new_with_tensor_and_size_sparse(const LongTensor& indices, const Te
204204

205205
SparseTensor clone_sparse(const SparseTensor& self) {
206206
SparseTensor other = new_with_dims_and_size_sparse(self.type(), self._sparseDims(), self._denseDims(), self.sizes());
207-
_copy_into_sparse(other, _get_sparse_impl(self)->indices(), _get_sparse_impl(self)->values());
207+
_copy_into_sparse(other, _get_sparse_impl(self)->indices(), _get_sparse_impl(self)->values(), true);
208208
_get_sparse_impl(other)->set_coalesced(self.is_coalesced());
209209
return other;
210210
}
@@ -243,11 +243,11 @@ Tensor sparse_to_dense(const SparseTensor& self) {
243243
return dst.add_(self);
244244
}
245245

246-
SparseTensor& copy_sparse_(SparseTensor& self, const SparseTensor& src) {
246+
SparseTensor& copy_sparse_(SparseTensor& self, const SparseTensor& src, bool non_blocking) {
247247
if (isSameTensor(self, src)) return self;
248248
_get_sparse_impl(self)->resize_(src._sparseDims(), src._denseDims(), src.sizes());
249249
// NB: This seems to copy the underlying full indices/values buffer
250-
_copy_into_sparse(self, _get_sparse_impl(src)->indices(), _get_sparse_impl(src)->values());
250+
_copy_into_sparse(self, _get_sparse_impl(src)->indices(), _get_sparse_impl(src)->values(), non_blocking);
251251
_get_sparse_impl(self)->set_coalesced(src.is_coalesced());
252252
return self;
253253
}

aten/src/ATen/native/sparse/SparseTensorMath.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ SparseTensor& log1p_out_sparse(SparseTensor& r, const SparseTensor& t) {
9898
r.is_coalesced(), "log1p: in-place on uncoalesced tensors is not supported yet!");
9999
}
100100
else {
101-
r = raw_copy_sparse_(r, t.coalesce());
101+
copy_sparse_to_sparse_(r, t.coalesce());
102102
}
103103
r._values().log1p_();
104104
return r;
@@ -192,7 +192,7 @@ SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const S
192192
AT_CHECK(t.sizes().equals(src.sizes()), "add: expected sizes of 'self' and 'other' to match, but ", t.sizes(), " != ", src.sizes());
193193

194194
if (src._nnz() == 0) {
195-
return raw_copy_sparse_(r, t);
195+
return copy_sparse_to_sparse_(r, t);
196196
}
197197
if (t._nnz() == 0) {
198198
return mul_out_sparse_scalar(r, src, value);

aten/src/ATen/native/sparse/SparseUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ inline void _alias_into_sparse(const SparseTensor& self, const LongTensor& indic
5050

5151
// Take indices and values and makes a (data) copy of them to put into the sparse
5252
// indices/values. This used to be called THSTensor_(_set)
53-
inline void _copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values) {
54-
_alias_into_sparse(self, indices.clone(), values.clone());
53+
inline void _copy_into_sparse(const SparseTensor& self, const LongTensor& indices, const Tensor& values, bool non_blocking) {
54+
_alias_into_sparse(self, self._indices().type().copy(indices, non_blocking), self._values().type().copy(values, non_blocking));
5555
}
5656

5757
// Does NOT make copies of indices/values

aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const
348348
AT_CHECK(t.sizes().equals(src.sizes()), "add: expected 'self' and 'other' to have same size, but ", t.sizes(), " != ", src.sizes());
349349

350350
if (src._nnz() == 0) {
351-
return raw_copy_sparse_(r_, t);
351+
return copy_sparse_to_sparse_(r_, t);
352352
}
353353
if (t._nnz() == 0) {
354354
return mul_out_sparse_scalar(r_, src, value);

aten/src/ATen/templates/TypeDefault.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ namespace at {
1818

1919
Tensor & TypeDefault::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
2020
Tensor b_src;
21-
std::tie(b_src) = expand_inplace(self, src, "copy");
21+
if (is_sparse()) b_src = src;
22+
else std::tie(b_src) = expand_inplace(self, src, "copy");
2223
return s_copy_(self, b_src, non_blocking);
2324
}
2425

@@ -28,19 +29,11 @@ Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device>
2829
device_guard.set_index(to_device.value().index());
2930
}
3031
AT_CHECK(src.defined(), "attempt to copy an undefined tensor");
31-
if (is_sparse()) {
32-
auto indices = src._indices();
33-
auto values = src._values();
34-
auto & this_dense = toBackend(is_cuda() ? Backend::CUDA : Backend::CPU);
35-
auto & this_dense_idx = this_dense.toScalarType(ScalarType::Long);
36-
auto indices_copy = this_dense_idx.copy(indices, non_blocking);
37-
auto values_copy = this_dense.copy(values, non_blocking);
38-
return _sparse_coo_tensor_unsafe(indices_copy, values_copy, src.sizes());
39-
} else {
40-
Tensor r = this->tensor(src.sizes());
41-
r.copy_(src, non_blocking);
42-
return r;
43-
}
32+
Tensor r;
33+
if (is_sparse()) r = this->native_tensor();
34+
else r = this->tensor(src.sizes());
35+
r.copy_(src, non_blocking);
36+
return r;
4437
}
4538

4639
void TypeDefault::backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {

0 commit comments

Comments
 (0)