@@ -903,6 +903,24 @@ DPCTLDevice_GetCompositeDevice(__dpctl_keep const DPCTLSyclDeviceRef DRef)
903
903
return nullptr ;
904
904
}
905
905
906
+ bool _CallPeerAccess (device dev, device peer)
907
+ {
908
+ auto BE1 = dev.get_backend ();
909
+ auto BE2 = peer.get_backend ();
910
+
911
+ if ((BE1 != sycl::backend::ext_oneapi_level_zero &&
912
+ BE1 != sycl::backend::ext_oneapi_cuda &&
913
+ BE1 != sycl::backend::ext_oneapi_hip) ||
914
+ (BE2 != sycl::backend::ext_oneapi_level_zero &&
915
+ BE2 != sycl::backend::ext_oneapi_cuda &&
916
+ BE2 != sycl::backend::ext_oneapi_hip) ||
917
+ (dev == peer))
918
+ {
919
+ return false ;
920
+ }
921
+ return true ;
922
+ }
923
+
906
924
bool DPCTLDevice_CanAccessPeer (__dpctl_keep const DPCTLSyclDeviceRef DRef,
907
925
__dpctl_keep const DPCTLSyclDeviceRef PDRef,
908
926
DPCTLPeerAccessType PT)
@@ -911,33 +929,13 @@ bool DPCTLDevice_CanAccessPeer(__dpctl_keep const DPCTLSyclDeviceRef DRef,
911
929
auto D = unwrap<device>(DRef);
912
930
auto PD = unwrap<device>(PDRef);
913
931
if (D && PD) {
914
- auto BE1 = D->get_backend ();
915
- auto BE2 = PD->get_backend ();
916
-
917
- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
918
- BE1 != sycl::backend::ext_oneapi_cuda &&
919
- BE1 != sycl::backend::ext_oneapi_hip)
920
- {
921
- std::ostringstream os;
922
- os << " Backend " << BE1 << " does not support peer access" ;
923
- error_handler (os.str (), __FILE__, __func__, __LINE__);
924
- return false ;
925
- }
926
-
927
- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
928
- BE2 != sycl::backend::ext_oneapi_cuda &&
929
- BE2 != sycl::backend::ext_oneapi_hip)
930
- {
931
- std::ostringstream os;
932
- os << " Backend " << BE2 << " does not support peer access" ;
933
- error_handler (os.str (), __FILE__, __func__, __LINE__);
934
- return false ;
935
- }
936
- try {
937
- canAccess = D->ext_oneapi_can_access_peer (
938
- *PD, DPCTL_DPCTLPeerAccessTypeToSycl (PT));
939
- } catch (std::exception const &e) {
940
- error_handler (e, __FILE__, __func__, __LINE__);
932
+ if (_CallPeerAccess (*D, *PD)) {
933
+ try {
934
+ canAccess = D->ext_oneapi_can_access_peer (
935
+ *PD, DPCTL_DPCTLPeerAccessTypeToSycl (PT));
936
+ } catch (std::exception const &e) {
937
+ error_handler (e, __FILE__, __func__, __LINE__);
938
+ }
941
939
}
942
940
}
943
941
return canAccess;
@@ -949,31 +947,18 @@ void DPCTLDevice_EnablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
949
947
auto D = unwrap<device>(DRef);
950
948
auto PD = unwrap<device>(PDRef);
951
949
if (D && PD) {
952
- auto BE1 = D->get_backend ();
953
- auto BE2 = PD->get_backend ();
954
-
955
- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
956
- BE1 != sycl::backend::ext_oneapi_cuda &&
957
- BE1 != sycl::backend::ext_oneapi_hip)
958
- {
959
- std::ostringstream os;
960
- os << " Backend " << BE1 << " does not support peer access" ;
961
- error_handler (os.str (), __FILE__, __func__, __LINE__);
950
+ if (_CallPeerAccess (*D, *PD)) {
951
+ try {
952
+ D->ext_oneapi_enable_peer_access (*PD);
953
+ } catch (std::exception const &e) {
954
+ error_handler (e, __FILE__, __func__, __LINE__);
955
+ }
962
956
}
963
-
964
- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
965
- BE2 != sycl::backend::ext_oneapi_cuda &&
966
- BE2 != sycl::backend::ext_oneapi_hip)
967
- {
957
+ else {
968
958
std::ostringstream os;
969
- os << " Backend " << BE2 << " does not support peer access" ;
959
+ os << " Given devices do not support peer access" ;
970
960
error_handler (os.str (), __FILE__, __func__, __LINE__);
971
961
}
972
- try {
973
- D->ext_oneapi_enable_peer_access (*PD);
974
- } catch (std::exception const &e) {
975
- error_handler (e, __FILE__, __func__, __LINE__);
976
- }
977
962
}
978
963
return ;
979
964
}
@@ -984,31 +969,18 @@ void DPCTLDevice_DisablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
984
969
auto D = unwrap<device>(DRef);
985
970
auto PD = unwrap<device>(PDRef);
986
971
if (D && PD) {
987
- auto BE1 = D->get_backend ();
988
- auto BE2 = PD->get_backend ();
989
-
990
- if (BE1 != sycl::backend::ext_oneapi_level_zero &&
991
- BE1 != sycl::backend::ext_oneapi_cuda &&
992
- BE1 != sycl::backend::ext_oneapi_hip)
993
- {
994
- std::ostringstream os;
995
- os << " Backend " << BE1 << " does not support peer access" ;
996
- error_handler (os.str (), __FILE__, __func__, __LINE__);
972
+ if (_CallPeerAccess (*D, *PD)) {
973
+ try {
974
+ D->ext_oneapi_disable_peer_access (*PD);
975
+ } catch (std::exception const &e) {
976
+ error_handler (e, __FILE__, __func__, __LINE__);
977
+ }
997
978
}
998
-
999
- if (BE2 != sycl::backend::ext_oneapi_level_zero &&
1000
- BE2 != sycl::backend::ext_oneapi_cuda &&
1001
- BE2 != sycl::backend::ext_oneapi_hip)
1002
- {
979
+ else {
1003
980
std::ostringstream os;
1004
- os << " Backend " << BE2 << " does not support peer access" ;
981
+ os << " Given devices do not support peer access" ;
1005
982
error_handler (os.str (), __FILE__, __func__, __LINE__);
1006
983
}
1007
- try {
1008
- D->ext_oneapi_disable_peer_access (*PD);
1009
- } catch (std::exception const &e) {
1010
- error_handler (e, __FILE__, __func__, __LINE__);
1011
- }
1012
984
}
1013
985
return ;
1014
986
}
0 commit comments