Skip to content

Commit 8476972

Browse files
committed
Merge remote-tracking branch 'upstream/master' into ifu
2 parents cf78d09 + 6d6655e commit 8476972

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+746
-460
lines changed

aten/src/ATen/CPUTypeDefault.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <ATen/CPUTypeDefault.h>
2+
3+
#include <ATen/Context.h>
4+
#include <ATen/CPUGenerator.h>
5+
6+
namespace at {
7+
8+
Allocator* CPUTypeDefault::allocator() const {
9+
return getCPUAllocator();
10+
}
11+
12+
Device CPUTypeDefault::getDeviceFromPtr(void * data) const {
13+
return DeviceType::CPU;
14+
}
15+
16+
std::unique_ptr<Generator> CPUTypeDefault::generator() const {
17+
return std::unique_ptr<Generator>(new CPUGenerator(&at::globalContext()));
18+
}
19+
20+
} // namespace at

aten/src/ATen/CPUTypeDefault.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
#include <ATen/TypeDefault.h>
3+
4+
namespace at {
5+
6+
struct AT_API CPUTypeDefault : public TypeDefault {
7+
CPUTypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
8+
: TypeDefault(type_id, is_variable, is_undefined) {}
9+
Allocator* allocator() const override;
10+
Device getDeviceFromPtr(void * data) const override;
11+
std::unique_ptr<Generator> generator() const override;
12+
};
13+
14+
} // namespace at

aten/src/ATen/Context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,8 @@ Type& getMaybeVariableType(const TensorImpl* impl) {
118118
backend, impl->scalar_type(), impl->is_variable());
119119
}
120120

121+
Allocator* getCPUAllocator() {
122+
return getTHDefaultAllocator();
123+
}
124+
121125
}

