Skip to content

Commit cc612d2

Browse files
committed
DispatchKeySet perf improvements
Pull Request resolved: #72828 Reland of D34034847 ghstack-source-id: 149482806 Differential Revision: [D34227615](https://our.internmc.facebook.com/intern/diff/D34227615/)
1 parent 88dd603 commit cc612d2

File tree

4 files changed

+73
-39
lines changed

4 files changed

+73
-39
lines changed

aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ Tensor MakeStridedQTensorCPU(
160160
allocator->allocate(size_bytes),
161161
allocator,
162162
/* resizable = */ true);
163+
constexpr auto quantized_cpu_ks = at::DispatchKeySet(at::DispatchKey::QuantizedCPU);
163164
auto tensor = detail::make_tensor<QTensorImpl>(
164165
storage,
165-
at::DispatchKeySet(at::DispatchKey::QuantizedCPU),
166+
quantized_cpu_ks,
166167
dtype,
167168
quantizer);
168169
get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides);

c10/core/DispatchKeySet.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,18 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
641641
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
642642
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
643643

644+
constexpr DispatchKeySet python_ks = DispatchKeySet({
645+
DispatchKey::Python,
646+
DispatchKey::PythonTLSSnapshot,
647+
});
648+
649+
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
650+
651+
constexpr DispatchKeySet sparse_csr_ks =
652+
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});
653+
654+
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
655+
644656
// backend dispatch keys that map to DispatchKey::AutogradOther
645657
// NB: keys in this set also get associated with CompositeImplicitAutograd
646658
constexpr DispatchKeySet autogradother_backends =

