Skip to content

Commit e7101da

Browse files
authored
[Offload] Copy loaded images into managed storage (#158748)
Summary: Currently we have this `__tgt_device_image` indirection which just takes a reference to some pointers. This was all find and good when the only usage of this was from a section of GPU code that came from an ELF constant section. However, we have expanded beyond that and now need to worry about managing lifetimes. We have code that references the image even after it was loaded internally. This patch changes the implementation to instaed copy the memory buffer and manage it locally. This PR reworks the JIT and other image handling to directly manage its own memory. We now don't need to duplicate this behavior externally at the Offload API level. Also we actually free these if the user unloads them. Upside, less likely to crash and burn. Downside, more latency when loading an image.
1 parent 311d78f commit e7101da

File tree

8 files changed

+89
-198
lines changed

8 files changed

+89
-198
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,11 @@ struct ol_event_impl_t {
157157

158158
struct ol_program_impl_t {
159159
ol_program_impl_t(plugin::DeviceImageTy *Image,
160-
std::unique_ptr<llvm::MemoryBuffer> ImageData,
161-
const __tgt_device_image &DeviceImage)
162-
: Image(Image), ImageData(std::move(ImageData)),
163-
DeviceImage(DeviceImage) {}
160+
llvm::MemoryBufferRef DeviceImage)
161+
: Image(Image), DeviceImage(DeviceImage) {}
164162
plugin::DeviceImageTy *Image;
165-
std::unique_ptr<llvm::MemoryBuffer> ImageData;
166163
std::mutex SymbolListMutex;
167-
__tgt_device_image DeviceImage;
164+
llvm::MemoryBufferRef DeviceImage;
168165
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
169166
llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
170167
};
@@ -891,28 +888,14 @@ Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
891888
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
892889
size_t ProgDataSize, ol_program_handle_t *Program) {
893890
// Make a copy of the program binary in case it is released by the caller.
894-
auto ImageData = MemoryBuffer::getMemBufferCopy(
895-
StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
896-
897-
auto DeviceImage = __tgt_device_image{
898-
const_cast<char *>(ImageData->getBuffer().data()),
899-
const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr,
900-
nullptr};
901-
902-
ol_program_handle_t Prog =
903-
new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage);
904-
905-
auto Res =
906-
Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
907-
if (!Res) {
908-
delete Prog;
891+
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
892+
Expected<plugin::DeviceImageTy *> Res =
893+
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
894+
if (!Res)
909895
return Res.takeError();
910-
}
911-
assert(*Res != nullptr && "loadBinary returned nullptr");
912-
913-
Prog->Image = *Res;
914-
*Program = Prog;
896+
assert(*Res && "loadBinary returned nullptr");
915897

898+
*Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer());
916899
return Error::success();
917900
}
918901

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ struct AMDGPUMemoryManagerTy : public DeviceAllocatorTy {
464464
struct AMDGPUDeviceImageTy : public DeviceImageTy {
465465
/// Create the AMDGPU image with the id and the target image pointer.
466466
AMDGPUDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
467-
const __tgt_device_image *TgtImage)
468-
: DeviceImageTy(ImageId, Device, TgtImage) {}
467+
std::unique_ptr<MemoryBuffer> &&TgtImage)
468+
: DeviceImageTy(ImageId, Device, std::move(TgtImage)) {}
469469

470470
/// Prepare and load the executable corresponding to the image.
471471
Error loadExecutable(const AMDGPUDeviceTy &Device);
@@ -2160,7 +2160,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
21602160
AMDGPUDeviceImageTy &AMDImage = static_cast<AMDGPUDeviceImageTy &>(*Image);
21612161

21622162
// Unload the executable of the image.
2163-
return AMDImage.unloadExecutable();
2163+
if (auto Err = AMDImage.unloadExecutable())
2164+
return Err;
2165+
2166+
// Destroy the associated memory and invalidate the object.
2167+
Plugin.free(Image);
2168+
return Error::success();
21642169
}
21652170

21662171
/// Deinitialize the device and release its resources.
@@ -2183,18 +2188,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
21832188

