-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[Offload] Add MAX_WORK_GROUP_SIZE
device info query
#143718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-offload Author: Ross Brunton (RossBrunton) ChangesThis adds a new device info query for the maximum workgroup/block size Full diff: https://github.com/llvm/llvm-project/pull/143718.diff 6 Files Affected:
diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td
index 7674da0438c29..e2e6d6e452671 100644
--- a/offload/liboffload/API/Common.td
+++ b/offload/liboffload/API/Common.td
@@ -148,6 +148,16 @@ def : Struct {
];
}
+def : Struct {
+ let name = "ol_range_t";
+ let desc = "A three element vector";
+ let members = [
+ StructMember<"size_t", "x", "X">,
+ StructMember<"size_t", "y", "Y">,
+ StructMember<"size_t", "z", "Z">,
+ ];
+}
+
def : Function {
let name = "olInit";
let desc = "Perform initialization of the Offload library and plugins";
diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td
index 4abc24f3ba27f..1c7d1aaee8d59 100644
--- a/offload/liboffload/API/Device.td
+++ b/offload/liboffload/API/Device.td
@@ -31,7 +31,8 @@ def : Enum {
TaggedEtor<"PLATFORM", "ol_platform_handle_t", "the platform associated with the device">,
TaggedEtor<"NAME", "char[]", "Device name">,
TaggedEtor<"VENDOR", "char[]", "Device vendor">,
- TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">
+ TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">,
+ TaggedEtor<"MAX_WORK_GROUP_SIZE", "ol_range_t", "Maximum work group size in each dimension">,
];
}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index d2b331905ab77..b89845c387e65 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -228,16 +228,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet);
// Find the info if it exists under any of the given names
- auto GetInfo = [&](std::vector<std::string> Names) {
- InfoQueueTy DevInfo;
- if (Device == HostDevice())
- return std::string("Host");
-
+ auto FindInfo = [&](InfoQueueTy &DevInfo, std::vector<std::string> &Names)
+ -> std::optional<decltype(DevInfo.getQueue().begin())> {
if (!Device->Device)
- return std::string("");
+ return std::nullopt;
if (auto Err = Device->Device->obtainInfoImpl(DevInfo))
- return std::string("");
+ return std::nullopt;
for (auto Name : Names) {
auto InfoKeyMatches = [&](const InfoQueueTy::InfoQueueEntryTy &Info) {
@@ -247,11 +244,50 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
DevInfo.getQueue().end(), InfoKeyMatches);
if (Item != std::end(DevInfo.getQueue())) {
- return Item->Value;
+ return Item;
}
}
- return std::string("");
+ return std::nullopt;
+ };
+ auto GetInfoString = [&](std::vector<std::string> Names) {
+ InfoQueueTy DevInfo;
+
+ if (auto Item = FindInfo(DevInfo, Names)) {
+ return (*Item)->Value.c_str();
+ } else {
+ return "";
+ }
+ };
+ auto GetInfoXyz = [&](std::vector<std::string> Names) {
+ InfoQueueTy DevInfo;
+
+ if (auto Item = FindInfo(DevInfo, Names)) {
+ auto Iter = *Item;
+ ol_range_t Out{0, 0, 0};
+ auto Level = Iter->Level + 1;
+
+ while ((++Iter)->Level == Level) {
+ switch (Iter->Key[0]) {
+ case 'x':
+ Out.x = std::stoi(Iter->Value);
+ break;
+ case 'y':
+ Out.y = std::stoi(Iter->Value);
+ break;
+ case 'z':
+ Out.z = std::stoi(Iter->Value);
+ break;
+ default:
+ // Ignore any extra values
+ (void)0;
+ }
+ }
+
+ return Out;
+ } else {
+ return ol_range_t{0, 0, 0};
+ }
};
switch (PropName) {
@@ -261,12 +297,21 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST)
: ReturnValue(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
- return ReturnValue(GetInfo({"Device Name"}).c_str());
+ if (Device == HostDevice())
+ return ReturnValue("Host");
+ return ReturnValue(GetInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR:
- return ReturnValue(GetInfo({"Vendor Name"}).c_str());
+ if (Device == HostDevice())
+ return ReturnValue("Host");
+ return ReturnValue(GetInfoString({"Vendor Name"}));
case OL_DEVICE_INFO_DRIVER_VERSION:
+ if (Device == HostDevice())
+ return ReturnValue("Host");
return ReturnValue(
- GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str());
+ GetInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
+ return ReturnValue(GetInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
+ "Maximum Block Dimensions" /*CUDA*/}));
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getDeviceInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index a964ff09d0f6e..d1189688a90a3 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -213,6 +213,11 @@ template <typename T> inline void printTagged(llvm::raw_ostream &os, const void
"enum {0} value);\n",
EnumRec{R}.getName());
}
+ for (auto *R : Records.getAllDerivedDefinitions("Struct")) {
+ OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, "
+ "const struct {0} param);\n",
+ StructRec{R}.getName());
+ }
OS << "\n";
// Create definitions
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
index 0247744911eaa..ef7baf9e91275 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp
@@ -77,6 +77,15 @@ TEST_P(olGetDeviceInfoTest, SuccessDriverVersion) {
ASSERT_EQ(std::strlen(DriverVersion.data()), Size - 1);
}
+TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) {
+ ol_range_t Value{0, 0, 0};
+ ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE,
+ sizeof(Value), &Value));
+ ASSERT_GT(Value.x, 0);
+ ASSERT_GT(Value.y, 0);
+ ASSERT_GT(Value.z, 0);
+}
+
TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {
ol_device_type_t DeviceType;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
index edd2704a722dd..a2caad8650c79 100644
--- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp
@@ -44,6 +44,14 @@ TEST_P(olGetDeviceInfoSizeTest, SuccessDriverVersion) {
ASSERT_NE(Size, 0ul);
}
+TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSize) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(
+ olGetDeviceInfoSize(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, &Size));
+ ASSERT_EQ(Size, sizeof(ol_range_t));
+ ASSERT_EQ(Size, sizeof(size_t) * 3);
+}
+
TEST_P(olGetDeviceInfoSizeTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some small comments
@@ -261,12 +297,21 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, | |||
return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST) | |||
: ReturnValue(OL_DEVICE_TYPE_GPU); | |||
case OL_DEVICE_INFO_NAME: | |||
return ReturnValue(GetInfo({"Device Name"}).c_str()); | |||
if (Device == HostDevice()) | |||
return ReturnValue("Host"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should do something to avoid having to special-case the host device everywhere in this function. Maybe a separate function that we delegate to when Device == HostDevice()
. It could return an UNSUPPORTED_ENUM
error for most properties since I don't think we really want to implement all of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you be okay if I did that as a separate change after/alongside this? It feels like a separate thing that could involve lots of discussions.
return std::string(""); | ||
return std::nullopt; | ||
}; | ||
auto GetInfoString = [&](std::vector<std::string> Names) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should pass Names
by reference here and in GetInfoXyz
(I think this is my mistake from the original implementation)
Is there a reason we can't use StringRef
and SmallVector
here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like trying to pass it by reference results in us being unable to use {"Device Name"}
in the literal list (because it isn't an lvalue reference).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes sense, never mind then
return ""; | ||
} | ||
}; | ||
auto GetInfoXyz = [&](llvm::SmallVector<StringRef> Names) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we copying this by value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to use this function like return ReturnValue(GetInfoString({"Vendor Name"}));
where the argument is passed as an initialiser list. Therefore it can't be an lvalue.
default: | ||
// Ignore any extra values | ||
(void)0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to void zero instead of just doing nothing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ (before 23, apparently) requires labels to be attached to statements, hence the no-op statement.
But I'll just move the comment out of the switch entirely, since it's kinda confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this ever actually be hit? Just make it __builtin_unreachable()
otherwise, just make it break;
immediately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can if a plugin (correctly or otherwise) reports a value outwith x, y or z.
if (Device == HostDevice()) | ||
return std::string("Host"); | ||
|
||
auto FindInfo = [&](InfoQueueTy &DevInfo, llvm::SmallVector<StringRef> &Names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little confusing, what exactly are we looking up here? I figured this would be a hash table or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not; it's a list of basically-strings with nesting.
It's a mess, and both me and Callum want to improve it in some way.
while ((++Iter)->Level == Level) { | ||
switch (Iter->Key[0]) { | ||
case 'x': | ||
Out.x = std::stoi(Iter->Value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM has integer conversions that can fail correctly. Do we need to worry about invalid values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure. This value is reported by the plugins which should always return valid integers. Should liboffload be in charge of validating values that plugins return, or is it enough to say that plugins returning invalid values is UB?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be very difficult for people to reason with a meta library if it reported arbitrary UB with the underlying implementation
This adds a new device info query for the maximum workgroup/block size for each dimension. Since this returns three values, a new `ol_range_t` type was added as an `{x, y, z}` triplet. Device info handling and struct printing was also updated to handle it.
This adds a new device info query for the maximum workgroup/block size
for each dimension. Since this returns three values, a new
ol_range_t
type was added as an
{x, y, z}
triplet. Device info handling andstruct printing was also updated to handle it.