c10/core/TensorImpl.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@ TensorImpl::TensorImpl(
148148
numel_(0),
149149
data_type_(data_type),
150150
device_opt_(storage_.device()),
151-
key_set_(key_set.remove(
152-
DispatchKey::Python).remove(
153-
DispatchKey::PythonTLSSnapshot)) { // See [Note: Python key removal]
151+
key_set_(key_set - c10::python_ks) { // See [Note: Python key removal]
154152
init_bitfields();
155153
// Inference tensor doesn't have version counter.
156154
if (!is_inference()) {
@@ -195,8 +193,8 @@ TensorImpl::TensorImpl(
195193

196194
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
197195

198-
key_set =
199-
key_set.remove(DispatchKey::Python).remove(DispatchKey::PythonTLSSnapshot); // See [Note: Python key removal]
196+
// See [Note: Python key removal]
197+
key_set = key_set - c10::python_ks;
200198

201199
// Inference tensor doesn't have autograd related keys.
202200
if (inference_mode) {

c10/core/TensorImpl.h

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -838,91 +838,103 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
838838
bool is_sparse() const {
839839
// NB: This method is not virtual and avoid dispatches for performance
840840
// reasons.
841-
return key_set_.has(DispatchKey::Sparse);
841+
return key_set_.has_all(c10::sparse_ks);
842842
}
843843

844844
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
845845
// format.
846846
bool is_sparse_csr() const {
847-
return key_set_.has(DispatchKey::SparseCsrCPU) ||
848-
key_set_.has(DispatchKey::SparseCsrCUDA);
847+
return key_set_.has_any(c10::sparse_csr_ks);
849848
}
850849

851850
bool is_quantized() const {
852851
// NB: This method is not virtual and avoid dispatches for performance
853852
// reasons.
854-
return key_set_.has(DispatchKey::Quantized);
853+
constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized);
854+
return key_set_.has_all(quantized_ks);
855855
}
856856

857857
bool is_meta() const {
858858
// NB: This method is not virtual and avoid dispatches for performance
859859
// reasons.
860-
return key_set_.has(DispatchKey::Meta);
860+
constexpr auto meta_ks = DispatchKeySet(DispatchKey::Meta);
861+
return key_set_.has_all(meta_ks);
861862
}
862863

863864
bool is_cpu() const {
864865
// NB: This method is not virtual and avoid dispatches for performance
865866
// reasons.
866-
return key_set_.has_backend(BackendComponent::CPUBit) ||
867-
key_set_.has(DispatchKey::SparseCsrCPU) ||
868-
key_set_.has(DispatchKey::MkldnnCPU);
867+
constexpr auto cpu_bits_ks = DispatchKeySet(BackendComponent::CPUBit) |
868+
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU});
869+
return key_set_.has_any(cpu_bits_ks);
869870
}
870871

871872
bool is_cuda() const {
872873
// NB: This method is not virtual and avoid dispatches for performance
873874
// reasons.
874-
return key_set_.has_backend(BackendComponent::CUDABit) ||
875-
key_set_.has(DispatchKey::SparseCsrCUDA);
875+
constexpr auto cuda_bits_ks = DispatchKeySet(BackendComponent::CUDABit) |
876+
DispatchKeySet(DispatchKey::SparseCsrCUDA);
877+
return key_set_.has_any(cuda_bits_ks);
876878
}
877879

878880
bool is_xpu() const {
879881
// NB: This method is not virtual and avoid dispatches for performance
880882
// reasons.
881-
return key_set_.has_backend(BackendComponent::XPUBit);
883+
constexpr auto xpu_ks = DispatchKeySet(BackendComponent::XPUBit);
884+
return key_set_.has_all(xpu_ks);
882885
}
883886

884887
bool is_xla() const {
885-
return key_set_.has_backend(BackendComponent::XLABit);
888+
constexpr auto xla_ks = DispatchKeySet(BackendComponent::XLABit);
889+
return key_set_.has_all(xla_ks);
886890
}
887891

888892
bool is_hpu() const {
889-
return key_set_.has_backend(BackendComponent::HPUBit);
893+
constexpr auto hpu_ks = DispatchKeySet(BackendComponent::HPUBit);
894+
return key_set_.has_all(hpu_ks);
890895
}
891896

892897
bool is_lazy() const {
893-
return key_set_.has_backend(BackendComponent::LazyBit);
898+
constexpr auto lazy_ks = DispatchKeySet(BackendComponent::LazyBit);
899+
return key_set_.has_all(lazy_ks);
894900
}
895901

896902
bool is_hip() const {
897903
// NB: This method is not virtual and avoid dispatches for performance
898904
// reasons.
899-
return key_set_.has_backend(BackendComponent::HIPBit);
905+
constexpr auto hip_ks = DispatchKeySet(BackendComponent::HIPBit);
906+
return key_set_.has_all(hip_ks);
900907
}
901908

902909
bool is_ve() const {
903910
// NB: This method is not virtual and avoid dispatches for performance
904911
// reasons.
905-
return key_set_.has_backend(BackendComponent::VEBit);
912+
constexpr auto ve_ks = DispatchKeySet(BackendComponent::VEBit);
913+
return key_set_.has_all(ve_ks);
906914
}
907915

908916
bool is_mkldnn() const {
909-
return key_set_.has(DispatchKey::MkldnnCPU);
917+
return key_set_.has_all(c10::mkldnn_ks);
910918
}
911919

912920
bool is_vulkan() const {
913-
return key_set_.has(DispatchKey::Vulkan);
921+
constexpr auto vulkan_ks = DispatchKeySet(DispatchKey::Vulkan);
922+
return key_set_.has_all(vulkan_ks);
914923
}
915924

916925
bool is_metal() const {
917-
return key_set_.has(DispatchKey::Metal);
926+
constexpr auto metal_ks = DispatchKeySet(DispatchKey::Metal);
927+
return key_set_.has_all(metal_ks);
918928
}
919929

920930
bool is_mlc() const {
921-
return key_set_.has(DispatchKey::MLC);
931+
constexpr auto mls_ks = DispatchKeySet(DispatchKey::MLC);
932+
return key_set_.has_all(mls_ks);
922933
}
923934

924935
bool is_ort() const {
925-
return key_set_.has(DispatchKey::ORT);
936+
constexpr auto ort_ks = DispatchKeySet(DispatchKey::ORT);
937+
return key_set_.has_all(ort_ks);
926938
}
927939

928940
// TODO: remove this once we don't automatically enabled Autograd dispatch
@@ -938,8 +950,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
938950
// Invariant:
939951
// Inference tensor has version_counter_.enabled() == false
940952
bool is_inference() {
941-
bool no_ADInplaceOrView = !key_set_.has(c10::DispatchKey::ADInplaceOrView);
942-
bool no_Autograd = (key_set_ & c10::autograd_dispatch_keyset).empty();
953+
bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks);
954+
bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset);
943955
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
944956
no_ADInplaceOrView == no_Autograd,
945957
"ADInplaceOrView and Autograd keys must be on/off at the same time.");
@@ -960,14 +972,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
960972

961973
Layout layout() const {
962974
// NB: This method is not virtual and avoid dispatches for perf.
963-
if (is_sparse()) {
975+
// strided is also the most common layout type, so we check for
976+
// strided case first.
977+
// This keyset must also be kept in sync with the logic in
978+
// is_sparse() / is_sparse_csr() / is_mkldnn()
979+
constexpr auto sparse_and_sparsecsr_and_mkldnn_ks =
980+
c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks;
981+
if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) {
982+
return kStrided;
983+
} else if (is_sparse()) {
964984
return kSparse;
965985
} else if (is_sparse_csr()) {
966986
return kSparseCsr;
967-
} else if (is_mkldnn()) {
968-
return kMkldnn;
969987
} else {
970-
return kStrided;
988+
TORCH_INTERNAL_ASSERT(
989+
is_mkldnn(), "There is an error in the layout calculation logic.");
990+
return kMkldnn;
971991
}
972992
}
973993

@@ -1053,7 +1073,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
10531073
* Whether or not the imaginary part of the tensor should be negated
10541074
*/
10551075
inline bool is_conj() const {
1056-
return key_set_.has(DispatchKey::Conjugate);
1076+
constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate);
1077+
return key_set_.has_all(conjugate_ks);
10571078
}
10581079

10591080
/**
@@ -1073,7 +1094,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
10731094
* Whether or not the tensor is a zerotensor
10741095
*/
10751096
inline bool _is_zerotensor() const {
1076-
return key_set_.has(DispatchKey::ZeroTensor);
1097+
constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor);
1098+
return key_set_.has_all(zerotensor_ks);
10771099
}
10781100

10791101
/**
@@ -1093,7 +1115,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
10931115
* Whether or not the tensor should be negated
10941116
*/
10951117
inline bool is_neg() const {
1096-
return key_set_.has(DispatchKey::Negative);
1118+
constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative);
1119+
return key_set_.has_all(negative_ks);
10971120
}
10981121

10991122
/**
@@ -1464,14 +1487,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
14641487

14651488
void set_python_dispatch(bool k) {
14661489
if (k) {
1467-
key_set_ = key_set_.add(DispatchKey::Python).add(DispatchKey::PythonTLSSnapshot);
1490+
key_set_ = key_set_.add(c10::python_ks);
14681491
} else {
1469-
key_set_ = key_set_.remove(DispatchKey::Python).remove(DispatchKey::PythonTLSSnapshot);
1492+
key_set_ = key_set_ - c10::python_ks;
14701493
}
14711494
}
14721495

14731496
bool is_python_dispatch() const {
1474-
return key_set_.has(DispatchKey::Python);
1497+
return key_set_.has_all(c10::python_ks);
14751498
}
14761499

14771500
/**

0 commit comments

Comments
 (0)