@@ -838,91 +838,103 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
838
838
bool is_sparse () const {
839
839
// NB: This method is not virtual and avoid dispatches for performance
840
840
// reasons.
841
- return key_set_.has (DispatchKey::Sparse );
841
+ return key_set_.has_all (c10::sparse_ks );
842
842
}
843
843
844
844
// Whether a tensor is sparse COO or not. Use is_sparse_csr for checking CSR
845
845
// format.
846
846
bool is_sparse_csr () const {
847
- return key_set_.has (DispatchKey::SparseCsrCPU) ||
848
- key_set_.has (DispatchKey::SparseCsrCUDA);
847
+ return key_set_.has_any (c10::sparse_csr_ks);
849
848
}
850
849
851
850
bool is_quantized () const {
852
851
// NB: This method is not virtual and avoid dispatches for performance
853
852
// reasons.
854
- return key_set_.has (DispatchKey::Quantized);
853
+ constexpr auto quantized_ks = DispatchKeySet (DispatchKey::Quantized);
854
+ return key_set_.has_all (quantized_ks);
855
855
}
856
856
857
857
bool is_meta () const {
858
858
// NB: This method is not virtual and avoid dispatches for performance
859
859
// reasons.
860
- return key_set_.has (DispatchKey::Meta);
860
+ constexpr auto meta_ks = DispatchKeySet (DispatchKey::Meta);
861
+ return key_set_.has_all (meta_ks);
861
862
}
862
863
863
864
bool is_cpu () const {
864
865
// NB: This method is not virtual and avoid dispatches for performance
865
866
// reasons.
866
- return key_set_. has_backend (BackendComponent::CPUBit) | |
867
- key_set_. has ( DispatchKey::SparseCsrCPU) ||
868
- key_set_.has (DispatchKey::MkldnnCPU );
867
+ constexpr auto cpu_bits_ks = DispatchKeySet (BackendComponent::CPUBit) |
868
+ DispatchKeySet ({ DispatchKey::SparseCsrCPU, DispatchKey::MkldnnCPU});
869
+ return key_set_.has_any (cpu_bits_ks );
869
870
}
870
871
871
872
bool is_cuda () const {
872
873
// NB: This method is not virtual and avoid dispatches for performance
873
874
// reasons.
874
- return key_set_.has_backend (BackendComponent::CUDABit) ||
875
- key_set_.has (DispatchKey::SparseCsrCUDA);
875
+ constexpr auto cuda_bits_ks = DispatchKeySet (BackendComponent::CUDABit) |
876
+ DispatchKeySet (DispatchKey::SparseCsrCUDA);
877
+ return key_set_.has_any (cuda_bits_ks);
876
878
}
877
879
878
880
bool is_xpu () const {
879
881
// NB: This method is not virtual and avoid dispatches for performance
880
882
// reasons.
881
- return key_set_.has_backend (BackendComponent::XPUBit);
883
+ constexpr auto xpu_ks = DispatchKeySet (BackendComponent::XPUBit);
884
+ return key_set_.has_all (xpu_ks);
882
885
}
883
886
884
887
bool is_xla () const {
885
- return key_set_.has_backend (BackendComponent::XLABit);
888
+ constexpr auto xla_ks = DispatchKeySet (BackendComponent::XLABit);
889
+ return key_set_.has_all (xla_ks);
886
890
}
887
891
888
892
bool is_hpu () const {
889
- return key_set_.has_backend (BackendComponent::HPUBit);
893
+ constexpr auto hpu_ks = DispatchKeySet (BackendComponent::HPUBit);
894
+ return key_set_.has_all (hpu_ks);
890
895
}
891
896
892
897
bool is_lazy () const {
893
- return key_set_.has_backend (BackendComponent::LazyBit);
898
+ constexpr auto lazy_ks = DispatchKeySet (BackendComponent::LazyBit);
899
+ return key_set_.has_all (lazy_ks);
894
900
}
895
901
896
902
bool is_hip () const {
897
903
// NB: This method is not virtual and avoid dispatches for performance
898
904
// reasons.
899
- return key_set_.has_backend (BackendComponent::HIPBit);
905
+ constexpr auto hip_ks = DispatchKeySet (BackendComponent::HIPBit);
906
+ return key_set_.has_all (hip_ks);
900
907
}
901
908
902
909
bool is_ve () const {
903
910
// NB: This method is not virtual and avoid dispatches for performance
904
911
// reasons.
905
- return key_set_.has_backend (BackendComponent::VEBit);
912
+ constexpr auto ve_ks = DispatchKeySet (BackendComponent::VEBit);
913
+ return key_set_.has_all (ve_ks);
906
914
}
907
915
908
916
bool is_mkldnn () const {
909
- return key_set_.has (DispatchKey::MkldnnCPU );
917
+ return key_set_.has_all (c10::mkldnn_ks );
910
918
}
911
919
912
920
bool is_vulkan () const {
913
- return key_set_.has (DispatchKey::Vulkan);
921
+ constexpr auto vulkan_ks = DispatchKeySet (DispatchKey::Vulkan);
922
+ return key_set_.has_all (vulkan_ks);
914
923
}
915
924
916
925
bool is_metal () const {
917
- return key_set_.has (DispatchKey::Metal);
926
+ constexpr auto metal_ks = DispatchKeySet (DispatchKey::Metal);
927
+ return key_set_.has_all (metal_ks);
918
928
}
919
929
920
930
bool is_mlc () const {
921
- return key_set_.has (DispatchKey::MLC);
931
+ constexpr auto mls_ks = DispatchKeySet (DispatchKey::MLC);
932
+ return key_set_.has_all (mls_ks);
922
933
}
923
934
924
935
bool is_ort () const {
925
- return key_set_.has (DispatchKey::ORT);
936
+ constexpr auto ort_ks = DispatchKeySet (DispatchKey::ORT);
937
+ return key_set_.has_all (ort_ks);
926
938
}
927
939
928
940
// TODO: remove this once we don't automatically enabled Autograd dispatch
@@ -938,8 +950,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
938
950
// Invariant:
939
951
// Inference tensor has version_counter_.enabled() == false
940
952
bool is_inference () {
941
- bool no_ADInplaceOrView = !key_set_.has (c10::DispatchKey::ADInplaceOrView );
942
- bool no_Autograd = ( key_set_ & c10::autograd_dispatch_keyset). empty ( );
953
+ bool no_ADInplaceOrView = !key_set_.has_any (c10::inplace_or_view_ks );
954
+ bool no_Autograd = ! key_set_. has_any ( c10::autograd_dispatch_keyset);
943
955
TORCH_INTERNAL_ASSERT_DEBUG_ONLY (
944
956
no_ADInplaceOrView == no_Autograd,
945
957
" ADInplaceOrView and Autograd keys must be on/off at the same time." );
@@ -960,14 +972,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
960
972
961
973
Layout layout () const {
962
974
// NB: This method is not virtual and avoid dispatches for perf.
963
- if (is_sparse ()) {
975
+ // strided is also the most common layout type, so we check for
976
+ // strided case first.
977
+ // This keyset must also be kept in sync with the logic in
978
+ // is_sparse() / is_sparse_csr() / is_mkldnn()
979
+ constexpr auto sparse_and_sparsecsr_and_mkldnn_ks =
980
+ c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks;
981
+ if (!key_set_.has_any (sparse_and_sparsecsr_and_mkldnn_ks)) {
982
+ return kStrided ;
983
+ } else if (is_sparse ()) {
964
984
return kSparse ;
965
985
} else if (is_sparse_csr ()) {
966
986
return kSparseCsr ;
967
- } else if (is_mkldnn ()) {
968
- return kMkldnn ;
969
987
} else {
970
- return kStrided ;
988
+ TORCH_INTERNAL_ASSERT (
989
+ is_mkldnn (), " There is an error in the layout calculation logic." );
990
+ return kMkldnn ;
971
991
}
972
992
}
973
993
@@ -1053,7 +1073,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1053
1073
* Whether or not the imaginary part of the tensor should be negated
1054
1074
*/
1055
1075
inline bool is_conj () const {
1056
- return key_set_.has (DispatchKey::Conjugate);
1076
+ constexpr auto conjugate_ks = DispatchKeySet (DispatchKey::Conjugate);
1077
+ return key_set_.has_all (conjugate_ks);
1057
1078
}
1058
1079
1059
1080
/* *
@@ -1073,7 +1094,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1073
1094
* Whether or not the tensor is a zerotensor
1074
1095
*/
1075
1096
inline bool _is_zerotensor () const {
1076
- return key_set_.has (DispatchKey::ZeroTensor);
1097
+ constexpr auto zerotensor_ks = DispatchKeySet (DispatchKey::ZeroTensor);
1098
+ return key_set_.has_all (zerotensor_ks);
1077
1099
}
1078
1100
1079
1101
/* *
@@ -1093,7 +1115,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1093
1115
* Whether or not the tensor should be negated
1094
1116
*/
1095
1117
inline bool is_neg () const {
1096
- return key_set_.has (DispatchKey::Negative);
1118
+ constexpr auto negative_ks = DispatchKeySet (DispatchKey::Negative);
1119
+ return key_set_.has_all (negative_ks);
1097
1120
}
1098
1121
1099
1122
/* *
@@ -1464,14 +1487,14 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
1464
1487
1465
1488
void set_python_dispatch (bool k) {
1466
1489
if (k) {
1467
- key_set_ = key_set_.add (DispatchKey::Python). add (DispatchKey::PythonTLSSnapshot );
1490
+ key_set_ = key_set_.add (c10::python_ks );
1468
1491
} else {
1469
- key_set_ = key_set_. remove (DispatchKey::Python). remove (DispatchKey::PythonTLSSnapshot) ;
1492
+ key_set_ = key_set_ - c10::python_ks ;
1470
1493
}
1471
1494
}
1472
1495
1473
1496
bool is_python_dispatch () const {
1474
- return key_set_.has (DispatchKey::Python );
1497
+ return key_set_.has_all (c10::python_ks );
1475
1498
}
1476
1499
1477
1500
/* *
0 commit comments