21842189
virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
21852190
DeviceImageTy &Image) override {
2186-
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
2187-
if (Handler.isSymbolInImage(*this, Image, "amdgcn.device.fini"))
2188-
Image.setPendingGlobalDtors();
2189-
21902191
return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/true);
21912192
}
21922193

21932194
virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
21942195
DeviceImageTy &Image) override {
2195-
if (Image.hasPendingGlobalDtors())
2196-
return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
2197-
return Plugin::success();
2196+
return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
21982197
}
21992198

22002199
uint64_t getStreamBusyWaitMicroseconds() const { return OMPX_StreamBusyWait; }
@@ -2321,11 +2320,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
23212320
}
23222321

23232322
/// Load the binary image into the device and allocate an image object.
2324-
Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
2325-
int32_t ImageId) override {
2323+
Expected<DeviceImageTy *>
2324+
loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage,
2325+
int32_t ImageId) override {
23262326
// Allocate and initialize the image object.
23272327
AMDGPUDeviceImageTy *AMDImage = Plugin.allocate<AMDGPUDeviceImageTy>();
2328-
new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, TgtImage);
2328+
new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, std::move(TgtImage));
23292329

23302330
// Load the HSA executable.
23312331
if (Error Err = AMDImage->loadExecutable(*this))
@@ -3105,7 +3105,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
31053105
// Perform a quick check for the named kernel in the image. The kernel
31063106
// should be created by the 'amdgpu-lower-ctor-dtor' pass.
31073107
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
3108-
if (IsCtor && !Handler.isSymbolInImage(*this, Image, KernelName))
3108+
if (!Handler.isSymbolInImage(*this, Image, KernelName))
31093109
return Plugin::success();
31103110

31113111
// Allocate and construct the AMDGPU kernel.

