Skip to content

Commit 7106d21

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch] Add native fast path for transformer encoder inference (pytorch#76333)
Summary: Pull Request resolved: pytorch#76333 The current PyTorch multi-head attention and transformer implementations are slow. This should speed them up for inference. ghstack-source-id: 154737857 (Note: this ignores all push blocking failures!) Test Plan: CI Reviewed By: cpuhrsch Differential Revision: D35239925 fbshipit-source-id: 5a7eb8ff79bc6afb4b7d45075ddb2a24a6e2df28
1 parent b941d10 commit 7106d21

File tree

16 files changed

+1815
-30
lines changed

16 files changed

+1815
-30
lines changed

BUILD.bazel

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ filegroup(
228228
),
229229
)
230230

231+
filegroup(
232+
name = "aten_native_transformers_cpp",
233+
srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
234+
)
235+
231236
filegroup(
232237
name = "aten_native_mkl_cpp",
233238
srcs = glob(["aten/src/ATen/native/mkl/*.cpp", "aten/src/ATen/mkl/*.cpp"]),
@@ -278,6 +283,7 @@ filegroup(
278283
"aten/src/ATen/native/miopen/*.cpp",
279284
"aten/src/ATen/native/nested/cuda/*.cpp",
280285
"aten/src/ATen/native/sparse/cuda/*.cpp",
286+
"aten/src/ATen/native/transformers/cuda/*.cpp",
281287
"aten/src/THC/*.cpp",
282288
],
283289
),
@@ -292,6 +298,7 @@ filegroup(
292298
"aten/src/ATen/native/nested/cuda/*.cu",
293299
"aten/src/ATen/native/quantized/cuda/*.cu",
294300
"aten/src/ATen/native/sparse/cuda/*.cu",
301+
"aten/src/ATen/native/transformers/cuda/*.cu",
295302
]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
296303
# It's a bit puzzling to me why it's not necessary to declare the
297304
# target that generates these sources...
@@ -393,6 +400,7 @@ cc_library(
393400
":aten_native_quantized_cpp",
394401
":aten_native_sparse_cpp",
395402
":aten_native_nested_cpp",
403+
":aten_native_transformers_cpp",
396404
":aten_native_xnnpack",
397405
":aten_src_ATen_config",
398406
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),

aten/src/ATen/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ file(GLOB native_quantized_cpp
105105
"native/quantized/*.cpp"
106106
"native/quantized/cpu/*.cpp")
107107
file(GLOB native_nested_cpp "native/nested/*.cpp")
108+
file(GLOB native_transformers_cpp "native/transformers/*.cpp")
108109

109110
file(GLOB native_h "native/*.h")
110111
file(GLOB native_ao_sparse_h
@@ -128,6 +129,8 @@ file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
128129
file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu")
129130
file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp")
130131
file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp")
132+
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
133+
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
131134

132135
file(GLOB native_hip_hip "native/hip/*.hip")
133136
file(GLOB native_hip_cpp "native/hip/*.cpp")
@@ -140,6 +143,8 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
140143
file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
141144
file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
142145
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
146+
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
147+
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
143148
file(GLOB native_utils_cpp "native/utils/*.cpp")
144149

145150
# XNNPACK
@@ -162,6 +167,7 @@ else()
162167
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
163168
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
164169
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
170+
${native_transformers_cpp}
165171
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
166172
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}
167173
)
@@ -205,6 +211,7 @@ if(USE_CUDA)
205211
${native_nested_cuda_cu}
206212
${native_sparse_cuda_cu}
207213
${native_quantized_cuda_cu}
214+
${native_transformers_cuda_cu}
208215
${cuda_generated_sources}
209216
)
210217
list(APPEND ATen_CUDA_CPP_SRCS
@@ -216,6 +223,7 @@ if(USE_CUDA)
216223
${native_quantized_cuda_cpp}
217224
${native_quantized_cudnn_cpp}
218225
${native_sparse_cuda_cpp}
226+
${native_transformers_cuda_cpp}
219227
)
220228
set(ATen_CUDA_LINALG_SRCS ${native_cuda_linalg_cpp})
221229
if(NOT BUILD_LAZY_CUDA_LINALG)
@@ -238,9 +246,9 @@ endif()
238246

239247
if(USE_ROCM)
240248
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
241-
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip})
249+
set(ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} ${native_hip_hip} ${native_nested_hip_hip} ${native_sparse_hip_hip} ${native_quantized_hip_hip} ${native_transformers_hip_hip})
242250
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
243-
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
251+
set(all_hip_cpp ${native_nested_hip_cpp} ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} ${cuda_generated_sources} ${ATen_HIP_SRCS})
244252
set(all_hip_cpp ${native_miopen_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp})
245253
endif()
246254

aten/src/ATen/native/native_functions.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4662,6 +4662,12 @@
46624662

46634663
- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
46644664

4665+
# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads).
4666+
- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
4667+
dispatch:
4668+
CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu
4669+
CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda
4670+
46654671
- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
46664672
device_check: NoCheck # cpu_nested_shape_example will always be on CPU
46674673
dispatch:
@@ -11602,3 +11608,14 @@
1160211608
variants: method
1160311609
dispatch:
1160411610
NestedTensorCPU, NestedTensorCUDA: NestedTensor_layer_norm
11611+
11612+
# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
11613+
- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None) -> Tensor
11614+
variants: function
11615+
dispatch:
11616+
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: transformer_encoder_layer_forward
11617+
11618+
- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True) -> (Tensor, Tensor)
11619+
variants: function
11620+
dispatch:
11621+
CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: native_multi_head_attention

aten/src/ATen/native/nested/NestedTensorMath.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <ATen/native/nested/NestedTensorMath.h>
2+
13
#include <ATen/ATen.h>
24
#include <ATen/AccumulateType.h>
35
#include <ATen/NamedTensorUtils.h>
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#pragma once
22

3+
#include <c10/macros/Macros.h>
4+
5+
#include <vector>
6+
37
namespace at {
48
namespace native {
59
struct NestedTensorImpl;
610

711
// TODO: cache this and only do it once per NestedTensor
812
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
913

10-
std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
14+
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl& nt);
1115

1216
} // namespace native
1317
} // namespace at

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ Tensor nested_from_padded_cuda(
3535
const Tensor& sizes,
3636
bool do_transform_0213) {
3737
if (padded.dim() > 1 && padded.dim() < 5) {
38+
if (padded.dtype() != kFloat && padded.dtype() != kHalf) {
39+
TORCH_WARN_ONCE(
40+
"nested_from_padded CUDA kernels only support fp32/fp16; falling "
41+
"back to slower generic kernel");
42+
return at::native::nested_from_padded_generic(padded, sizes, do_transform_0213);
43+
}
3844
TORCH_CHECK(
3945
(padded.dim() == 4 && do_transform_0213) ||
4046
(padded.dim() == 3 && !do_transform_0213),

0 commit comments

Comments
 (0)