@@ -35,11 +35,28 @@ filterP2PDevices(ur_device_handle_t hSourceDevice,
35
35
}
36
36
37
37
static std::vector<std::vector<ur_device_handle_t >>
38
- populateP2PDevices (size_t maxDevices,
39
- const std::vector<ur_device_handle_t > &devices) {
40
- std::vector<std::vector<ur_device_handle_t >> p2pDevices (maxDevices);
38
+ populateP2PDevices (const std::vector<ur_device_handle_t > &devices) {
39
+ std::vector<ur_device_handle_t > allDevices;
40
+ std::function<void (ur_device_handle_t )> collectDeviceAndSubdevices =
41
+ [&allDevices, &collectDeviceAndSubdevices](ur_device_handle_t device) {
42
+ allDevices.push_back (device);
43
+ for (auto &subDevice : device->SubDevices ) {
44
+ collectDeviceAndSubdevices (subDevice);
45
+ }
46
+ };
47
+
41
48
for (auto &device : devices) {
42
- p2pDevices[device->Id .value ()] = filterP2PDevices (device, devices);
49
+ collectDeviceAndSubdevices (device);
50
+ }
51
+
52
+ uint64_t maxDeviceId = 0 ;
53
+ for (auto &device : allDevices) {
54
+ maxDeviceId = std::max (maxDeviceId, device->Id .value ());
55
+ }
56
+
57
+ std::vector<std::vector<ur_device_handle_t >> p2pDevices (maxDeviceId + 1 );
58
+ for (auto &device : allDevices) {
59
+ p2pDevices[device->Id .value ()] = filterP2PDevices (device, allDevices);
43
60
}
44
61
return p2pDevices;
45
62
}
@@ -83,8 +100,7 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
83
100
nativeEventsPool(this , std::make_unique<v2::provider_normal>(
84
101
this , v2::QUEUE_IMMEDIATE,
85
102
v2::EVENT_FLAGS_PROFILING_ENABLED)),
86
- p2pAccessDevices(populateP2PDevices(
87
- phDevices[0 ]->Platform->getNumDevices (), this->hDevices)),
103
+ p2pAccessDevices(populateP2PDevices(this ->hDevices)),
88
104
defaultUSMPool(this , nullptr ), asyncPool(this , nullptr ) {}
89
105
90
106
ur_result_t ur_context_handle_t_::retain () {
0 commit comments