offload/plugins-nextgen/common/include/JIT.h

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,22 @@ struct JITEngine {
5151
/// Run jit compilation if \p Image is a bitcode image, otherwise simply
5252
/// return \p Image. It is expected to return a memory buffer containing the
5353
/// generated device image that could be loaded to the device directly.
54-
Expected<const __tgt_device_image *>
55-
process(const __tgt_device_image &Image,
56-
target::plugin::GenericDeviceTy &Device);
57-
58-
/// Remove \p Image from the jit engine's cache
59-
void erase(const __tgt_device_image &Image,
60-
target::plugin::GenericDeviceTy &Device);
54+
Expected<std::unique_ptr<MemoryBuffer>>
55+
process(StringRef Image, target::plugin::GenericDeviceTy &Device);
6156

6257
private:
6358
/// Compile the bitcode image \p Image and generate the binary image that can
6459
/// be loaded to the target device of the triple \p Triple architecture \p
6560
/// MCpu. \p PostProcessing will be called after codegen to handle cases such
6661
/// as assembler as an external tool.
67-
Expected<const __tgt_device_image *>
68-
compile(const __tgt_device_image &Image, const std::string &ComputeUnitKind,
62+
Expected<std::unique_ptr<MemoryBuffer>>
63+
compile(StringRef Image, const std::string &ComputeUnitKind,
6964
PostProcessingFn PostProcessing);
7065

7166
/// Create or retrieve the object image file from the file system or via
7267
/// compilation of the \p Image.
7368
Expected<std::unique_ptr<MemoryBuffer>>
74-
getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
69+
getOrCreateObjFile(StringRef Image, LLVMContext &Ctx,
7570
const std::string &ComputeUnitKind);
7671

7772
/// Run backend, which contains optimization and code generation.
@@ -92,14 +87,6 @@ struct JITEngine {
9287
struct ComputeUnitInfo {
9388
/// LLVM Context in which the modules will be constructed.
9489
LLVMContext Context;
95-
96-
/// A map of embedded IR images to the buffer used to store JITed code
97-
DenseMap<const __tgt_device_image *, std::unique_ptr<MemoryBuffer>>
98-
JITImages;
99-
100-
/// A map of embedded IR images to JITed images.
101-
DenseMap<const __tgt_device_image *, std::unique_ptr<__tgt_device_image>>
102-
TgtImageMap;
10390
};
10491

10592
/// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -306,60 +306,36 @@ class DeviceImageTy {
306306
/// not unique between different device; they may overlap.
307307
int32_t ImageId;
308308

309-
/// The pointer to the raw __tgt_device_image.
310-
const __tgt_device_image *TgtImage;
311-
const __tgt_device_image *TgtImageBitcode;
309+
/// The managed image data.
310+
std::unique_ptr<MemoryBuffer> Image;
312311

313312
/// Reference to the device this image is loaded on.
314313
GenericDeviceTy &Device;
315314

316-
/// If this image has any global destructors that much be called.
317-
/// FIXME: This is only required because we currently have no invariants
318-
/// towards the lifetime of the underlying image. We should either copy
319-
/// the image into memory locally or erase the pointers after init.
320-
bool PendingGlobalDtors;
321-
322315
public:
316+
virtual ~DeviceImageTy() = default;
317+
323318
DeviceImageTy(int32_t Id, GenericDeviceTy &Device,
324-
const __tgt_device_image *Image)
325-
: ImageId(Id), TgtImage(Image), TgtImageBitcode(nullptr), Device(Device),
326-
PendingGlobalDtors(false) {
327-
assert(TgtImage && "Invalid target image");
328-
}
319+
std::unique_ptr<MemoryBuffer> &&Image)
320+
: ImageId(Id), Image(std::move(Image)), Device(Device) {}
329321

330322
/// Get the image identifier within the device.
331323
int32_t getId() const { return ImageId; }
332324

333325
/// Get the device that this image is loaded onto.
334326
GenericDeviceTy &getDevice() const { return Device; }
335327

336-
/// Get the pointer to the raw __tgt_device_image.
337-
const __tgt_device_image *getTgtImage() const { return TgtImage; }
338-
339-
void setTgtImageBitcode(const __tgt_device_image *TgtImageBitcode) {
340-
this->TgtImageBitcode = TgtImageBitcode;
341-
}
342-
343-
const __tgt_device_image *getTgtImageBitcode() const {
344-
return TgtImageBitcode;
345-
}
346-
347328
/// Get the image starting address.
348-
void *getStart() const { return TgtImage->ImageStart; }
329+
const void *getStart() const { return Image->getBufferStart(); }
349330

350331
/// Get the image size.
351-
size_t getSize() const {
352-
return utils::getPtrDiff(TgtImage->ImageEnd, TgtImage->ImageStart);
353-
}
332+
size_t getSize() const { return Image->getBufferSize(); }
354333

355334
/// Get a memory buffer reference to the whole image.
356335
MemoryBufferRef getMemoryBuffer() const {
357336
return MemoryBufferRef(StringRef((const char *)getStart(), getSize()),
358337
"Image");
359338
}
360-
/// Accessors to the boolean value
361-
bool setPendingGlobalDtors() { return PendingGlobalDtors = true; }
362-
bool hasPendingGlobalDtors() const { return PendingGlobalDtors; }
363339
};
364340

365341
/// Class implementing common functionalities of offload kernels. Each plugin
@@ -831,9 +807,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
831807

832808
/// Load the binary image into the device and return the target table.
833809
Expected<DeviceImageTy *> loadBinary(GenericPluginTy &Plugin,
834-
const __tgt_device_image *TgtImage);
810+
StringRef TgtImage);
835811
virtual Expected<DeviceImageTy *>
836-
loadBinaryImpl(const __tgt_device_image *TgtImage, int32_t ImageId) = 0;
812+
loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage, int32_t ImageId) = 0;
837813

838814
/// Unload a previously loaded Image from the device
839815
Error unloadBinary(DeviceImageTy *Image);

offload/plugins-nextgen/common/src/JIT.cpp

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ using namespace omp::target;
4949

