diff --git a/sycl/source/detail/queue_impl.hpp b/sycl/source/detail/queue_impl.hpp index 3fbf72042e998..9676b18d16134 100644 --- a/sycl/source/detail/queue_impl.hpp +++ b/sycl/source/detail/queue_impl.hpp @@ -646,8 +646,8 @@ class queue_impl : public std::enable_shared_from_this { // for in order ones. void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask); - static ContextImplPtr getContext(queue_impl *Queue) { - return Queue ? Queue->getContextImplPtr() : nullptr; + static context_impl *getContext(queue_impl *Queue) { + return Queue ? &Queue->getContextImpl() : nullptr; } // Must be called under MMutex protection diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 835fdc535f6b7..d1ed2d30655c3 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -429,7 +429,7 @@ class DispatchHostTask { "Host task submissions should have an associated queue"); interop_handle IH{MReqToMem, HostTask.MQueue, HostTask.MQueue->getDeviceImpl().shared_from_this(), - HostTask.MQueue->getContextImplPtr()}; + HostTask.MQueue->getContextImpl().shared_from_this()}; // TODO: should all the backends that support this entry point use this // for host task? auto &Queue = HostTask.MQueue; @@ -2677,7 +2677,7 @@ void enqueueImpKernel( detail::kernel_param_desc_t (*KernelParamDescGetter)(int), bool KernelHasSpecialCaptures) { // Run OpenCL kernel - auto &ContextImpl = Queue.getContextImplPtr(); + context_impl &ContextImpl = Queue.getContextImpl(); device_impl &DeviceImpl = Queue.getDeviceImpl(); ur_kernel_handle_t Kernel = nullptr; std::mutex *KernelMutex = nullptr; @@ -2715,7 +2715,7 @@ void enqueueImpKernel( KernelMutex = SyclKernelImpl->getCacheMutex(); } else { KernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel( - *ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc); + ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc); Kernel = KernelCacheVal->MKernelHandle; KernelMutex = KernelCacheVal->MMutex; Program = KernelCacheVal->MProgramHandle; @@ -2727,7 +2727,7 @@ void enqueueImpKernel( // Initialize device globals associated with this. std::vector DeviceGlobalInitEvents = - ContextImpl->initializeDeviceGlobals(Program, Queue); + ContextImpl.initializeDeviceGlobals(Program, Queue); if (!DeviceGlobalInitEvents.empty()) { std::vector EventsWithDeviceGlobalInits; EventsWithDeviceGlobalInits.reserve(RawEvents.size() + @@ -2784,9 +2784,9 @@ ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue, ur_program_handle_t Program = nullptr; device Device = Queue.get_device(); - ContextImplPtr ContextImpl = Queue.getContextImplPtr(); + context_impl &ContextImpl = Queue.getContextImpl(); std::optional CachedProgram = - ContextImpl->getProgramForHostPipe(Device, hostPipeEntry); + ContextImpl.getProgramForHostPipe(Device, hostPipeEntry); if (CachedProgram) Program = *CachedProgram; else { @@ -3004,7 +3004,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() { // Queue is created by graph_impl before creating command to submit to // scheduler. const AdapterPtr &Adapter = MQueue->getAdapter(); - auto ContextImpl = MQueue->getContextImplPtr(); + context_impl &ContextImpl = MQueue->getContextImpl(); device_impl &DeviceImpl = MQueue->getDeviceImpl(); // The CUDA & HIP backends don't have the equivalent of barrier @@ -3033,7 +3033,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() { false /* profilable*/ }; Adapter->call( - ContextImpl->getHandleRef(), DeviceImpl.getHandleRef(), &Desc, + ContextImpl.getHandleRef(), DeviceImpl.getHandleRef(), &Desc, &ChildCommandBuffer); } @@ -3043,12 +3043,12 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() { // available if a user asks for them inside the interop task scope std::vector ReqToMem; const std::vector &HandlerReq = HostTask->getRequirements(); - auto ReqToMemConv = [&ReqToMem, ContextImpl](Requirement *Req) { + auto ReqToMemConv = [&ReqToMem, &ContextImpl](Requirement *Req) { const std::vector &AllocaCmds = Req->MSYCLMemObj->MRecord->MAllocaCommands; for (AllocaCommandBase *AllocaCmd : AllocaCmds) - if (ContextImpl.get() == getContext(AllocaCmd->getQueue())) { + if (&ContextImpl == getContext(AllocaCmd->getQueue())) { auto MemArg = reinterpret_cast(AllocaCmd->getMemAllocation()); ReqToMem.emplace_back(std::make_pair(Req, MemArg)); @@ -3068,8 +3068,8 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() { ur_exp_command_buffer_handle_t InteropCommandBuffer = ChildCommandBuffer ? ChildCommandBuffer : MCommandBuffer; interop_handle IH{std::move(ReqToMem), MQueue, - DeviceImpl.shared_from_this(), ContextImpl, - InteropCommandBuffer}; + DeviceImpl.shared_from_this(), + ContextImpl.shared_from_this(), InteropCommandBuffer}; CommandBufferNativeCommandData CustomOpData{ std::move(IH), HostTask->MHostTask->MInteropTask}; @@ -3471,7 +3471,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() { EnqueueNativeCommandData CustomOpData{ interop_handle{std::move(ReqToMem), HostTask->MQueue, HostTask->MQueue->getDeviceImpl().shared_from_this(), - HostTask->MQueue->getContextImplPtr()}, + HostTask->MQueue->getContextImpl().shared_from_this()}, HostTask->MHostTask->MInteropTask}; ur_bool_t NativeCommandSupport = false; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 3116ae898fba6..207ee7df83938 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -52,7 +52,7 @@ static bool IsSuitableSubReq(const Requirement *Req) { return Req->MIsSubBuffer; } -static bool isOnSameContext(const ContextImplPtr Context, queue_impl *Queue) { +static bool isOnSameContext(context_impl *Context, queue_impl *Queue) { // Covers case for host usage (nullptr == nullptr) and existing device // contexts comparison. return Context == queue_impl::getContext(Queue); @@ -233,8 +233,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue, "shouldn't lead to any enqueuing (no linked " "alloca or exceeding the leaf limit)."); } else - MemObject->MRecord.reset(new MemObjRecord{ - queue_impl::getContext(Queue).get(), LeafLimit, AllocateDependency}); + MemObject->MRecord.reset(new MemObjRecord{queue_impl::getContext(Queue), + LeafLimit, AllocateDependency}); MMemObjs.push_back(MemObject); return MemObject->MRecord.get(); @@ -346,15 +346,16 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record, } AllocaCommandBase *AllocaCmdSrc = - findAllocaForReq(Record, Req, Record->MCurContext); + findAllocaForReq(Record, Req, Record->getCurContext()); if (!AllocaCmdSrc && IsSuitableSubReq(Req)) { // Since no alloca command for the sub buffer requirement was found in the // current context, need to find a parent alloca command for it (it must be // there) auto IsSuitableAlloca = [Record](AllocaCommandBase *AllocaCmd) { - bool Res = isOnSameContext(Record->MCurContext, AllocaCmd->getQueue()) && - // Looking for a parent buffer alloca command - AllocaCmd->getType() == Command::CommandType::ALLOCA; + bool Res = + isOnSameContext(Record->getCurContext(), AllocaCmd->getQueue()) && + // Looking for a parent buffer alloca command + AllocaCmd->getType() == Command::CommandType::ALLOCA; return Res; }; const auto It = @@ -384,10 +385,9 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record, NewCmd = insertMapUnmapForLinkedCmds(AllocaCmdSrc, AllocaCmdDst, MapMode); Record->MHostAccess = MapMode; } else { - if ((Req->MAccessMode == access::mode::discard_write) || (Req->MAccessMode == access::mode::discard_read_write)) { - Record->MCurContext = Context; + Record->setCurContext(Context); return nullptr; } else { // Full copy of buffer is needed to avoid loss of data that may be caused @@ -409,7 +409,7 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record, addNodeToLeaves(Record, NewCmd, access::mode::read_write, ToEnqueue); for (Command *Cmd : ToCleanUp) cleanupCommand(Cmd); - Record->MCurContext = Context; + Record->setCurContext(Context); return NewCmd; } @@ -422,7 +422,8 @@ Command *Scheduler::GraphBuilder::remapMemoryObject( AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd; assert(LinkedAllocaCmd && "Linked alloca command expected"); - std::set Deps = findDepsForReq(Record, Req, Record->MCurContext); + std::set Deps = + findDepsForReq(Record, Req, Record->getCurContext()); UnMapMemObject *UnMapCmd = new UnMapMemObject( LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(), @@ -473,7 +474,7 @@ Scheduler::GraphBuilder::addCopyBack(Requirement *Req, std::set Deps = findDepsForReq(Record, Req, nullptr); AllocaCommandBase *SrcAllocaCmd = - findAllocaForReq(Record, Req, Record->MCurContext); + findAllocaForReq(Record, Req, Record->getCurContext()); auto MemCpyCmdUniquePtr = std::make_unique( *SrcAllocaCmd->getRequirement(), SrcAllocaCmd, *Req, &Req->MData, @@ -525,7 +526,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req, AllocaCommandBase *HostAllocaCmd = getOrCreateAllocaForReq(Record, Req, nullptr, ToEnqueue); - if (isOnSameContext(Record->MCurContext, HostAllocaCmd->getQueue())) { + if (isOnSameContext(Record->getCurContext(), HostAllocaCmd->getQueue())) { if (!isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) { remapMemoryObject(Record, Req, Req->MIsSubBuffer ? (static_cast( @@ -571,10 +572,8 @@ Command *Scheduler::GraphBuilder::addCGUpdateHost( /// 1. New and examined commands only read -> can bypass /// 2. New and examined commands has non-overlapping requirements -> can bypass /// 3. New and examined commands have different contexts -> cannot bypass -std::set -Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record, - const Requirement *Req, - const ContextImplPtr &Context) { +std::set Scheduler::GraphBuilder::findDepsForReq( + MemObjRecord *Record, const Requirement *Req, context_impl *Context) { std::set RetDeps; std::vector Visited; const bool ReadOnlyReq = Req->MAccessMode == access::mode::read; @@ -644,7 +643,7 @@ DepDesc Scheduler::GraphBuilder::findDepForRecord(Command *Cmd, // The function searches for the alloca command matching context and // requirement. AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq( - MemObjRecord *Record, const Requirement *Req, const ContextImplPtr &Context, + MemObjRecord *Record, const Requirement *Req, context_impl *Context, bool AllowConst) { auto IsSuitableAlloca = [&Context, Req, AllowConst](AllocaCommandBase *AllocaCmd) { @@ -663,7 +662,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq( return (Record->MAllocaCommands.end() != It) ? *It : nullptr; } -static bool checkHostUnifiedMemory(const ContextImplPtr &Ctx) { +static bool checkHostUnifiedMemory(context_impl *Ctx) { if (const char *HUMConfig = SYCLConfig::get()) { if (std::strcmp(HUMConfig, "0") == 0) return Ctx == nullptr; @@ -744,7 +743,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq( Record->MAllocaCommands.push_back(HostAllocaCmd); Record->MWriteLeaves.push_back(HostAllocaCmd, ToEnqueue); ++(HostAllocaCmd->MLeafCounter); - Record->MCurContext = nullptr; + Record->setCurContext(nullptr); } } } else { @@ -768,11 +767,12 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq( bool PinnedHostMemory = MemObj->usesPinnedHostMemory(); bool HostUnifiedMemoryOnNonHostDevice = - Queue == nullptr ? checkHostUnifiedMemory(Record->MCurContext) - : HostUnifiedMemory; + Queue == nullptr + ? checkHostUnifiedMemory(Record->getCurContext()) + : HostUnifiedMemory; if (PinnedHostMemory || HostUnifiedMemoryOnNonHostDevice) { AllocaCommandBase *LinkedAllocaCmdCand = findAllocaForReq( - Record, Req, Record->MCurContext, /*AllowConst=*/false); + Record, Req, Record->getCurContext(), /*AllowConst=*/false); // Cannot setup link if candidate is linked already if (LinkedAllocaCmdCand && @@ -812,7 +812,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq( AllocaCmd->MIsActive = false; } else { LinkedAllocaCmd->MIsActive = false; - Record->MCurContext = Context; + Record->setCurContext(Context); std::set Deps = findDepsForReq(Record, Req, Context); for (Command *Dep : Deps) { @@ -965,7 +965,7 @@ Command *Scheduler::GraphBuilder::addCG( AllocaCmd = getOrCreateAllocaForReq(Record, Req, QueueForAlloca, ToEnqueue); - isSameCtx = isOnSameContext(Record->MCurContext, QueueForAlloca); + isSameCtx = isOnSameContext(Record->getCurContext(), QueueForAlloca); } // If there is alloca command we need to check if the latest memory is in @@ -992,7 +992,7 @@ Command *Scheduler::GraphBuilder::addCG( const detail::CGHostTask &HT = static_cast(NewCmd->getCG()); - if (!isOnSameContext(Record->MCurContext, HT.MQueue.get())) { + if (!isOnSameContext(Record->getCurContext(), HT.MQueue.get())) { NeedMemMoveToHost = true; MemMoveTargetQueue = HT.MQueue.get(); } @@ -1226,9 +1226,7 @@ Command *Scheduler::GraphBuilder::connectDepEvent( try { std::shared_ptr HT(new detail::HostTask); std::unique_ptr ConnectCG(new detail::CGHostTask( - std::move(HT), - /* Queue = */ Cmd->getQueue(), - /* Context = */ {}, + std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ nullptr, /* Args = */ {}, detail::CG::StorageInitHelper( /* ArgsStorage = */ {}, /* AccStorage = */ {}, @@ -1302,7 +1300,7 @@ Command *Scheduler::GraphBuilder::addCommandGraphUpdate( AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue, ToEnqueue); - isSameCtx = isOnSameContext(Record->MCurContext, Queue); + isSameCtx = isOnSameContext(Record->getCurContext(), Queue); } if (!isSameCtx) { diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index 856738d324da7..b7ebfc47cf5dd 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -185,7 +185,6 @@ class event_impl; class context_impl; class DispatchHostTask; -using ContextImplPtr = std::shared_ptr; using EventImplPtr = std::shared_ptr; using StreamImplPtr = std::shared_ptr; @@ -214,6 +213,10 @@ struct MemObjRecord { // The context which has the latest state of the memory object. std::shared_ptr MCurContext; + context_impl *getCurContext() { return MCurContext.get(); } + void setCurContext(context_impl *Ctx) { + MCurContext = Ctx ? Ctx->shared_from_this() : nullptr; + } // The mode this object can be accessed from the host (host_accessor). // Valid only if the current usage is on host. @@ -688,7 +691,7 @@ class Scheduler { /// Finds dependencies for the requirement. std::set findDepsForReq(MemObjRecord *Record, const Requirement *Req, - const ContextImplPtr &Context); + context_impl *Context); EmptyCommand *addEmptyCmd(Command *Cmd, const std::vector &Req, @@ -702,7 +705,7 @@ class Scheduler { /// Searches for suitable alloca in memory record. AllocaCommandBase *findAllocaForReq(MemObjRecord *Record, const Requirement *Req, - const ContextImplPtr &Context, + context_impl *Context, bool AllowConst = true); friend class Command;