Skip to content

Commit e789a08

Browse files
authored
[UR][L0v2] Include subdevices when populating p2p access devices (#19772)
1 parent 198a23e commit e789a08

File tree

1 file changed

+22
-6
lines changed
  • unified-runtime/source/adapters/level_zero/v2

1 file changed

+22
-6
lines changed

unified-runtime/source/adapters/level_zero/v2/context.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,28 @@ filterP2PDevices(ur_device_handle_t hSourceDevice,
3535
}
3636

3737
static std::vector<std::vector<ur_device_handle_t>>
38-
populateP2PDevices(size_t maxDevices,
39-
const std::vector<ur_device_handle_t> &devices) {
40-
std::vector<std::vector<ur_device_handle_t>> p2pDevices(maxDevices);
38+
populateP2PDevices(const std::vector<ur_device_handle_t> &devices) {
39+
std::vector<ur_device_handle_t> allDevices;
40+
std::function<void(ur_device_handle_t)> collectDeviceAndSubdevices =
41+
[&allDevices, &collectDeviceAndSubdevices](ur_device_handle_t device) {
42+
allDevices.push_back(device);
43+
for (auto &subDevice : device->SubDevices) {
44+
collectDeviceAndSubdevices(subDevice);
45+
}
46+
};
47+
4148
for (auto &device : devices) {
42-
p2pDevices[device->Id.value()] = filterP2PDevices(device, devices);
49+
collectDeviceAndSubdevices(device);
50+
}
51+
52+
uint64_t maxDeviceId = 0;
53+
for (auto &device : allDevices) {
54+
maxDeviceId = std::max(maxDeviceId, device->Id.value());
55+
}
56+
57+
std::vector<std::vector<ur_device_handle_t>> p2pDevices(maxDeviceId + 1);
58+
for (auto &device : allDevices) {
59+
p2pDevices[device->Id.value()] = filterP2PDevices(device, allDevices);
4360
}
4461
return p2pDevices;
4562
}
@@ -83,8 +100,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
83100
nativeEventsPool(this, std::make_unique<v2::provider_normal>(
84101
this, v2::QUEUE_IMMEDIATE,
85102
v2::EVENT_FLAGS_PROFILING_ENABLED)),
86-
p2pAccessDevices(populateP2PDevices(
87-
phDevices[0]->Platform->getNumDevices(), this->hDevices)),
103+
p2pAccessDevices(populateP2PDevices(this->hDevices)),
88104
defaultUSMPool(this, nullptr), asyncPool(this, nullptr) {}
89105

90106
ur_result_t ur_context_handle_t_::retain() {

0 commit comments

Comments
 (0)