5050
namespace {
5151

52-
bool isImageBitcode(const __tgt_device_image &Image) {
53-
StringRef Binary(reinterpret_cast<const char *>(Image.ImageStart),
54-
utils::getPtrDiff(Image.ImageEnd, Image.ImageStart));
55-
56-
return identify_magic(Binary) == file_magic::bitcode;
57-
}
58-
5952
Expected<std::unique_ptr<Module>>
6053
createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
6154
LLVMContext &Context) {
@@ -66,12 +59,10 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
6659
"failed to create module");
6760
return std::move(Mod);
6861
}
69-
Expected<std::unique_ptr<Module>>
70-
createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
71-
StringRef Data((const char *)Image.ImageStart,
72-
utils::getPtrDiff(Image.ImageEnd, Image.ImageStart));
62+
Expected<std::unique_ptr<Module>> createModuleFromImage(StringRef Image,
63+
LLVMContext &Context) {
7364
std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
74-
Data, /*BufferName=*/"", /*RequiresNullTerminator=*/false);
65+
Image, /*BufferName=*/"", /*RequiresNullTerminator=*/false);
7566
return createModuleFromMemoryBuffer(MB, Context);
7667
}
7768

@@ -238,7 +229,7 @@ JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
238229
}
239230

240231
Expected<std::unique_ptr<MemoryBuffer>>
241-
JITEngine::getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
232+
JITEngine::getOrCreateObjFile(StringRef Image, LLVMContext &Ctx,
242233
const std::string &ComputeUnitKind) {
243234

244235
// Check if the user replaces the module at runtime with a finished object.
@@ -277,58 +268,28 @@ JITEngine::getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
277268
return backend(*Mod, ComputeUnitKind, JITOptLevel);
278269
}
279270

280-
Expected<const __tgt_device_image *>
281-
JITEngine::compile(const __tgt_device_image &Image,
282-
const std::string &ComputeUnitKind,
271+
Expected<std::unique_ptr<MemoryBuffer>>
272+
JITEngine::compile(StringRef Image, const std::string &ComputeUnitKind,
283273
PostProcessingFn PostProcessing) {
284274
std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
285275

286-
// Check if we JITed this image for the given compute unit kind before.
287-
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
288-
if (CUI.TgtImageMap.contains(&Image))
289-
return CUI.TgtImageMap[&Image].get();
290-
291-
auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
276+
LLVMContext Ctz;
277+
auto ObjMBOrErr = getOrCreateObjFile(Image, Ctz, ComputeUnitKind);
292278
if (!ObjMBOrErr)
293279
return ObjMBOrErr.takeError();
294280

295-
auto ImageMBOrErr = PostProcessing(std::move(*ObjMBOrErr));
296-
if (!ImageMBOrErr)
297-
return ImageMBOrErr.takeError();
298-
299-
CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)});
300-
auto &ImageMB = CUI.JITImages[&Image];
301-
CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()});
302-
auto &JITedImage = CUI.TgtImageMap[&Image];
303-
*JITedImage = Image;
304-
JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart());
305-
JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
306-
307-
return JITedImage.get();
281+
return PostProcessing(std::move(*ObjMBOrErr));
308282
}
309283

310-
Expected<const __tgt_device_image *>
311-
JITEngine::process(const __tgt_device_image &Image,
312-
target::plugin::GenericDeviceTy &Device) {
313-
const std::string &ComputeUnitKind = Device.getComputeUnitKind();
284+
Expected<std::unique_ptr<MemoryBuffer>>
285+
JITEngine::process(StringRef Image, target::plugin::GenericDeviceTy &Device) {
286+
assert(identify_magic(Image) == file_magic::bitcode && "Image not LLVM-IR");
314287

288+
const std::string &ComputeUnitKind = Device.getComputeUnitKind();
315289
PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
316290
-> Expected<std::unique_ptr<MemoryBuffer>> {
317291
return Device.doJITPostProcessing(std::move(MB));
318292
};
319293

320-
if (isImageBitcode(Image))
321-
return compile(Image, ComputeUnitKind, PostProcessing);
322-
323-
return &Image;
324-
}
325-
326-
void JITEngine::erase(const __tgt_device_image &Image,
327-
target::plugin::GenericDeviceTy &Device) {
328-
std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
329-
const std::string &ComputeUnitKind = Device.getComputeUnitKind();
330-
ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
331-
332-
CUI.TgtImageMap.erase(&Image);
333-
CUI.JITImages.erase(&Image);
294+
return compile(Image, ComputeUnitKind, PostProcessing);
334295
}

0 commit comments

Comments
 (0)