aten/src/ATen/Context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ static inline Type& getNonVariableType(DeviceType p, ScalarType s) {
158158
AT_API Type& getMaybeVariableType(TensorOptions options);
159159
AT_API Type& getMaybeVariableType(const TensorImpl*);
160160

161+
AT_API Allocator* getCPUAllocator();
162+
161163
static inline Type& CPU(ScalarType s) {
162164
return getNonVariableType(Backend::CPU, s);
163165
}

aten/src/ATen/UndefinedType.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ ScalarType UndefinedType::scalarType() const {
1111
Backend UndefinedType::backend() const {
1212
return Backend::Undefined;
1313
}
14-
bool UndefinedType::is_cuda() const { return false; }
15-
bool UndefinedType::is_sparse() const { return false; }
16-
bool UndefinedType::is_distributed() const { return false; }
14+
15+
Allocator* UndefinedType::allocator() const {
16+
AT_ERROR("allocator not defined for UndefinedType");
17+
}
18+
19+
Device UndefinedType::getDeviceFromPtr(void*) const {
20+
AT_ERROR("getDeviceFromPtr not defined for UndefinedType");
21+
}
1722

1823
Storage UndefinedType::storage(bool resizable) const {
1924
AT_ERROR("storage not defined for UndefinedType");
@@ -38,8 +43,9 @@ std::unique_ptr<Generator> UndefinedType::generator() const {
3843
}
3944

4045
const char * UndefinedType::toString() const {
41-
return UndefinedType::typeString();
46+
return "UndefinedType";
4247
}
48+
4349
TypeID UndefinedType::ID() const {
4450
return TypeID::Undefined;
4551
}
@@ -61,10 +67,6 @@ Type & UndefinedType::toScalarType(ScalarType s) const {
6167
AT_ERROR("toScalarType not implemented for UndefinedType to non-UndefinedType");
6268
}
6369

64-
const char * UndefinedType::typeString() {
65-
return "UndefinedType";
66-
}
67-
6870
Tensor & UndefinedType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
6971
AT_ERROR("s_copy not defined for UndefinedType");
7072
}

aten/src/ATen/UndefinedType.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ struct UndefinedType final : public TypeDefault {
1515
explicit UndefinedType();
1616
virtual ScalarType scalarType() const override;
1717
virtual Backend backend() const override;
18-
virtual bool is_cuda() const override;
19-
virtual bool is_sparse() const override;
20-
virtual bool is_distributed() const override;
18+
virtual Allocator* allocator() const override;
19+
virtual Device getDeviceFromPtr(void* data) const override;
2120
virtual Storage storage(bool resizable = false) const override;
2221
virtual Storage storage(size_t size, bool resizable = false) const override;
2322
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
@@ -28,7 +27,6 @@ struct UndefinedType final : public TypeDefault {
2827
virtual Type & toBackend(Backend b) const override;
2928
virtual Type & toScalarType(ScalarType s) const override;
3029
virtual TypeID ID() const override;
31-
static const char * typeString();
3230
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
3331
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
3432

aten/src/ATen/cuda/CUDAContext.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "ATen/cuda/CUDAContext.h"
2-
#include "THC/THCGeneral.h"
2+
#include "THC/THCGeneral.hpp"
33

44
namespace at { namespace cuda {
55

@@ -45,6 +45,10 @@ void uncheckedSetCurrentCUDAStream(CUDAStream stream) {
4545
detail::CUDAStream_uncheckedSetStream(stream.internals());
4646
}
4747

48+
Allocator* getCUDADeviceAllocator() {
49+
return at::globalContext().getTHCState()->cudaDeviceAllocator;
50+
}
51+
4852
/* Handles */
4953
#ifndef __HIP_PLATFORM_HCC__
5054
cusparseHandle_t getCurrentCUDASparseHandle() {
@@ -54,4 +58,4 @@ void uncheckedSetCurrentCUDAStream(CUDAStream stream) {
5458

5559
} // namespace cuda
5660

57-
} // namespace at
61+
} // namespace at

aten/src/ATen/cuda/CUDAContext.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ AT_API CUDAStream getCurrentCUDAStream(int64_t device = -1);
5454
AT_API void setCurrentCUDAStream(CUDAStream stream);
5555
AT_API void uncheckedSetCurrentCUDAStream(CUDAStream stream);
5656

57+
AT_API Allocator* getCUDADeviceAllocator();
58+
5759
/* Handles */
5860
#ifndef __HIP_PLATFORM_HCC__
5961
AT_API cusparseHandle_t getCurrentCUDASparseHandle();

aten/src/ATen/cuda/CUDADevice.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "ATen/cuda/Exceptions.h"
4+
5+
#include "cuda.h"
6+
7+
namespace at {
8+
namespace cuda {
9+
10+
inline Device getDeviceFromPtr(void* ptr) {
11+
struct cudaPointerAttributes attr;
12+
AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
13+
return {DeviceType::CUDA, attr.device};
14+
}
15+
16+
}} // namespace at::cuda
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <ATen/cuda/CUDATypeDefault.h>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <ATen/cuda/CUDADevice.h>
5+
#include <ATen/CUDAGenerator.h>
6+
7+
namespace at {
8+
9+
Allocator* CUDATypeDefault::allocator() const {
10+
return cuda::getCUDADeviceAllocator();
11+
}
12+
Device CUDATypeDefault::getDeviceFromPtr(void * data) const {
13+
return cuda::getDeviceFromPtr(data);
14+
}
15+
std::unique_ptr<Generator> CUDATypeDefault::generator() const {
16+
return std::unique_ptr<Generator>(new CUDAGenerator(&at::globalContext()));
17+
}
18+
19+
} // namespace at

aten/src/ATen/cuda/CUDATypeDefault.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
#include <ATen/TypeDefault.h>
3+
#include <ATen/cuda/ATenCUDAGeneral.h>
4+
5+
namespace at {
6+
7+
struct AT_CUDA_API CUDATypeDefault : public TypeDefault {
8+
CUDATypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
9+
: TypeDefault(type_id, is_variable, is_undefined) {}
10+
11+
Allocator* allocator() const override;
12+
Device getDeviceFromPtr(void * data) const override;
13+
std::unique_ptr<Generator> generator() const override;
14+
};
15+
16+
} // namespace at

aten/src/ATen/cudnn/Descriptors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ struct AT_CUDA_API DropoutDescriptor
257257
AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
258258
AT_ASSERT(type.is_cuda());
259259
AT_ASSERT(type.scalarType() == kByte);
260-
state = at::empty({static_cast<int64_t>(state_size)}, type);
260+
state = at::empty({static_cast<int64_t>(state_size)}, type.options());
261261
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
262262
}
263263

aten/src/ATen/gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
256256
]
257257
env['extra_cuda_headers'] = ['#include <ATen/cuda/CUDAHalf.cuh>']
258258
env['extra_cuda_headers'].append('#include <ATen/DeviceGuard.h>')
259+
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDADevice.h>')
260+
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDATypeDefault.h>')
259261
sname = '' if scalar_name == "Float" else scalar_name
260262
env['THType'] = 'Cuda{}'.format(sname)
261263
env['THStorage'] = 'THCuda{}Storage'.format(sname)

aten/src/ATen/native/Linear.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,4 +457,62 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
457457
return output;
458458
}
459459

460+
// implements tensordot, a matrix-multiplication-like contraction, but the dimensions given
461+
// in the two dimension lists
462+
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntList dims1, IntList dims2) {
463+
AT_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
464+
int64_t csize = 1; // total size of the contracted dimensions
465+
Tensor t1 = input1;
466+
Tensor t2 = input2;
467+
for (size_t i = 0; i < dims1.size(); i++) {
468+
int s1 = input1.size(dims1[i]);
469+
int s2 = input2.size(dims2[i]);
470+
if (s2 == 1) { // broadcasted dimensions can be summed right away
471+
t1 = t1.sum(dims1[i], true);
472+
} else if (s1 == 1) {
473+
t2 = t2.sum(dims2[i], true);
474+
} else {
475+
AT_CHECK(s1 == s2, "contracted dimensions need to match, but first has size ", s1, " in dim ", dims1[i],
476+
" and second has size ", s2, " in dim ", dims2[i]);
477+
csize *= s1;
478+
}
479+
}
480+
481+
auto cdims1 = dim_list_to_bitset(dims1, input1.dim());
482+
auto cdims2 = dim_list_to_bitset(dims2, input2.dim());
483+
std::vector<int64_t> p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
484+
p1.reserve(input1.dim());
485+
p2.reserve(input2.dim());
486+
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
487+
int64_t size1 = 1; // number of non-contracted elements in input1
488+
int64_t size2 = 1; // number of non-contracted elements in input2
489+
490+
// fill the permutations and compute sizes
491+
for (int64_t i = 0; i < input1.dim(); i++) {
492+
if (! cdims1[i]) {
493+
p1.emplace_back(i);
494+
size1 *= t1.size(i);
495+
rsizes.emplace_back(t1.size(i));
496+
}
497+
}
498+
for (size_t i = 0; i < dims1.size(); i++) {
499+
p1.emplace_back(dims1[i]);
500+
}
501+
for (size_t i = 0; i < dims2.size(); i++) {
502+
p2.emplace_back(dims2[i]);
503+
}
504+
for (int64_t i = 0; i < input2.dim(); i++) {
505+
if (! cdims2[i]) {
506+
p2.emplace_back(i);
507+
size2 *= t2.size(i);
508+
rsizes.emplace_back(t2.size(i));
509+
}
510+
}
511+
// permut and reshape for matrix multiplication
512+
t1 = t1.permute(p1).reshape({size1, csize});
513+
t2 = t2.permute(p2).reshape({csize, size2});
514+
// multiply and reshape to target size
515+
return at::mm(t1, t2).reshape(rsizes);
516+
}
517+
460518
}} // namespace at::native

0 commit comments

Comments
 (0)