From 7b7faf97f1fd534dcbadb47ab58bcccc309950df Mon Sep 17 00:00:00 2001 From: Garret Catron Date: Mon, 5 Nov 2018 13:21:24 -0800 Subject: [PATCH] Moved OpenCL memory allocation to runtime. Created new BackendUtils library and put collectConstants there.. Added helper function to retrieve symbol offset by value from symbolTable --- include/glow/Backends/BackendUtils.h | 39 +++ lib/Backends/BackendUtils.cpp | 53 ++++ lib/Backends/CMakeLists.txt | 7 + lib/Backends/OpenCL/OpenCL.cpp | 326 ++++++++++++------------ lib/Backends/OpenCL/OpenCL.h | 43 ++-- lib/ExecutionEngine/ExecutionEngine.cpp | 5 +- 6 files changed, 286 insertions(+), 187 deletions(-) create mode 100644 include/glow/Backends/BackendUtils.h create mode 100644 lib/Backends/BackendUtils.cpp diff --git a/include/glow/Backends/BackendUtils.h b/include/glow/Backends/BackendUtils.h new file mode 100644 index 0000000000..404a6ae4e1 --- /dev/null +++ b/include/glow/Backends/BackendUtils.h @@ -0,0 +1,39 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef GLOW_BACKENDS_BACKENDUTILS_H +#define GLOW_BACKENDS_BACKENDUTILS_H + +#include "glow/Backends/CompiledFunction.h" +#include "glow/IR/IR.h" + +namespace glow { +/// At compile time condense constants to a single block of memory. +/// This allows the graph to go away after compile time. +/// Allocates a block of memory of size \p constantMaxSize then walks the given +/// function \p F and and copies weights to their address as specified by +/// offsets contained in \p symbolTable. +uint8_t *collectConstants( + const IRFunction *F, uint64_t constantMaxSize, + const std::unordered_map + &symbolTable); +/// Helper function to retrieve offset for Value: \p v from \p symbolTable. +size_t +getValueOffset(Value *v, + const std::unordered_map + &symbolTable); +} // end namespace glow + +#endif // GLOW_BACKENDS_BACKENDUTILS_H diff --git a/lib/Backends/BackendUtils.cpp b/lib/Backends/BackendUtils.cpp new file mode 100644 index 0000000000..80fbb1628c --- /dev/null +++ b/lib/Backends/BackendUtils.cpp @@ -0,0 +1,53 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "glow/Backends/BackendUtils.h" +#include "glow/IR/Instrs.h" + +using namespace glow; +using llvm::cast; +using llvm::isa; + +uint8_t *glow::collectConstants( + const IRFunction *F, uint64_t constantMaxSize, + const std::unordered_map + &symbolTable) { + + // At compile time condense constants to a single block of memory. + // This allows the graph to go away after compile time. + uint8_t *baseConstantWeightVarsStore = + (uint8_t *)alignedAlloc(constantMaxSize, TensorAlignment); + for (auto &v : F->getGraph()->getParent()->getConstants()) { + assert(isa(F->getWeightForNode(v))); + auto *w = cast(F->getWeightForNode(v)); + auto payload = v->getPayload().getUnsafePtr(); + auto numBytes = w->getSizeInBytes(); + auto it = symbolTable.find(std::string(w->getName())); + assert(it != symbolTable.end() && "Symbol not found."); + auto addr = it->second.offset; + // Copy weight to offset. + memcpy(baseConstantWeightVarsStore + addr, payload, numBytes); + } + return baseConstantWeightVarsStore; +} + +/// Helper function, gets offset of \p v from \p symbolTable. +size_t glow::getValueOffset( + Value *v, const std::unordered_map + &symbolTable) { + auto it = symbolTable.find(std::string(v->getName())); + assert(it != symbolTable.end() && "Symbol not found."); + return it->second.offset; +} diff --git a/lib/Backends/CMakeLists.txt b/lib/Backends/CMakeLists.txt index 6c39052e99..e39fe6e148 100644 --- a/lib/Backends/CMakeLists.txt +++ b/lib/Backends/CMakeLists.txt @@ -1,5 +1,11 @@ add_library(Backends Backends.cpp) +add_library(BackendUtils BackendUtils.cpp) + +target_link_libraries(BackendUtils + PRIVATE + IR) + add_subdirectory(Interpreter) if(GLOW_WITH_OPENCL) @@ -15,6 +21,7 @@ endif() target_link_libraries(Backends PRIVATE + BackendUtils Interpreter ${linked_backends} Base diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index bd19099775..8959c1dc62 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -20,6 +20,7 @@ #include "OpenCL.h" +#include "glow/Backends/BackendUtils.h" #include "glow/CodeGen/MemoryAllocator.h" #include "glow/Graph/Graph.h" #include "glow/Graph/Nodes.h" @@ -120,8 +121,8 @@ static void addStringOption(std::vector &options, } OpenCLFunction::OpenCLFunction(std::unique_ptr F, - const Context &ctx) - : F_(std::move(F)) { + const runtime::RuntimeBundle &bundle) + : F_(std::move(F)), bundle_(bundle) { cl_uint numPlatforms{0}; cl_int err = clGetPlatformIDs(0, NULL, &numPlatforms); GLOW_ASSERT(err == CL_SUCCESS && "clGetPlatformIDs Failed."); @@ -160,7 +161,6 @@ OpenCLFunction::OpenCLFunction(std::unique_ptr F, std::string source(reinterpret_cast(kernels_cl_src), kernels_cl_src_size); createProgram(source, options, commands_); - allocateMemory(ctx); } OpenCLFunction::~OpenCLFunction() { @@ -174,7 +174,6 @@ OpenCLFunction::~OpenCLFunction() { freeDeviceBuffer(deviceBuffer_); deviceBuffer_ = nullptr; } - externalTensors_.clear(); } static std::string getKernelName(const char *baseName, ElemKind elemTy) { @@ -262,14 +261,13 @@ void setKernelArg(cl_kernel kernel, unsigned argIdx, T value) { /// Set OpenCL \p kernel arguments using the buffer operands of the /// instruction \p I. The first of these arguments should be passed to the \p -/// kernel at index \p nextKernelArgIdx. The \p tensors map provides a mapping -/// from Values to on-device buffers addresses of these values. +/// kernel at index \p nextKernelArgIdx. The \p bundle provides symbolTable, a +/// mapping from Values to on-device buffer offsets of these values. /// /// \returns the index of the last set OpenCL kernel argument. -static size_t -setKernelArgsForBuffers(cl_kernel kernel, const Instruction &I, - size_t nextKernelArgIdx, - std::unordered_map &tensors) { +static size_t setKernelArgsForBuffers(cl_kernel kernel, const Instruction &I, + size_t nextKernelArgIdx, + runtime::RuntimeBundle &bundle) { // Number of instruction operands. auto numArgs = I.getNumOperands(); // The predicate of the instruction if available. @@ -284,7 +282,8 @@ setKernelArgsForBuffers(cl_kernel kernel, const Instruction &I, if (value == predicate) continue; // The value is a buffer that should be passed as a kernel argument. - setKernelArg(kernel, kernelArgIdx, tensors[value]); + setKernelArg(kernel, kernelArgIdx, + getValueOffset(value, bundle.symbolTable)); kernelArgIdx++; } return kernelArgIdx - 1; @@ -540,10 +539,11 @@ void OpenCLFunction::executeConvolution(const OCLConvolutionInst *CC) { auto kernelName = isQuantized ? "conv_forward_mem_i8" : "conv_forward_mem"; auto kernel = createKernel(kernelName, prog); setKernelArg(kernel, 0, deviceBuffer_); - setKernelArg(kernel, 1, tensors_[input]); - setKernelArg(kernel, 2, tensors_[weights]); - setKernelArg(kernel, 3, tensors_[bias]); - setKernelArg(kernel, 4, tensors_[output]); + setKernelArg(kernel, 1, getValueOffset(input, bundle_.symbolTable)); + setKernelArg(kernel, 2, + getValueOffset(weights, bundle_.symbolTable)); + setKernelArg(kernel, 3, getValueOffset(bias, bundle_.symbolTable)); + setKernelArg(kernel, 4, getValueOffset(output, bundle_.symbolTable)); // Extra options for quantized kernel if (isQuantized) { @@ -619,11 +619,8 @@ static void topK(Tensor &outW, Tensor &indW, Tensor &inW, size_t k) { } } void OpenCLFunction::execute(Context &ctx) { - auto copiedToDeviceBytes = copyMutableWeightsToDevice(); - (void)copiedToDeviceBytes; - DEBUG_GLOW(llvm::dbgs() << "Copied " << copiedToDeviceBytes - << " bytes to OpenCL device\n"); - + copyConstantsToDevice(); + copyInputsToDevice(ctx); for (const auto &I : F_->getInstrs()) { // The kernels are named after the name of the instruction, plus the "W" // suffix to prevent name colissions for functions like 'tanh' that are also @@ -673,7 +670,7 @@ void OpenCLFunction::execute(Context &ctx) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto numMandatoryArgs = numArgs; (void)numMandatoryArgs; @@ -765,7 +762,7 @@ void OpenCLFunction::execute(Context &ctx) { // the batch is processed by a different parallel 'thread'. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); // This is the number of elements for each slice. There are N slices in // our batch. @@ -784,7 +781,7 @@ void OpenCLFunction::execute(Context &ctx) { // the batch is processed by a different parallel 'thread'. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); // This is the number of elements for each slice. There are N slices in // our batch. @@ -801,7 +798,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *ET = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); // Currently support tensors up to 4 dimensions. // TODO: Handle other dimensions. @@ -841,7 +838,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *IT = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); // Currently support tensors of up to 4 dimensions. // TODO: Handle other dimensions. @@ -895,7 +892,7 @@ void OpenCLFunction::execute(Context &ctx) { cl_kernel kernel = createKernel(useTiledMatMul ? tiledKernelName : kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto ddim = ShapeNHWC::fromXY(BMM->getDest()->getType()->dims()); auto ldim = ShapeNHWC::fromXY(BMM->getLHS()->getType()->dims()); @@ -938,7 +935,7 @@ void OpenCLFunction::execute(Context &ctx) { } cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto bdim = flattenCdr(BA->getBatch()->dims()); setKernelArg(kernel, numArgs + 1, bdim.first); @@ -972,7 +969,7 @@ void OpenCLFunction::execute(Context &ctx) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto bdim = flattenCdr(BRA->getBatch()->dims()); setKernelArg(kernel, numArgs + 1, bdim.first); @@ -994,7 +991,7 @@ void OpenCLFunction::execute(Context &ctx) { // the X and the Y in the output filter. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNHWC(CC->getDest()->getType()->dims()); auto idim = ShapeNHWC(CC->getSrc()->getType()->dims()); auto pads = PaddingTLBR(CC->getPads()); @@ -1039,7 +1036,7 @@ void OpenCLFunction::execute(Context &ctx) { auto *biasGrad = CG->getBiasGrad(); cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto destGradDim = ShapeNHWC(destGrad->dims()); auto srcDim = ShapeNHWC(src->dims()); @@ -1055,12 +1052,12 @@ void OpenCLFunction::execute(Context &ctx) { setKernelArg(kernel, numArgs + 6, destGradDim); setKernelArg(kernel, numArgs + 7, filterGradDim); // Zero memory for the output buffers. - fillBuffer(deviceBuffer_, tensors_[srcGrad], srcGrad->size(), 0, - srcGrad->getElementType()); - fillBuffer(deviceBuffer_, tensors_[filterGrad], filterGrad->size(), 0, - filterGrad->getElementType()); - fillBuffer(deviceBuffer_, tensors_[biasGrad], biasGrad->size(), 0, - biasGrad->getElementType()); + fillBuffer(deviceBuffer_, getValueOffset(srcGrad, bundle_.symbolTable), + srcGrad->size(), 0, srcGrad->getElementType()); + fillBuffer(deviceBuffer_, getValueOffset(filterGrad, bundle_.symbolTable), + filterGrad->size(), 0, filterGrad->getElementType()); + fillBuffer(deviceBuffer_, getValueOffset(biasGrad, bundle_.symbolTable), + biasGrad->size(), 0, biasGrad->getElementType()); (void)filter; assert(filter->dims() == filterGrad->dims() && "Dims should be the same"); @@ -1077,7 +1074,7 @@ void OpenCLFunction::execute(Context &ctx) { // the X and the Y in the output filter. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNHWC(PM->getDest()->getType()->dims()); auto idim = ShapeNHWC(PM->getSrc()->getType()->dims()); @@ -1102,7 +1099,7 @@ void OpenCLFunction::execute(Context &ctx) { // the X and the Y in the output filter. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNHWC(PM->getDest()->getType()->dims()); auto idim = ShapeNHWC(PM->getSrc()->getType()->dims()); @@ -1125,7 +1122,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *PMG = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto destGradDim = ShapeNHWC(PMG->getDestGrad()->dims()); auto srcGradDim = ShapeNHWC(PMG->getSrcGrad()->dims()); @@ -1153,7 +1150,7 @@ void OpenCLFunction::execute(Context &ctx) { // the X and the Y in the output filter. cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNHWC(PA->getDest()->getType()->dims()); auto idim = ShapeNHWC(PA->getSrc()->getType()->dims()); @@ -1181,7 +1178,7 @@ void OpenCLFunction::execute(Context &ctx) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); // Temporary hack to support 3-dim transposes. // TODO: support any dimensional transposes. @@ -1215,8 +1212,8 @@ void OpenCLFunction::execute(Context &ctx) { if (src == dest) { continue; } - size_t destOff = tensors_[dest]; - size_t srcOff = tensors_[src]; + size_t destOff = getValueOffset(dest, bundle_.symbolTable); + size_t srcOff = getValueOffset(src, bundle_.symbolTable); size_t sizeInBytes = dest->getSizeInBytes(); cl_event event{nullptr}; cl_int err = clEnqueueCopyBuffer(commands_, deviceBuffer_, deviceBuffer_, @@ -1232,7 +1229,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *GI = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); unsigned_t batchDims = GI->getBatchDims(); auto *data = GI->getData(); @@ -1268,7 +1265,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *SAI = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto *data = SAI->getData(); size_t dataSliceSize = data->size() / data->dims()[0]; @@ -1329,7 +1326,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto PA = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNCHW(PA->getDest()->getType()->dims()); auto idim = ShapeNCHW(PA->getSrc()->getType()->dims()); @@ -1359,7 +1356,7 @@ void OpenCLFunction::execute(Context &ctx) { if (auto *PM = dyn_cast(&I)) { cl_kernel kernel = createKernel(kernelName); setKernelArg(kernel, 0, deviceBuffer_); - auto numArgs = setKernelArgsForBuffers(kernel, I, 1, tensors_); + auto numArgs = setKernelArgsForBuffers(kernel, I, 1, bundle_); auto odim = ShapeNCHW(PM->getDest()->getType()->dims()); auto idim = ShapeNCHW(PM->getSrc()->getType()->dims()); @@ -1389,26 +1386,17 @@ void OpenCLFunction::execute(Context &ctx) { clReleaseKernel(kl.kernel_); } kernelLaunches_.clear(); - - auto copiedFromDeviceBytes = copyMutableWeightsFromDevice(); - (void)copiedFromDeviceBytes; - DEBUG_GLOW(llvm::dbgs() << "Copied " << copiedFromDeviceBytes - << " bytes from OpenCL device\n"); + copyOutputsFromDevice(ctx); } uint64_t OpenCLFunction::copyValueToDevice(const Value *v, void *buf) { uint64_t copiedBytes = 0; - auto it = tensors_.find(v); - assert(it != tensors_.end() && "Unknown value"); + auto it = bundle_.symbolTable.find(std::string(v->getName())); + assert(it != bundle_.symbolTable.end() && "Unknown value"); size_t sizeInBytes = v->getType()->getSizeInBytes(); // Issue a non-blocking command to copy the buffer to the device. if (sizeInBytes) { - if (!buf) { - Tensor *T = externalTensors_[v]; - assert(T && "Expected an external tensor"); - buf = T->getUnsafePtr(); - } - size_t valueOffset = it->second; + size_t valueOffset = it->second.offset; cl_event event{nullptr}; cl_int err = clEnqueueWriteBuffer( commands_, deviceBuffer_, /* blocking_read */ CL_FALSE, valueOffset, @@ -1425,25 +1413,20 @@ uint64_t OpenCLFunction::copyValueToDevice(const Value *v, void *buf) { uint64_t OpenCLFunction::copyValueFromDevice(const Value *v, void *buf) { uint64_t copiedBytes = 0; - auto it = tensors_.find(v); - assert(it != tensors_.end() && "Unknown value"); + auto it = bundle_.symbolTable.find(std::string(v->getName())); + assert(it != bundle_.symbolTable.end() && "Unknown value"); size_t sizeInBytes = v->getType()->getSizeInBytes(); // Issue a non-blocking command to copy the buffer from the device. if (sizeInBytes) { - if (!buf) { - Tensor *T = externalTensors_[v]; - assert(T && "Expected an external tensor"); - buf = T->getUnsafePtr(); - } - size_t valueOffset = it->second; + size_t valueOffset = it->second.offset; cl_event event{nullptr}; cl_int err = clEnqueueReadBuffer( commands_, deviceBuffer_, /* blocking_read */ CL_FALSE, valueOffset, sizeInBytes, buf, /* num_events_in_wait_list */ 0, /* event_list */ nullptr, /* event */ doProfile ? &event : nullptr); GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy from the device"); - DEBUG_GLOW(llvm::dbgs() << "Copied the value from device: " - << it->first->getName() << "\n"); + DEBUG_GLOW(llvm::dbgs() + << "Copied the value from device: " << v->getName() << "\n"); if (doProfile) { kernelLaunches_.emplace_back(KernelLaunch("copyFromDevice", event)); } @@ -1452,99 +1435,130 @@ uint64_t OpenCLFunction::copyValueFromDevice(const Value *v, void *buf) { return copiedBytes; } -uint64_t OpenCLFunction::copyMutableWeightsToDevice() { - uint64_t copiedBytes = 0; - for (auto it : tensors_) { - if (!externalTensors_.count(it.first)) { - continue; - } - if (auto *W = dyn_cast(it.first)) { - if (W->getMutability() == WeightVar::MutabilityKind::Constant) - continue; - } - copiedBytes += copyValueToDevice(it.first); +void OpenCLFunction::copyConstantsToDevice() { + deviceBuffer_ = allocDeviceBuffer(bundle_.constantWeightVarsMemSize + + bundle_.mutableWeightVarsMemSize + + bundle_.activationsMemSize); + size_t sizeInBytes = bundle_.constantWeightVarsMemSize; + // Issue a non-blocking command to copy the buffer to the device. + auto buf = bundle_.constants; + size_t valueOffset = 0; + cl_event event{nullptr}; + cl_int err = clEnqueueWriteBuffer( + commands_, deviceBuffer_, /* blocking_read */ CL_FALSE, valueOffset, + sizeInBytes, buf, /* num_events_in_wait_list */ 0, + /* event_list */ nullptr, /* event */ doProfile ? &event : nullptr); + GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy data to the device"); + if (doProfile) { + kernelLaunches_.emplace_back(KernelLaunch("copyToDevice", event)); } // Do it! clFinish(commands_); - return copiedBytes; } -uint64_t OpenCLFunction::copyConstantWeightsToDevice() { - uint64_t copiedBytes = 0; - for (auto it : tensors_) { - if (!externalTensors_.count(it.first)) { - continue; - } - if (auto *W = dyn_cast(it.first)) { - if (W->getMutability() != WeightVar::MutabilityKind::Constant) - continue; +void OpenCLFunction::copyInputsToDevice(const Context &ctx) { + for (auto PH : ctx.pairs()) { + auto symbolInfo = + bundle_.symbolTable.find(std::string(PH.first->getName())); + assert(symbolInfo != bundle_.symbolTable.end() && "Symbol not found"); + auto addr = symbolInfo->second.offset; + auto numBytes = symbolInfo->second.size; + // Issue a non-blocking command to copy the buffer to the device. + auto buf = PH.second->getUnsafePtr(); + cl_event event{nullptr}; + cl_int err = clEnqueueWriteBuffer( + commands_, deviceBuffer_, /* blocking_read */ CL_FALSE, addr, numBytes, + buf, /* num_events_in_wait_list */ 0, + /* event_list */ nullptr, /* event */ doProfile ? &event : nullptr); + GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy data to the device"); + if (doProfile) { + kernelLaunches_.emplace_back(KernelLaunch("copyToDevice", event)); } - copiedBytes += copyValueToDevice(it.first); } // Do it! clFinish(commands_); - return copiedBytes; } -uint64_t OpenCLFunction::copyMutableWeightsFromDevice() { - size_t copiedBytes = 0; - clFinish(commands_); - - for (auto it : tensors_) { - if (!externalTensors_.count(it.first)) { - continue; - } - if (auto *W = dyn_cast(it.first)) { - if (W->getMutability() == WeightVar::MutabilityKind::Constant) - continue; +void OpenCLFunction::copyOutputsFromDevice(const Context &ctx) { + for (auto PH : ctx.pairs()) { + auto symbolInfo = + bundle_.symbolTable.find(std::string(PH.first->getName())); + assert(symbolInfo != bundle_.symbolTable.end() && "Symbol not found"); + auto addr = symbolInfo->second.offset; + auto numBytes = symbolInfo->second.size; + // Issue a non-blocking command to copy the buffer to the device. + auto buf = PH.second->getUnsafePtr(); + cl_event event{nullptr}; + cl_int err = clEnqueueReadBuffer( + commands_, deviceBuffer_, /* blocking_read */ CL_FALSE, addr, numBytes, + buf, /* num_events_in_wait_list */ 0, + /* event_list */ nullptr, /* event */ doProfile ? &event : nullptr); + GLOW_ASSERT(err == CL_SUCCESS && "Unable to copy data to the device"); + if (doProfile) { + kernelLaunches_.emplace_back(KernelLaunch("copyToDevice", event)); } - copiedBytes += copyValueFromDevice(it.first); } + // Do it! clFinish(commands_); - return copiedBytes; } -void OpenCLFunction::allocateMemory(const Context &ctx) { - // The allocator assigns device memory addresses to the buffers. +/// Computes offsets and total allocation for Constants, Placeholders, and +/// Activations to build runtime symbol table. Returns +/// RuntimeBundle. +runtime::RuntimeBundle generateRuntimeBundle(const IRFunction *F) { + // Use a single allocator. The OpenCL backend uses a single buffer on the card + // for Constants, Placeholders, and Activations, in that order. MemoryAllocator allocator("GPU", 0xFFFFFFFF); - // Register the bound locations of the variables. - for (auto &v : F_->getGraph()->getParent()->getConstants()) { - auto *w = F_->getWeightForNode(v); - assert(!externalTensors_.count(w) && "The tensor is already registered"); - externalTensors_[w] = &v->getPayload(); + /// Symbol table mapping symbol name to offset for runtime. + std::unordered_map symbolTable; + + // Compute the offsets for Constants. + for (auto &v : F->getGraph()->getParent()->getConstants()) { + assert(isa(F->getWeightForNode(v)) && "Expected WeightVar"); + auto *w = cast(F->getWeightForNode(v)); + auto numBytes = w->getSizeInBytes(); + size_t addr = allocator.allocate(numBytes, v); + runtime::RuntimeSymbolInfo symbol; + symbol.size = numBytes; + symbol.offset = addr; + symbolTable.emplace(std::string(v->getName()), symbol); } - - // Register the bound locations of the placeholders. - for (auto PH : ctx.pairs()) { - auto *w = F_->getWeightForNode(PH.first); - assert(!externalTensors_.count(w) && "The tensor is already registered"); - externalTensors_[w] = PH.second; + uint64_t constantMaxSize = allocator.getMaxMemoryUsage(); + + // Compute the offsets for Placeholders. + for (auto &v : F->getGraph()->getParent()->getPlaceholders()) { + // Get the WeightVar for each Placeholder to calculate offsets. + assert(isa(F->getWeightForNode(v)) && "Expected WeightVar"); + auto *w = cast(F->getWeightForNode(v)); + auto numBytes = w->getSizeInBytes(); + size_t addr = allocator.allocate(numBytes, w); + runtime::RuntimeSymbolInfo symbol; + symbol.offset = addr; + symbol.size = numBytes; + symbolTable.emplace(std::string(v->getName()), symbol); } - // Assign device-space addresses to the weights. - for (auto it : externalTensors_) { - Tensor *T = it.second; - size_t sizeInBytes = T->getType().getSizeInBytes(); - size_t addr = allocator.allocate(sizeInBytes, it.first); - // Associate the new buffer with the weight value. - tensors_[it.first] = addr; - } + uint64_t placeholderMaxSize = allocator.getMaxMemoryUsage() - constantMaxSize; - // Assign device-space addresses to the activations. - for (const auto &I : F_->getInstrs()) { + for (const auto &I : F->getInstrs()) { if (auto *A = llvm::dyn_cast(&I)) { auto numBytes = I.getSizeInBytes(); size_t addr = allocator.allocate(numBytes, A); - assert(!tensors_.count(A) && "Allocation already made!"); - tensors_[A] = addr; + assert(!symbolTable.count(std::string(A->getName())) && + "Allocation already made!"); + runtime::RuntimeSymbolInfo symbol; + symbol.offset = addr; + symbol.size = numBytes; + symbolTable.emplace(std::string(A->getName()), symbol); continue; } if (auto *TV = llvm::dyn_cast(&I)) { // Calculate and store the length of the offset into the base, using the // source of the tensorview. - assert(!tensors_.count(TV) && "Allocation already made!"); + assert(!symbolTable.count(std::string(TV->getName())) && + "Allocation already made!"); size_t offsetLength = TV->getOffsets().empty() ? 0 : TV->getOffsets()[0]; auto *tvSource = TV->getSrc(); if (tvSource->dims().size() > 1) { @@ -1552,41 +1566,31 @@ void OpenCLFunction::allocateMemory(const Context &ctx) { offsetLength *= tvSource->dims()[i]; } } - assert(tensors_.count(tvSource) && "Source allocation not found!"); - tensors_[TV] = - tensors_[tvSource] + (offsetLength * TV->getType()->getElementSize()); + assert(symbolTable.count(std::string(tvSource->getName())) && + "Source allocation not found!"); + runtime::RuntimeSymbolInfo symbol; + symbol.offset = getValueOffset(tvSource, symbolTable) + + (offsetLength * TV->getType()->getElementSize()); + symbol.size = TV->getSizeInBytes(); + symbolTable.emplace(std::string(TV->getName()), symbol); continue; } if (auto *D = llvm::dyn_cast(&I)) { auto *A = D->getAlloc(); - assert(tensors_.count(A) && "Invalid deallocation!"); + assert(symbolTable.count(std::string(A->getName())) && + "Invalid deallocation!"); allocator.deallocate(A); continue; } } - - // Ask the memory allocator how much memory is required. What was the high - // watermark for this program. - uint64_t requiredSpace = allocator.getMaxMemoryUsage(); - DEBUG_GLOW(llvm::dbgs() << "Allocated GPU memory block of size: " - << requiredSpace << "\n"); - - // Release the memory from the previous run. - if (deviceBuffer_) { - freeDeviceBuffer(deviceBuffer_); - deviceBuffer_ = nullptr; - } - - deviceBuffer_ = allocDeviceBuffer(requiredSpace); - // Copy constant weights just once. - copyConstantWeightsToDevice(); -} - -Tensor *OpenCLFunction::getTensor(const Value *v) const { - assert(externalTensors_.count(v) && "Unknown value"); - auto ie = externalTensors_.find(v); - return ie->second; + uint64_t activationsMaxSize = + allocator.getMaxMemoryUsage() - placeholderMaxSize - constantMaxSize; + runtime::RuntimeBundle info(constantMaxSize, placeholderMaxSize, + activationsMaxSize); + info.symbolTable = std::move(symbolTable); + info.constants = collectConstants(F, constantMaxSize, info.symbolTable); + return info; } cl_mem OpenCLFunction::allocDeviceBuffer(uint64_t size) { @@ -1602,9 +1606,9 @@ cl_mem OpenCLFunction::allocDeviceBuffer(uint64_t size) { void OpenCLFunction::freeDeviceBuffer(cl_mem buf) { clReleaseMemObject(buf); } std::unique_ptr -OCLBackend::compileIR(std::unique_ptr IR, - const Context &ctx) const { - return llvm::make_unique(std::move(IR), ctx); +OCLBackend::compileIR(std::unique_ptr IR, const Context &) const { + runtime::RuntimeBundle bundle = generateRuntimeBundle(IR.get()); + return llvm::make_unique(std::move(IR), bundle); } std::unique_ptr diff --git a/lib/Backends/OpenCL/OpenCL.h b/lib/Backends/OpenCL/OpenCL.h index 8e42976496..9f0dea327c 100644 --- a/lib/Backends/OpenCL/OpenCL.h +++ b/lib/Backends/OpenCL/OpenCL.h @@ -22,6 +22,7 @@ #include "glow/Base/Traits.h" #include "glow/Graph/Context.h" #include "glow/Graph/Node.h" +#include "glow/IR/IR.h" #include "llvm/ADT/ArrayRef.h" #include @@ -34,7 +35,6 @@ namespace glow { -class IRFunction; class OCLConvolutionInst; class Value; @@ -70,11 +70,6 @@ class OpenCLFunction final : public CompiledFunction { }; /// The IR to be executed. std::unique_ptr F_; - /// Maps values to on-device buffers. This list includes both weights and - /// activations. - std::unordered_map tensors_; - /// Maps values to Tensors, that are *not* owned by this class. - std::unordered_map externalTensors_; /// CL compute device id. cl_device_id deviceId_; /// CL compute context. @@ -90,10 +85,13 @@ class OpenCLFunction final : public CompiledFunction { cl_mem deviceBuffer_{0}; /// Information about kernel launches. std::vector kernelLaunches_; + /// Runtime bundle that contains symbol offsets and constants. + runtime::RuntimeBundle bundle_; public: /// Ctor. - explicit OpenCLFunction(std::unique_ptr F, const Context &ctx); + explicit OpenCLFunction(std::unique_ptr F, + const runtime::RuntimeBundle &bundle); /// @name CompiledFunction interface ///@{ @@ -101,28 +99,28 @@ class OpenCLFunction final : public CompiledFunction { void execute(Context &ctx) override; ///@} + /// Allocates on device buffer and copies Constant weights to device. + void copyConstantsToDevice(); + /// Copies Inputs from \p ctx to on device memory. + void copyInputsToDevice(const Context &ctx); + /// Copies outputs from device to tensors in \p ctx. + void copyOutputsFromDevice(const Context &ctx); + /// Copy Function to device, an empty function for OpenCL. + void copyFunctionToDevice(){}; + /// Allocate Mutable buffers on device, this is an empty function on OpenCL + /// because the OCL backend uses a single buffer which is allocated when + /// constants are copied to the device. + void allocateMutableBuffersOnDevice(){}; + /// Frees runtime allocations. This is an empty function for OCL. + void freeAllocations(){}; private: - /// Allocate memory for the tensors. - void allocateMemory(const Context &ctx); /// Copy the value from a device to a provided buffer. - /// If \p buf is nullptr, the payload of the underlying tensor is used. /// \returns number of copied bytes. uint64_t copyValueFromDevice(const Value *v, void *buf = nullptr); /// Copy value from the provided buffer to the device. - /// If \p buf is nullptr, the payload of the underlying tensor is used. /// \returns number of copied bytes. uint64_t copyValueToDevice(const Value *v, void *buf = nullptr); - /// Copy mutable weights to the device. - /// \returns number of copied bytes. - uint64_t copyMutableWeightsToDevice(); - /// Copy constant weights to the device. - /// \returns number of copied bytes. - uint64_t copyConstantWeightsToDevice(); - /// Copy mutable weights from the device. - /// \returns number of copied bytes. - uint64_t copyMutableWeightsFromDevice(); - /// Fill the device \p buffer with a given \p value. /// \param len number of buffer elements to be filled by the \p value. /// Elements are considered to be of the type described by \p elemKind. @@ -155,9 +153,6 @@ class OpenCLFunction final : public CompiledFunction { cl_device_id device, llvm::ArrayRef global, llvm::ArrayRef local, std::vector &kernelLaunches); - - /// \returns a pointer to the tensor that is saved under \p v. - Tensor *getTensor(const Value *v) const; }; /// This is the OpenCL backend. diff --git a/lib/ExecutionEngine/ExecutionEngine.cpp b/lib/ExecutionEngine/ExecutionEngine.cpp index 675ab336bf..32de85a187 100644 --- a/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/lib/ExecutionEngine/ExecutionEngine.cpp @@ -78,9 +78,10 @@ void ExecutionEngine::run(Context &ctx) { assert(function_ && "No function has been compiled"); // TODO call runtime functions from EE instead of in the compiled function. // copyFunctionToDevice() + // copyConstantsToDevice() // allocateMutableBuffersOnDevice() - // copyMutablesToDevice(ctx) - // copyMutablesFromDevice(ctx) + // copyInputsToDevice(ctx) + // copyOutputsFromDevice(ctx) // freeAllocations() // We are working toward moving memory allocation and initialization to // runtime. As an intermediate the runtime functions are being called within