Skip to content

Commit 9e7031e

Browse files
committed
Factor common code out of C-API peer access functions
1 parent 8563cb0 commit 9e7031e

File tree

1 file changed

+41
-69
lines changed

1 file changed

+41
-69
lines changed

libsyclinterface/source/dpctl_sycl_device_interface.cpp

Lines changed: 41 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,24 @@ DPCTLDevice_GetCompositeDevice(__dpctl_keep const DPCTLSyclDeviceRef DRef)
903903
return nullptr;
904904
}
905905

906+
bool _CallPeerAccess(device dev, device peer)
907+
{
908+
auto BE1 = dev.get_backend();
909+
auto BE2 = peer.get_backend();
910+
911+
if ((BE1 != sycl::backend::ext_oneapi_level_zero &&
912+
BE1 != sycl::backend::ext_oneapi_cuda &&
913+
BE1 != sycl::backend::ext_oneapi_hip) ||
914+
(BE2 != sycl::backend::ext_oneapi_level_zero &&
915+
BE2 != sycl::backend::ext_oneapi_cuda &&
916+
BE2 != sycl::backend::ext_oneapi_hip) ||
917+
(dev == peer))
918+
{
919+
return false;
920+
}
921+
return true;
922+
}
923+
906924
bool DPCTLDevice_CanAccessPeer(__dpctl_keep const DPCTLSyclDeviceRef DRef,
907925
__dpctl_keep const DPCTLSyclDeviceRef PDRef,
908926
DPCTLPeerAccessType PT)
@@ -911,33 +929,13 @@ bool DPCTLDevice_CanAccessPeer(__dpctl_keep const DPCTLSyclDeviceRef DRef,
911929
auto D = unwrap<device>(DRef);
912930
auto PD = unwrap<device>(PDRef);
913931
if (D && PD) {
914-
auto BE1 = D->get_backend();
915-
auto BE2 = PD->get_backend();
916-
917-
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
918-
BE1 != sycl::backend::ext_oneapi_cuda &&
919-
BE1 != sycl::backend::ext_oneapi_hip)
920-
{
921-
std::ostringstream os;
922-
os << "Backend " << BE1 << " does not support peer access";
923-
error_handler(os.str(), __FILE__, __func__, __LINE__);
924-
return false;
925-
}
926-
927-
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
928-
BE2 != sycl::backend::ext_oneapi_cuda &&
929-
BE2 != sycl::backend::ext_oneapi_hip)
930-
{
931-
std::ostringstream os;
932-
os << "Backend " << BE2 << " does not support peer access";
933-
error_handler(os.str(), __FILE__, __func__, __LINE__);
934-
return false;
935-
}
936-
try {
937-
canAccess = D->ext_oneapi_can_access_peer(
938-
*PD, DPCTL_DPCTLPeerAccessTypeToSycl(PT));
939-
} catch (std::exception const &e) {
940-
error_handler(e, __FILE__, __func__, __LINE__);
932+
if (_CallPeerAccess(*D, *PD)) {
933+
try {
934+
canAccess = D->ext_oneapi_can_access_peer(
935+
*PD, DPCTL_DPCTLPeerAccessTypeToSycl(PT));
936+
} catch (std::exception const &e) {
937+
error_handler(e, __FILE__, __func__, __LINE__);
938+
}
941939
}
942940
}
943941
return canAccess;
@@ -949,31 +947,18 @@ void DPCTLDevice_EnablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
949947
auto D = unwrap<device>(DRef);
950948
auto PD = unwrap<device>(PDRef);
951949
if (D && PD) {
952-
auto BE1 = D->get_backend();
953-
auto BE2 = PD->get_backend();
954-
955-
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
956-
BE1 != sycl::backend::ext_oneapi_cuda &&
957-
BE1 != sycl::backend::ext_oneapi_hip)
958-
{
959-
std::ostringstream os;
960-
os << "Backend " << BE1 << " does not support peer access";
961-
error_handler(os.str(), __FILE__, __func__, __LINE__);
950+
if (_CallPeerAccess(*D, *PD)) {
951+
try {
952+
D->ext_oneapi_enable_peer_access(*PD);
953+
} catch (std::exception const &e) {
954+
error_handler(e, __FILE__, __func__, __LINE__);
955+
}
962956
}
963-
964-
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
965-
BE2 != sycl::backend::ext_oneapi_cuda &&
966-
BE2 != sycl::backend::ext_oneapi_hip)
967-
{
957+
else {
968958
std::ostringstream os;
969-
os << "Backend " << BE2 << " does not support peer access";
959+
os << "Given devices do not support peer access";
970960
error_handler(os.str(), __FILE__, __func__, __LINE__);
971961
}
972-
try {
973-
D->ext_oneapi_enable_peer_access(*PD);
974-
} catch (std::exception const &e) {
975-
error_handler(e, __FILE__, __func__, __LINE__);
976-
}
977962
}
978963
return;
979964
}
@@ -984,31 +969,18 @@ void DPCTLDevice_DisablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
984969
auto D = unwrap<device>(DRef);
985970
auto PD = unwrap<device>(PDRef);
986971
if (D && PD) {
987-
auto BE1 = D->get_backend();
988-
auto BE2 = PD->get_backend();
989-
990-
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
991-
BE1 != sycl::backend::ext_oneapi_cuda &&
992-
BE1 != sycl::backend::ext_oneapi_hip)
993-
{
994-
std::ostringstream os;
995-
os << "Backend " << BE1 << " does not support peer access";
996-
error_handler(os.str(), __FILE__, __func__, __LINE__);
972+
if (_CallPeerAccess(*D, *PD)) {
973+
try {
974+
D->ext_oneapi_disable_peer_access(*PD);
975+
} catch (std::exception const &e) {
976+
error_handler(e, __FILE__, __func__, __LINE__);
977+
}
997978
}
998-
999-
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
1000-
BE2 != sycl::backend::ext_oneapi_cuda &&
1001-
BE2 != sycl::backend::ext_oneapi_hip)
1002-
{
979+
else {
1003980
std::ostringstream os;
1004-
os << "Backend " << BE2 << " does not support peer access";
981+
os << "Given devices do not support peer access";
1005982
error_handler(os.str(), __FILE__, __func__, __LINE__);
1006983
}
1007-
try {
1008-
D->ext_oneapi_disable_peer_access(*PD);
1009-
} catch (std::exception const &e) {
1010-
error_handler(e, __FILE__, __func__, __LINE__);
1011-
}
1012984
}
1013985
return;
1014986
}

0 commit comments

Comments
 (0)