diff --git a/offload/liboffload/API/Device.td b/offload/liboffload/API/Device.td index 4abc24f3ba27f..94bd6cbf0e5be 100644 --- a/offload/liboffload/API/Device.td +++ b/offload/liboffload/API/Device.td @@ -31,7 +31,8 @@ def : Enum { TaggedEtor<"PLATFORM", "ol_platform_handle_t", "the platform associated with the device">, TaggedEtor<"NAME", "char[]", "Device name">, TaggedEtor<"VENDOR", "char[]", "Device vendor">, - TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version"> + TaggedEtor<"DRIVER_VERSION", "char[]", "Driver version">, + TaggedEtor<"MAX_WORK_GROUP_SIZE", "ol_dimensions_t", "Maximum work group size in each dimension">, ]; } diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 770c212d804d2..ff51f06f246d9 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -228,23 +228,48 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ReturnHelper ReturnValue(PropSize, PropValue, PropSizeRet); // Find the info if it exists under any of the given names - auto GetInfo = [&](std::vector Names) { + auto GetInfoString = [&](std::vector Names) { if (Device == HostDevice()) - return std::string("Host"); + return "Host"; if (!Device->Device) - return std::string(""); + return ""; auto Info = Device->Device->obtainInfoImpl(); if (auto Err = Info.takeError()) - return std::string(""); + return ""; for (auto Name : Names) { if (auto Entry = Info->get(Name)) - return (*Entry)->Value; + return (*Entry)->Value.c_str(); } - return std::string(""); + return ""; + }; + auto GetInfoXyz = [&](std::vector Names) { + if (!Device->Device) + return ol_dimensions_t{0, 0, 0}; + + auto Info = Device->Device->obtainInfoImpl(); + if (auto Err = Info.takeError()) + return ol_dimensions_t{0, 0, 0}; + + for (auto Name : Names) { + if (auto Entry = Info->get(Name)) { + auto Node = *Entry; + ol_dimensions_t Out{0, 0, 0}; + + if (auto X = Node->get("x")) + Out.x = stoi((*X)->Value); + if (auto Y = Node->get("y")) + Out.y = stoi((*Y)->Value); + if (auto Z = Node->get("z")) + Out.z = stoi((*Z)->Value); + return Out; + } + } + + return ol_dimensions_t{0, 0, 0}; }; switch (PropName) { @@ -254,12 +279,15 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST) : ReturnValue(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_NAME: - return ReturnValue(GetInfo({"Device Name"}).c_str()); + return ReturnValue(GetInfoString({"Device Name"})); case OL_DEVICE_INFO_VENDOR: - return ReturnValue(GetInfo({"Vendor Name"}).c_str()); + return ReturnValue(GetInfoString({"Vendor Name"})); case OL_DEVICE_INFO_DRIVER_VERSION: return ReturnValue( - GetInfo({"CUDA Driver Version", "HSA Runtime Version"}).c_str()); + GetInfoString({"CUDA Driver Version", "HSA Runtime Version"})); + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: + return ReturnValue(GetInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/, + "Maximum Block Dimensions" /*CUDA*/})); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp index a964ff09d0f6e..d1189688a90a3 100644 --- a/offload/tools/offload-tblgen/PrintGen.cpp +++ b/offload/tools/offload-tblgen/PrintGen.cpp @@ -213,6 +213,11 @@ template inline void printTagged(llvm::raw_ostream &os, const void "enum {0} value);\n", EnumRec{R}.getName()); } + for (auto *R : Records.getAllDerivedDefinitions("Struct")) { + OS << formatv("inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, " + "const struct {0} param);\n", + StructRec{R}.getName()); + } OS << "\n"; // Create definitions diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp index 0247744911eaa..c534c45205993 100644 --- a/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp +++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp @@ -77,6 +77,15 @@ TEST_P(olGetDeviceInfoTest, SuccessDriverVersion) { ASSERT_EQ(std::strlen(DriverVersion.data()), Size - 1); } +TEST_P(olGetDeviceInfoTest, SuccessMaxWorkGroupSize) { + ol_dimensions_t Value{0, 0, 0}; + ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, + sizeof(Value), &Value)); + ASSERT_GT(Value.x, 0u); + ASSERT_GT(Value.y, 0u); + ASSERT_GT(Value.z, 0u); +} + TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) { ol_device_type_t DeviceType; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, diff --git a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp index edd2704a722dd..a908078a25211 100644 --- a/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp +++ b/offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp @@ -44,6 +44,14 @@ TEST_P(olGetDeviceInfoSizeTest, SuccessDriverVersion) { ASSERT_NE(Size, 0ul); } +TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSize) { + size_t Size = 0; + ASSERT_SUCCESS( + olGetDeviceInfoSize(Device, OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE, &Size)); + ASSERT_EQ(Size, sizeof(ol_dimensions_t)); + ASSERT_EQ(Size, sizeof(uint32_t) * 3); +} + TEST_P(olGetDeviceInfoSizeTest, InvalidNullHandle) { size_t Size = 0; ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,