Skip to content

Commit a6c08a3

Browse files
committed
[UR] Refactor urDevicePartition to use desc struct
1 parent b939be2 commit a6c08a3

File tree

21 files changed

+579
-353
lines changed

21 files changed

+579
-353
lines changed

include/ur.py

Lines changed: 57 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class ur_structure_type_v(IntEnum):
225225
PROGRAM_NATIVE_PROPERTIES = 23 ## ::ur_program_native_properties_t
226226
SAMPLER_NATIVE_PROPERTIES = 24 ## ::ur_sampler_native_properties_t
227227
QUEUE_NATIVE_DESC = 25 ## ::ur_queue_native_desc_t
228+
DEVICE_PARTITION_DESC = 26 ## ::ur_device_partition_desc_t
228229

229230
class ur_structure_type_t(c_int):
230231
def __str__(self):
@@ -502,15 +503,15 @@ class ur_device_info_v(IntEnum):
502503
PREFERRED_INTEROP_USER_SYNC = 74 ## [::ur_bool_t] prefer user synchronization when sharing object with
503504
## other API
504505
PARENT_DEVICE = 75 ## [::ur_device_handle_t] return parent device handle
505-
PARTITION_PROPERTIES = 76 ## [::ur_device_partition_property_t[]] Returns an array of partition
506-
## types supported by the device
506+
PARTITION_PROPERTIES = 76 ## [::ur_device_partition_t[]] Returns an array of partition types
507+
## supported by the device
507508
PARTITION_MAX_SUB_DEVICES = 77 ## [uint32_t] maximum number of sub-devices when the device is
508509
## partitioned
509510
PARTITION_AFFINITY_DOMAIN = 78 ## [::ur_device_affinity_domain_flags_t] Returns a bit-field of the
510511
## supported affinity domains for partitioning.
511512
## If the device does not support any affinity domains, then 0 will be returned.
512-
PARTITION_TYPE = 79 ## [::ur_device_partition_property_t[]] return an array of
513-
## ::ur_device_partition_property_t for properties specified in
513+
PARTITION_TYPE = 79 ## [::ur_device_partition_desc_t[]] return an array of
514+
## ::ur_device_partition_desc_t for properties specified in
514515
## ::urDevicePartition
515516
MAX_NUM_SUB_GROUPS = 80 ## [uint32_t] max number of sub groups
516517
SUB_GROUP_INDEPENDENT_FORWARD_PROGRESS = 81 ## [::ur_bool_t] support sub group independent forward progress
@@ -569,9 +570,33 @@ def __str__(self):
569570

570571

571572
###############################################################################
572-
## @brief Device partition property type
573-
class ur_device_partition_property_t(c_intptr_t):
574-
pass
573+
## @brief Device affinity domain
574+
class ur_device_affinity_domain_flags_v(IntEnum):
575+
NUMA = UR_BIT(0) ## Split the device into sub devices comprised of compute units that
576+
## share a NUMA node.
577+
L4_CACHE = UR_BIT(1) ## Split the device into sub devices comprised of compute units that
578+
## share a level 4 data cache.
579+
L3_CACHE = UR_BIT(2) ## Split the device into sub devices comprised of compute units that
580+
## share a level 3 data cache.
581+
L2_CACHE = UR_BIT(3) ## Split the device into sub devices comprised of compute units that
582+
## share a level 2 data cache.
583+
L1_CACHE = UR_BIT(4) ## Split the device into sub devices comprised of compute units that
584+
## share a level 1 data cache.
585+
NEXT_PARTITIONABLE = UR_BIT(5) ## Split the device along the next partitionable affinity domain.
586+
## The implementation shall find the first level along which the device
587+
## or sub device may be further subdivided in the order:
588+
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_NUMA,
589+
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L4_CACHE,
590+
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L3_CACHE,
591+
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L2_CACHE,
592+
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L1_CACHE,
593+
## and partition the device into sub devices comprised of compute units
594+
## that share memory subsystems at this level.
595+
596+
class ur_device_affinity_domain_flags_t(c_int):
597+
def __str__(self):
598+
return hex(self.value)
599+
575600

576601
###############################################################################
577602
## @brief Partition Properties
@@ -587,6 +612,29 @@ def __str__(self):
587612
return str(ur_device_partition_v(self.value))
588613

589614

