Skip to content

Commit 3b3d928

Browse files
committed
CUDA Kernels: Use per-operator headers (2/4)
Splitting this into multiple PRs to keep the diffs more managable. ghstack-source-id: a36d52a Pull Request resolved: pytorch#71213
1 parent 7c2103a commit 3b3d928

22 files changed

+236
-39
lines changed

aten/src/ATen/native/cuda/Dropout.cu

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
34
#include <ATen/cuda/CUDAGeneratorImpl.h>
5+
#include <ATen/Dispatch.h>
6+
#include <ATen/Utils.h>
47
#include <ATen/cuda/detail/IndexUtils.cuh>
58
#include <ATen/cuda/detail/TensorInfo.cuh>
69
#include <ATen/cuda/CUDAGraphsUtils.cuh>
@@ -11,6 +14,17 @@
1114
#include <ATen/native/cuda/Loops.cuh>
1215
#include <ATen/native/cuda/MemoryAccess.cuh>
1316

17+
#ifndef AT_PER_OPERATOR_HEADERS
18+
#include <ATen/Functions.h>
19+
#include <ATen/NativeFunctions.h>
20+
#else
21+
#include <ATen/ops/_masked_scale_native.h>
22+
#include <ATen/ops/empty_like.h>
23+
#include <ATen/ops/native_dropout_backward_native.h>
24+
#include <ATen/ops/ones_like.h>
25+
#include <ATen/ops/zeros_like.h>
26+
#endif
27+
1428
namespace at{
1529
namespace native{
1630

aten/src/ATen/native/cuda/Embedding.cu

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
4+
#include <ATen/Dispatch.h>
35
#include <ATen/TensorUtils.h>
46
#include <ATen/ceil_div.h>
57
#include <ATen/cuda/CUDAContext.h>
@@ -17,6 +19,18 @@
1719
#include <thrust/iterator/reverse_iterator.h>
1820
#endif
1921

22+
#ifndef AT_PER_OPERATOR_HEADERS
23+
#include <ATen/Functions.h>
24+
#include <ATen/NativeFunctions.h>
25+
#else
26+
#include <ATen/ops/arange.h>
27+
#include <ATen/ops/embedding_dense_backward_native.h>
28+
#include <ATen/ops/embedding_renorm_native.h>
29+
#include <ATen/ops/empty.h>
30+
#include <ATen/ops/empty_like.h>
31+
#include <ATen/ops/zeros.h>
32+
#endif
33+
2034
namespace at { namespace native {
2135

2236
namespace {

aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
23
#include <ATen/cuda/Atomic.cuh>
34
#include <ATen/cuda/CUDAContext.h>
45
#include <ATen/cuda/cub.cuh>
6+
#include <ATen/AccumulateType.h>
7+
#include <ATen/Dispatch.h>
58
#include <ATen/TensorUtils.h>
6-
#include <ATen/NativeFunctions.h>
79
#include <ATen/native/cuda/SortingCommon.cuh>
810

9-
#include <ATen/AccumulateType.h>
10-
1111
#include <c10/macros/Macros.h>
1212

1313
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
1414
#include <thrust/iterator/counting_iterator.h>
1515
#endif
1616

17+
#ifndef AT_PER_OPERATOR_HEADERS
18+
#include <ATen/Functions.h>
19+
#else
20+
#include <ATen/ops/empty.h>
21+
#include <ATen/ops/zeros.h>
22+
#endif
23+
1724
namespace at {
1825
namespace native {
1926

aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cuh

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
#include <ATen/ATen.h>
1+
#pragma once
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/cuda/Atomic.cuh>
34
#include <ATen/cuda/CUDAContext.h>
45
#include <ATen/TensorUtils.h>
5-
#include <ATen/NativeFunctions.h>
6-
7-
#pragma once
86

97
namespace at {
108
namespace native {

aten/src/ATen/native/cuda/EmbeddingBag.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/AccumulateType.h>
24
#include <ATen/ceil_div.h>
5+
#include <ATen/Dispatch.h>
36
#include <ATen/cuda/Atomic.cuh>
47
#include <ATen/cuda/CUDAContext.h>
58
#include <ATen/cuda/DeviceUtils.cuh>
69
#include <ATen/TensorUtils.h>
7-
#include <ATen/NativeFunctions.h>
810

9-
#include <ATen/AccumulateType.h>
11+
#ifndef AT_PER_OPERATOR_HEADERS
12+
#include <ATen/Functions.h>
13+
#include <ATen/NativeFunctions.h>
14+
#else
15+
#include <ATen/ops/arange.h>
16+
#include <ATen/ops/empty.h>
17+
#include <ATen/ops/empty_like.h>
18+
#include <ATen/ops/zeros.h>
19+
#include <ATen/ops/_embedding_bag_native.h>
20+
#include <ATen/ops/_embedding_bag_forward_only_native.h>
21+
#include <ATen/ops/_embedding_bag_dense_backward_native.h>
22+
#include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
23+
#endif
1024

1125
#include <ATen/cuda/cub.cuh>
1226
#include <ATen/native/cuda/SortingCommon.cuh>

aten/src/ATen/native/cuda/Equal.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
3+
#include <ATen/NamedTensorUtils.h>
4+
5+
#ifndef AT_PER_OPERATOR_HEADERS
16
#include <ATen/NativeFunctions.h>
27
#include <ATen/CUDAFunctions.h>
3-
#include <ATen/NamedTensorUtils.h>
8+
#else
9+
#include <ATen/ops/eq_cuda_dispatch.h>
10+
#include <ATen/ops/equal_native.h>
11+
#endif
412

513
namespace at { namespace native {
614

aten/src/ATen/native/cuda/FractionalMaxPool2d.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
4+
#include <ATen/Dispatch.h>
35
#include <ATen/cuda/Atomic.cuh>
46
#include <ATen/cuda/CUDAContext.h>
57
#include <ATen/cuda/NumericLimits.cuh>
68
#include <ATen/cuda/detail/IndexUtils.cuh>
79
#include <ATen/cuda/detail/KernelUtils.h>
8-
#include <ATen/NativeFunctions.h>
910
#include <ATen/NumericUtils.h>
1011
#include <ATen/TensorUtils.h>
1112
#include <ATen/Utils.h>
1213
#include <c10/util/Exception.h>
1314

15+
#ifndef AT_PER_OPERATOR_HEADERS
16+
#include <ATen/NativeFunctions.h>
17+
#else
18+
#include <ATen/ops/fractional_max_pool2d_backward_native.h>
19+
#include <ATen/ops/fractional_max_pool2d_native.h>
20+
#endif
21+
1422
#include <algorithm>
1523
#include <cfloat>
1624
#include <cmath>

aten/src/ATen/native/cuda/FractionalMaxPool3d.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
4+
#include <ATen/Dispatch.h>
35
#include <ATen/cuda/Atomic.cuh>
46
#include <ATen/cuda/CUDAContext.h>
57
#include <ATen/cuda/NumericLimits.cuh>
68
#include <ATen/cuda/detail/IndexUtils.cuh>
79
#include <ATen/cuda/detail/TensorInfo.cuh>
810
#include <ATen/cuda/detail/KernelUtils.h>
9-
#include <ATen/NativeFunctions.h>
1011
#include <ATen/NumericUtils.h>
1112
#include <ATen/TensorUtils.h>
1213
#include <ATen/Utils.h>
1314
#include <c10/util/Exception.h>
1415

16+
#ifndef AT_PER_OPERATOR_HEADERS
17+
#include <ATen/Functions.h>
18+
#include <ATen/NativeFunctions.h>
19+
#else
20+
#include <ATen/ops/empty.h>
21+
#include <ATen/ops/fractional_max_pool3d_backward_native.h>
22+
#include <ATen/ops/fractional_max_pool3d_native.h>
23+
#endif
24+
1525
#include <algorithm>
1626
#include <cfloat>
1727
#include <cmath>

aten/src/ATen/native/cuda/FunctionOfAMatrixUtilsKernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
12
#include <ATen/native/FunctionOfAMatrixUtils.h>
23

34
#include <ATen/Dispatch.h>

aten/src/ATen/native/cuda/GridSampler.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
#include <ATen/native/cuda/KernelUtils.cuh>
23

34
namespace at { namespace native {

aten/src/ATen/native/cuda/Im2Col.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
3-
#include <ATen/NativeFunctions.h>
4+
#include <ATen/Dispatch.h>
45
#include <ATen/TensorUtils.h>
56
#include <ATen/Utils.h>
67
#include <ATen/div_rtn.h>
@@ -10,6 +11,16 @@
1011
#include <ATen/native/cuda/im2col.cuh>
1112
#include <ATen/native/im2col_shape_check.h>
1213

14+
#ifndef AT_PER_OPERATOR_HEADERS
15+
#include <ATen/Functions.h>
16+
#include <ATen/NativeFunctions.h>
17+
#else
18+
#include <ATen/ops/empty_like.h>
19+
#include <ATen/ops/col2im_native.h>
20+
#include <ATen/ops/im2col_native.h>
21+
#include <ATen/ops/im2col_backward_native.h>
22+
#endif
23+
1324
namespace at {
1425
namespace native {
1526
namespace {

aten/src/ATen/native/cuda/IndexKernel.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
12
#include <ATen/native/cuda/IndexKernel.h>
23
#include <ATen/native/TensorAdvancedIndexing.h> // For at::native::index_out
4+
#include <ATen/core/Tensor.h>
5+
#include <ATen/core/List.h>
36
#include <ATen/ExpandUtils.h>
4-
#include <ATen/Functions.h>
57
#include <ATen/MemoryOverlap.h>
68
#include <ATen/NamedTensorUtils.h>
9+
10+
#ifndef AT_PER_OPERATOR_HEADERS
11+
#include <ATen/Functions.h>
712
#include <ATen/NativeFunctions.h>
13+
#else
14+
#include <ATen/ops/empty.h>
15+
#include <ATen/ops/masked_scatter_native.h>
16+
#include <ATen/ops/masked_select_native.h>
17+
#endif
18+
819

920
namespace at {
1021
namespace native {

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
12
#include <ATen/native/TensorAdvancedIndexing.h>
23
#include <ATen/native/IndexingUtils.h>
34

4-
#include <ATen/ATen.h>
5+
#include <ATen/core/Tensor.h>
56
#include <ATen/ceil_div.h>
6-
#include <ATen/NativeFunctions.h>
7+
#include <ATen/Dispatch.h>
78
#include <ATen/ExpandUtils.h>
89
#include <ATen/MemoryOverlap.h>
10+
#include <ATen/TensorOperators.h>
911
#include <ATen/native/TensorIterator.h>
1012
#include <ATen/native/cuda/Loops.cuh>
1113
#include <ATen/native/Resize.h>
@@ -14,6 +16,18 @@
1416
#include <ATen/cuda/Atomic.cuh>
1517
#include <ATen/cuda/CUDAUtils.h>
1618

19+
#ifndef AT_PER_OPERATOR_HEADERS
20+
#include <ATen/Functions.h>
21+
#include <ATen/NativeFunctions.h>
22+
#else
23+
#include <ATen/ops/arange.h>
24+
#include <ATen/ops/empty.h>
25+
#include <ATen/ops/empty_quantized.h>
26+
#include <ATen/ops/index_add_native.h>
27+
#include <ATen/ops/index_select_native.h>
28+
#include <ATen/ops/masked_fill_native.h>
29+
#endif
30+
1731
#include <ATen/cuda/CUDAContext.h>
1832
#include <ATen/cuda/cub.h>
1933
#include <c10/util/irange.h>

aten/src/ATen/native/cuda/LegacyThrustHelpers.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/native/cuda/SortingCommon.cuh>
34
#include <ATen/cuda/cub_definitions.cuh>
45

6+
#ifndef AT_PER_OPERATOR_HEADERS
7+
#include <ATen/Functions.h>
8+
#else
9+
#include <ATen/ops/empty_like.h>
10+
#endif
11+
512
#include <ATen/cuda/ThrustAllocator.h>
613
#include <thrust/device_ptr.h>
714
#include <thrust/execution_policy.h>

aten/src/ATen/native/cuda/Loss.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
1-
#include <ATen/ATen.h>
1+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2+
#include <ATen/core/Tensor.h>
23
#include <ATen/AccumulateType.h>
3-
#include <ATen/NativeFunctions.h>
44
#include <ATen/Dispatch.h>
55
#include <ATen/cuda/detail/KernelUtils.h>
66
#include <ATen/native/TensorIterator.h>
7-
#include <aten/src/ATen/TensorUtils.h>
7+
#include <ATen/TensorUtils.h>
8+
#include <ATen/TensorOperators.h>
89
#include <ATen/cuda/detail/KernelUtils.h>
910
#include <ATen/native/cuda/Loops.cuh>
1011
#include <ATen/native/Resize.h>
1112

13+
#ifndef AT_PER_OPERATOR_HEADERS
14+
#include <ATen/Functions.h>
15+
#include <ATen/NativeFunctions.h>
16+
#else
17+
#include <ATen/ops/binary_cross_entropy_backward_native.h>
18+
#include <ATen/ops/binary_cross_entropy_native.h>
19+
#include <ATen/ops/empty_like.h>
20+
#include <ATen/ops/exp.h>
21+
#include <ATen/ops/nll_loss_backward_native.h>
22+
#include <ATen/ops/nll_loss_forward_native.h>
23+
#include <ATen/ops/squeeze.h>
24+
#endif
25+
1226
constexpr float EPSILON = 1e-12;
1327

1428
namespace {

aten/src/ATen/native/cuda/LossCTC.cu

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,32 @@
77
// Graves et al call the probabilities y, we use log_probs (also calling them inputs)
88
// A few optimizations (similar to those here, but also some I didn't take) are described in
99
// 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
10-
10+
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
1111
#include <ATen/TensorUtils.h>
1212
#include <c10/util/Exception.h>
1313
#include <c10/macros/Macros.h>
14-
#include <ATen/ATen.h>
14+
#include <ATen/core/Tensor.h>
1515
#include <ATen/Dispatch.h>
16+
#include <ATen/TensorOperators.h>
1617
#include <ATen/cuda/Atomic.cuh>
1718
#include <ATen/cuda/CUDAContext.h>
1819

20+
#ifndef AT_PER_OPERATOR_HEADERS
21+
#include <ATen/Functions.h>
22+
#include <ATen/NativeFunctions.h>
23+
#else
24+
#include <ATen/ops/_ctc_loss_backward_native.h>
25+
#include <ATen/ops/_ctc_loss_native.h>
26+
#include <ATen/ops/empty.h>
27+
#include <ATen/ops/exp.h>
28+
#include <ATen/ops/full_like.h>
29+
#include <ATen/ops/imag.h>
30+
#include <ATen/ops/logsumexp.h>
31+
#include <ATen/ops/tensor.h>
32+
#include <ATen/ops/where.h>
33+
#include <ATen/ops/zeros.h>
34+
#endif
35+
1936
#include <type_traits>
2037
#include <numeric>
2138

0 commit comments

Comments
 (0)