615+
###############################################################################
616+
## @brief Device partition value.
617+
class ur_device_partition_value_t(Structure):
618+
_fields_ = [
619+
("equally", c_ulong), ## [in] Number of compute units per sub-device when partitioning with
620+
## ::UR_DEVICE_PARTITION_EQUALLY.
621+
("count", c_ulong), ## [in] Number of compute units in a sub-device when partitioning with
622+
## ::UR_DEVICE_PARTITION_BY_COUNTS.
623+
("affinity_domain", ur_device_affinity_domain_flags_t) ## [in] The affinity domain to partition for when partitioning with
624+
## $UR_DEVICE_PARTITION_BY_AFFINITY_DOMAIN.
625+
]
626+
627+
###############################################################################
628+
## @brief Device partition description
629+
class ur_device_partition_desc_t(Structure):
630+
_fields_ = [
631+
("stype", ur_structure_type_t), ## [in] type of this structure, must be
632+
## ::UR_STRUCTURE_TYPE_DEVICE_PARTITION_DESC
633+
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
634+
("type", ur_device_partition_t), ## [in] The partitioning type to be used.
635+
("value", ur_device_partition_value_t) ## [in] The paritioning value.
636+
]
637+
590638
###############################################################################
591639
## @brief FP capabilities
592640
class ur_device_fp_capability_flags_v(IntEnum):
@@ -639,35 +687,6 @@ def __str__(self):
639687
return hex(self.value)
640688

641689

642-
###############################################################################
643-
## @brief Device affinity domain
644-
class ur_device_affinity_domain_flags_v(IntEnum):
645-
NUMA = UR_BIT(0) ## Split the device into sub devices comprised of compute units that
646-
## share a NUMA node.
647-
L4_CACHE = UR_BIT(1) ## Split the device into sub devices comprised of compute units that
648-
## share a level 4 data cache.
649-
L3_CACHE = UR_BIT(2) ## Split the device into sub devices comprised of compute units that
650-
## share a level 3 data cache.
651-
L2_CACHE = UR_BIT(3) ## Split the device into sub devices comprised of compute units that
652-
## share a level 2 data cache.
653-
L1_CACHE = UR_BIT(4) ## Split the device into sub devices comprised of compute units that
654-
## share a level 1 data cache.
655-
NEXT_PARTITIONABLE = UR_BIT(5) ## Split the device along the next partitionable affinity domain.
656-
## The implementation shall find the first level along which the device
657-
## or sub device may be further subdivided in the order:
658-
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_NUMA,
659-
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L4_CACHE,
660-
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L3_CACHE,
661-
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L2_CACHE,
662-
## ::UR_DEVICE_AFFINITY_DOMAIN_FLAG_L1_CACHE,
663-
## and partition the device into sub devices comprised of compute units
664-
## that share memory subsystems at this level.
665-
666-
class ur_device_affinity_domain_flags_t(c_int):
667-
def __str__(self):
668-
return hex(self.value)
669-
670-
671690
###############################################################################
672691
## @brief Native device creation properties
673692
class ur_device_native_properties_t(Structure):
@@ -2805,9 +2824,9 @@ class ur_usm_dditable_t(Structure):
28052824
###############################################################################
28062825
## @brief Function-pointer for urDevicePartition
28072826
if __use_win_types:
2808-
_urDevicePartition_t = WINFUNCTYPE( ur_result_t, ur_device_handle_t, POINTER(ur_device_partition_property_t), c_ulong, POINTER(ur_device_handle_t), POINTER(c_ulong) )
2827+
_urDevicePartition_t = WINFUNCTYPE( ur_result_t, ur_device_handle_t, POINTER(ur_device_partition_desc_t), c_size_t, c_ulong, POINTER(ur_device_handle_t), POINTER(c_ulong) )
28092828
else:
2810-
_urDevicePartition_t = CFUNCTYPE( ur_result_t, ur_device_handle_t, POINTER(ur_device_partition_property_t), c_ulong, POINTER(ur_device_handle_t), POINTER(c_ulong) )
2829+
_urDevicePartition_t = CFUNCTYPE( ur_result_t, ur_device_handle_t, POINTER(ur_device_partition_desc_t), c_size_t, c_ulong, POINTER(ur_device_handle_t), POINTER(c_ulong) )
28112830

28122831
###############################################################################
28132832
## @brief Function-pointer for urDeviceSelectBinary

0 commit comments

Comments
 (0)