diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 9f379245a309b2..ba703738d3c7c3 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -591,19 +591,42 @@ private static ValueTask FinalizeValueTaskReturningThunk(Continuation continuati } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void RestoreExecutionContext(ExecutionContext? previousExecutionCtx) + private static void RestoreExecutionContext(ExecutionContext? previousExecCtx) { Thread thread = Thread.CurrentThreadAssumedInitialized; - ExecutionContext? currentExecutionCtx = thread._executionContext; - if (previousExecutionCtx != currentExecutionCtx) + ExecutionContext? currentExecCtx = thread._executionContext; + if (previousExecCtx != currentExecCtx) { - ExecutionContext.RestoreChangedContextToThread(thread, previousExecutionCtx, currentExecutionCtx); + ExecutionContext.RestoreChangedContextToThread(thread, previousExecCtx, currentExecCtx); } } - private static void CaptureContinuationContext(ref object context, ref CorInfoContinuationFlags flags) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void CaptureContexts(out ExecutionContext? execCtx, out SynchronizationContext? syncCtx) + { + Thread thread = Thread.CurrentThreadAssumedInitialized; + execCtx = thread._executionContext; + syncCtx = thread._synchronizationContext; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void RestoreContexts(bool suspended, ExecutionContext? previousExecCtx, SynchronizationContext? previousSyncCtx) + { + Thread thread = Thread.CurrentThreadAssumedInitialized; + if (!suspended && previousSyncCtx != thread._synchronizationContext) + { + thread._synchronizationContext = previousSyncCtx; + } + + ExecutionContext? currentExecCtx = thread._executionContext; + if (previousExecCtx != currentExecCtx) + { + ExecutionContext.RestoreChangedContextToThread(thread, previousExecCtx, currentExecCtx); + } + } + + private static void CaptureContinuationContext(SynchronizationContext syncCtx, ref object context, ref CorInfoContinuationFlags flags) { - SynchronizationContext? syncCtx = Thread.CurrentThreadAssumedInitialized._synchronizationContext; if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) { flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_CAPTURED_SYNCHRONIZATION_CONTEXT; diff --git a/src/coreclr/inc/corinfo.h b/src/coreclr/inc/corinfo.h index ff0420c799568f..b078d6c9329b1a 100644 --- a/src/coreclr/inc/corinfo.h +++ b/src/coreclr/inc/corinfo.h @@ -1747,6 +1747,8 @@ struct CORINFO_ASYNC_INFO // Method handle for AsyncHelpers.RestoreExecutionContext CORINFO_METHOD_HANDLE restoreExecutionContextMethHnd; CORINFO_METHOD_HANDLE captureContinuationContextMethHnd; + CORINFO_METHOD_HANDLE captureContextsMethHnd; + CORINFO_METHOD_HANDLE restoreContextsMethHnd; }; // Flags passed from JIT to runtime. diff --git a/src/coreclr/inc/jiteeversionguid.h b/src/coreclr/inc/jiteeversionguid.h index 7043a7fa2f55ac..fe1aaade4fb3b9 100644 --- a/src/coreclr/inc/jiteeversionguid.h +++ b/src/coreclr/inc/jiteeversionguid.h @@ -37,11 +37,11 @@ #include -constexpr GUID JITEEVersionIdentifier = { /* d24a67e0-9e57-4c9e-ad31-5785df2526f2 */ - 0xd24a67e0, - 0x9e57, - 0x4c9e, - {0xad, 0x31, 0x57, 0x85, 0xdf, 0x25, 0x26, 0xf2} +constexpr GUID JITEEVersionIdentifier = { /* 2d40ec46-2e41-4a8b-8349-3c1267b95821 */ + 0x2d40ec46, + 0x2e41, + 0x4a8b, + {0x83, 0x49, 0x3c, 0x12, 0x67, 0xb9, 0x58, 0x21} }; #endif // JIT_EE_VERSIONING_GUID_H diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index 6dc067140547d8..d0c2fb4369abc5 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -6,11 +6,11 @@ // machines. The following key operations are performed: // // 1. Early, after import but before inlining: for async calls that require -// ExecutionContext save/restore semantics, ExecutionContext capture and +// ExecutionContext/SynchronizationContext save/restore semantics, capture and // restore calls are inserted around the async call site. This ensures proper // context flow across await boundaries when the continuation may run on -// different threads or synchronization contexts. The captured ExecutionContext -// is stored in a temporary local and restored after the async call completes, +// different threads or synchronization contexts. The captured contexts +// are stored in temporary locals and restored after the async call completes, // with special handling for calls inside try regions using try-finally blocks. // // Later, right before lowering the actual transformation to a state machine is @@ -47,7 +47,7 @@ //------------------------------------------------------------------------ // Compiler::SaveAsyncContexts: -// Insert code to save and restore ExecutionContext around async call sites. +// Insert code to save and restore contexts around async call sites. // // Returns: // Suitable phase status. @@ -80,53 +80,75 @@ PhaseStatus Compiler::SaveAsyncContexts() tree = tree->AsLclVarCommon()->Data(); } - if (!tree->IsCall() || !tree->AsCall()->IsAsyncAndAlwaysSavesAndRestoresExecutionContext()) + if (!tree->IsCall()) { ValidateNoAsyncSavesNecessaryInStatement(stmt); continue; } GenTreeCall* call = tree->AsCall(); + if (!call->IsAsync()) + { + ValidateNoAsyncSavesNecessaryInStatement(stmt); + continue; + } - unsigned lclNum = lvaGrabTemp(false DEBUGARG("ExecutionContext for SaveAndRestore async call")); + const AsyncCallInfo& asyncCallInfo = call->GetAsyncInfo(); - JITDUMP("Saving ExecutionContext in V%02u around [%06u]\n", lclNum, call->gtTreeID); + // Currently we always expect that ExecutionContext and + // SynchronizationContext correlate about their save/restore + // behavior. + assert((asyncCallInfo.ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore) == + asyncCallInfo.SaveAndRestoreSynchronizationContextField); - CORINFO_ASYNC_INFO* asyncInfo = eeGetAsyncInfo(); + if (asyncCallInfo.ExecutionContextHandling != ExecutionContextHandling::SaveAndRestore) + { + continue; + } - GenTreeCall* capture = gtNewCallNode(CT_USER_FUNC, asyncInfo->captureExecutionContextMethHnd, TYP_REF); - CORINFO_CALL_INFO callInfo = {}; - callInfo.hMethod = capture->gtCallMethHnd; - callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); - impMarkInlineCandidate(capture, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); + unsigned suspendedLclNum = + lvaGrabTemp(false DEBUGARG(printfAlloc("Suspended indicator for [%06u]", dspTreeID(call)))); + unsigned execCtxLclNum = + lvaGrabTemp(false DEBUGARG(printfAlloc("ExecutionContext for [%06u]", dspTreeID(call)))); + unsigned syncCtxLclNum = + lvaGrabTemp(false DEBUGARG(printfAlloc("SynchronizationContext for [%06u]", dspTreeID(call)))); - if (capture->IsInlineCandidate()) - { - Statement* captureStmt = fgNewStmtFromTree(capture); + LclVarDsc* suspendedLclDsc = lvaGetDesc(suspendedLclNum); + suspendedLclDsc->lvType = TYP_UBYTE; + suspendedLclDsc->lvHasLdAddrOp = true; - GenTreeRetExpr* retExpr = gtNewInlineCandidateReturnExpr(capture, TYP_REF); + LclVarDsc* execCtxLclDsc = lvaGetDesc(execCtxLclNum); + execCtxLclDsc->lvType = TYP_REF; + execCtxLclDsc->lvHasLdAddrOp = true; - capture->GetSingleInlineCandidateInfo()->retExpr = retExpr; - GenTree* storeCapture = gtNewTempStore(lclNum, retExpr); - Statement* storeCaptureStmt = fgNewStmtFromTree(storeCapture); + LclVarDsc* syncCtxLclDsc = lvaGetDesc(syncCtxLclNum); + syncCtxLclDsc->lvType = TYP_REF; + syncCtxLclDsc->lvHasLdAddrOp = true; - fgInsertStmtBefore(curBB, stmt, captureStmt); - fgInsertStmtBefore(curBB, stmt, storeCaptureStmt); + call->asyncInfo->SynchronizationContextLclNum = syncCtxLclNum; - JITDUMP("Inserted capture:\n"); - DISPSTMT(captureStmt); - DISPSTMT(storeCaptureStmt); - } - else - { - GenTree* storeCapture = gtNewTempStore(lclNum, capture); - Statement* storeCaptureStmt = fgNewStmtFromTree(storeCapture); + call->gtArgs.PushBack(this, NewCallArg::Primitive(gtNewLclAddrNode(suspendedLclNum, 0)) + .WellKnown(WellKnownArg::AsyncSuspendedIndicator)); - fgInsertStmtBefore(curBB, stmt, storeCaptureStmt); + JITDUMP("Saving contexts around [%06u], ExecutionContext = V%02u, SynchronizationContext = V%02u\n", + call->gtTreeID, execCtxLclNum, syncCtxLclNum); - JITDUMP("Inserted capture:\n"); - DISPSTMT(storeCaptureStmt); - } + CORINFO_ASYNC_INFO* asyncInfo = eeGetAsyncInfo(); + + GenTreeCall* capture = gtNewCallNode(CT_USER_FUNC, asyncInfo->captureContextsMethHnd, TYP_VOID); + capture->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(syncCtxLclNum, 0))); + capture->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclAddrNode(execCtxLclNum, 0))); + + CORINFO_CALL_INFO callInfo = {}; + callInfo.hMethod = capture->gtCallMethHnd; + callInfo.methodFlags = info.compCompHnd->getMethodAttribs(callInfo.hMethod); + impMarkInlineCandidate(capture, MAKE_METHODCONTEXT(callInfo.hMethod), false, &callInfo, compInlineContext); + + Statement* captureStmt = fgNewStmtFromTree(capture); + fgInsertStmtBefore(curBB, stmt, captureStmt); + + JITDUMP("Inserted capture:\n"); + DISPSTMT(captureStmt); BasicBlock* restoreBB = curBB; Statement* restoreAfterStmt = stmt; @@ -150,8 +172,10 @@ PhaseStatus Compiler::SaveAsyncContexts() #endif } - GenTreeCall* restore = gtNewCallNode(CT_USER_FUNC, asyncInfo->restoreExecutionContextMethHnd, TYP_VOID); - restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(lclNum))); + GenTreeCall* restore = gtNewCallNode(CT_USER_FUNC, asyncInfo->restoreContextsMethHnd, TYP_VOID); + restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(syncCtxLclNum))); + restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(execCtxLclNum))); + restore->gtArgs.PushFront(this, NewCallArg::Primitive(gtNewLclVarNode(suspendedLclNum))); callInfo = {}; callInfo.hMethod = restore->gtCallMethHnd; @@ -361,7 +385,8 @@ class AsyncLiveness void StartBlock(BasicBlock* block); void Update(GenTree* node); bool IsLive(unsigned lclNum); - void GetLiveLocals(jitstd::vector& liveLocals, unsigned fullyDefinedRetBufLcl); + template + void GetLiveLocals(jitstd::vector& liveLocals, Functor includeLocal); private: bool IsLocalCaptureUnnecessary(unsigned lclNum); @@ -539,14 +564,15 @@ bool AsyncLiveness::IsLive(unsigned lclNum) // Get live locals that should be captured at this point. // // Parameters: -// liveLocals - Vector to add live local information into -// fullyDefinedRetBufLcl - Local to skip even if live +// liveLocals - Vector to add live local information into +// includeLocal - Functor to check if a local should be included // -void AsyncLiveness::GetLiveLocals(jitstd::vector& liveLocals, unsigned fullyDefinedRetBufLcl) +template +void AsyncLiveness::GetLiveLocals(jitstd::vector& liveLocals, Functor includeLocal) { for (unsigned lclNum = 0; lclNum < m_numVars; lclNum++) { - if ((lclNum != fullyDefinedRetBufLcl) && IsLive(lclNum)) + if (includeLocal(lclNum) && IsLive(lclNum)) { liveLocals.push_back(LiveLocalInfo(lclNum)); } @@ -776,6 +802,8 @@ void AsyncTransformation::Transform( ContinuationLayout layout = LayOutContinuation(block, call, liveLocals); + ClearSuspendedIndicator(block, call); + CallDefinitionInfo callDefInfo = CanonicalizeCallDefinition(block, call, life); unsigned stateNum = (unsigned)m_resumptionBBs.size(); @@ -783,7 +811,7 @@ void AsyncTransformation::Transform( BasicBlock* suspendBB = CreateSuspension(block, call, stateNum, life, layout); - CreateCheckAndSuspendAfterCall(block, callDefInfo, life, suspendBB, remainder); + CreateCheckAndSuspendAfterCall(block, call, callDefInfo, life, suspendBB, remainder); BasicBlock* resumeBB = CreateResumption(block, *remainder, call, callDefInfo, stateNum, layout); @@ -808,27 +836,30 @@ void AsyncTransformation::CreateLiveSetForSuspension(BasicBlock* AsyncLiveness& life, jitstd::vector& liveLocals) { - unsigned fullyDefinedRetBufLcl = BAD_VAR_NUM; - CallArg* retbufArg = call->gtArgs.GetRetBufferArg(); - if (retbufArg != nullptr) - { - GenTree* retbuf = retbufArg->GetNode(); - if (retbuf->IsLclVarAddr()) + SmallHashTable excludedLocals(m_comp->getAllocator(CMK_Async)); + + auto visitDef = [&](const LocalDef& def) { + if (def.IsEntire) { - LclVarDsc* dsc = m_comp->lvaGetDesc(retbuf->AsLclVarCommon()); - ClassLayout* defLayout = m_comp->typGetObjLayout(call->gtRetClsHnd); - if (defLayout->GetSize() == dsc->lvExactSize()) - { - // This call fully defines this retbuf. There is no need to - // consider it live across the call since it is going to be - // overridden anyway. - fullyDefinedRetBufLcl = retbuf->AsLclVarCommon()->GetLclNum(); - JITDUMP(" V%02u is a fully defined retbuf and will not be considered live\n", fullyDefinedRetBufLcl); - } + JITDUMP(" V%02u is fully defined and will not be considered live\n", def.Def->GetLclNum()); + excludedLocals.AddOrUpdate(def.Def->GetLclNum(), true); } + return GenTree::VisitResult::Continue; + }; + + call->VisitLocalDefs(m_comp, visitDef); + + const AsyncCallInfo& asyncInfo = call->GetAsyncInfo(); + + if (asyncInfo.SynchronizationContextLclNum != BAD_VAR_NUM) + { + // This one is only live on the synchronous path, which liveness cannot prove + excludedLocals.AddOrUpdate(asyncInfo.SynchronizationContextLclNum, true); } - life.GetLiveLocals(liveLocals, fullyDefinedRetBufLcl); + life.GetLiveLocals(liveLocals, [&](unsigned lclNum) { + return !excludedLocals.Contains(lclNum); + }); LiftLIREdges(block, defs, liveLocals); #ifdef DEBUG @@ -1080,6 +1111,80 @@ ContinuationLayout AsyncTransformation::LayOutContinuation(BasicBlock* return layout; } +//------------------------------------------------------------------------ +// AsyncTransformation::ClearSuspendedIndicator: +// Generate IR to clear the value of the suspended indicator local. +// +// Parameters: +// block - Block to generate IR into +// call - The async call (not contained in "block") +// +void AsyncTransformation::ClearSuspendedIndicator(BasicBlock* block, GenTreeCall* call) +{ + CallArg* suspendedArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); + if (suspendedArg == nullptr) + { + return; + } + + GenTree* suspended = suspendedArg->GetNode(); + if (!suspended->IsLclVarAddr() && + (!suspended->OperIs(GT_LCL_VAR) || m_comp->lvaVarAddrExposed(suspended->AsLclVarCommon()->GetLclNum()))) + { + // We will need a second use of this, so spill to a local + LIR::Use use(LIR::AsRange(block), &suspendedArg->NodeRef(), call); + use.ReplaceWithLclVar(m_comp); + suspended = use.Def(); + } + + GenTree* value = m_comp->gtNewIconNode(0); + GenTree* storeSuspended = + m_comp->gtNewStoreValueNode(TYP_UBYTE, m_comp->gtCloneExpr(suspended), value, GTF_IND_NONFAULTING); + + LIR::AsRange(block).InsertBefore(call, LIR::SeqTree(m_comp, storeSuspended)); +} + +//------------------------------------------------------------------------ +// AsyncTransformation::SetSuspendedIndicator: +// Generate IR to set the value of the suspended indicator local, and remove +// the argument from the call. +// +// Parameters: +// block - Block to generate IR into +// callBlock - Block containing the call +// call - The async call +// +void AsyncTransformation::SetSuspendedIndicator(BasicBlock* block, BasicBlock* callBlock, GenTreeCall* call) +{ + CallArg* suspendedArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); + if (suspendedArg == nullptr) + { + return; + } + + GenTree* suspended = suspendedArg->GetNode(); + assert(suspended->IsLclVarAddr() || suspended->OperIs(GT_LCL_VAR)); // Ensured by ClearSuspendedIndicator + + GenTree* value = m_comp->gtNewIconNode(1); + GenTree* storeSuspended = + m_comp->gtNewStoreValueNode(TYP_UBYTE, m_comp->gtCloneExpr(suspended), value, GTF_IND_NONFAULTING); + + LIR::AsRange(block).InsertAtEnd(LIR::SeqTree(m_comp, storeSuspended)); + + call->gtArgs.RemoveUnsafe(suspendedArg); + call->asyncInfo->HasSuspensionIndicatorDef = false; + + // Avoid leaving LCL_ADDR around which will DNER the local. + if (suspended->IsLclVarAddr()) + { + LIR::AsRange(callBlock).Remove(suspended); + } + else + { + suspended->SetUnusedValue(); + } +} + //------------------------------------------------------------------------ // AsyncTransformation::CanonicalizeCallDefinition: // Put the call definition in a canonical form. This ensures that either the @@ -1225,7 +1330,7 @@ BasicBlock* AsyncTransformation::CreateSuspension( if (layout.GCRefsCount > 0) { - FillInGCPointersOnSuspension(layout, suspendBB); + FillInGCPointersOnSuspension(call, layout, suspendBB); } if (layout.DataSize > 0) @@ -1305,10 +1410,13 @@ GenTreeCall* AsyncTransformation::CreateAllocContinuationCall(AsyncLiveness& lif // parts that need to be stored. // // Parameters: +// call - The async call that is being transformed // layout - Layout information // suspendBB - Basic block to add IR to. // -void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout& layout, BasicBlock* suspendBB) +void AsyncTransformation::FillInGCPointersOnSuspension(GenTreeCall* call, + const ContinuationLayout& layout, + BasicBlock* suspendBB) { unsigned objectArrLclNum = GetGCDataArrayVar(); @@ -1397,8 +1505,16 @@ void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout& if (layout.ContinuationContextGCDataIndex != UINT_MAX) { - // Insert call AsyncHelpers.CaptureContinuationContext(ref - // newContinuation.GCData[ContinuationContextGCDataIndex], ref newContinuation.Flags). + const AsyncCallInfo& callInfo = call->GetAsyncInfo(); + assert(callInfo.SaveAndRestoreSynchronizationContextField && + (callInfo.SynchronizationContextLclNum != BAD_VAR_NUM)); + + // Insert call + // AsyncHelpers.CaptureContinuationContext( + // syncContextFromBeforeCall, + // ref newContinuation.GCData[ContinuationContextGCDataIndex], + // ref newContinuation.Flags). + GenTree* syncContextPlaceholder = m_comp->gtNewNull(); GenTree* contextElementPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); GenTree* flagsPlaceholder = m_comp->gtNewZeroConNode(TYP_BYREF); GenTreeCall* captureCall = @@ -1406,15 +1522,24 @@ void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout& captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(flagsPlaceholder)); captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(contextElementPlaceholder)); + captureCall->gtArgs.PushFront(m_comp, NewCallArg::Primitive(syncContextPlaceholder)); m_comp->compCurBB = suspendBB; m_comp->fgMorphTree(captureCall); LIR::AsRange(suspendBB).InsertAtEnd(LIR::SeqTree(m_comp, captureCall)); - // Now replace contextElementPlaceholder with actual address of the context element + // Replace sync context placeholder with actual sync context from before call LIR::Use use; - bool gotUse = LIR::AsRange(suspendBB).TryGetUse(contextElementPlaceholder, &use); + bool gotUse = LIR::AsRange(suspendBB).TryGetUse(syncContextPlaceholder, &use); + assert(gotUse); + GenTree* syncContextLcl = m_comp->gtNewLclvNode(callInfo.SynchronizationContextLclNum, TYP_REF); + LIR::AsRange(suspendBB).InsertBefore(syncContextPlaceholder, syncContextLcl); + use.ReplaceWith(syncContextLcl); + LIR::AsRange(suspendBB).Remove(syncContextPlaceholder); + + // Replace contextElementPlaceholder with actual address of the context element + gotUse = LIR::AsRange(suspendBB).TryGetUse(contextElementPlaceholder, &use); assert(gotUse); GenTree* objectArr = m_comp->gtNewLclvNode(objectArrLclNum, TYP_REF); @@ -1426,7 +1551,7 @@ void AsyncTransformation::FillInGCPointersOnSuspension(const ContinuationLayout& use.ReplaceWith(contextElementOffset); LIR::AsRange(suspendBB).Remove(contextElementPlaceholder); - // And now replace flagsPlaceholder with actual address of the flags + // Replace flagsPlaceholder with actual address of the flags gotUse = LIR::AsRange(suspendBB).TryGetUse(flagsPlaceholder, &use); assert(gotUse); @@ -1539,12 +1664,14 @@ void AsyncTransformation::FillInDataOnSuspension(const jitstd::vectorbbNum, stateNum); + SetSuspendedIndicator(resumeBB, block, call); + // We need to restore data before we restore GC pointers, since restoring // the data may also write the GC pointer fields with nulls. unsigned resumeByteArrLclNum = BAD_VAR_NUM; @@ -1649,7 +1778,7 @@ BasicBlock* AsyncTransformation::CreateResumption(BasicBlock* bloc if (layout.ExceptionGCDataIndex != UINT_MAX) { - storeResultBB = RethrowExceptionOnResumption(block, remainder, resumeObjectArrLclNum, layout, resumeBB); + storeResultBB = RethrowExceptionOnResumption(block, resumeObjectArrLclNum, layout, resumeBB); } } @@ -1822,7 +1951,6 @@ void AsyncTransformation::RestoreFromGCPointersOnResumption(unsigned // // Parameters: // block - The block containing the async call -// remainder - The block that contains the IR after the (split) async call // resumeObjectArrLclNum - Local that has the continuation object's GC pointers array // layout - Layout information for the continuation object // resumeBB - Basic block to append IR to @@ -1833,7 +1961,6 @@ void AsyncTransformation::RestoreFromGCPointersOnResumption(unsigned // rethrow. // BasicBlock* AsyncTransformation::RethrowExceptionOnResumption(BasicBlock* block, - BasicBlock* remainder, unsigned resumeObjectArrLclNum, const ContinuationLayout& layout, BasicBlock* resumeBB) @@ -1851,6 +1978,7 @@ BasicBlock* AsyncTransformation::RethrowExceptionOnResumption(BasicBlock* FlowEdge* storeResultEdge = m_comp->fgAddRefPred(storeResultBB, resumeBB); assert(resumeBB->KindIs(BBJ_ALWAYS)); + BasicBlock* remainder = resumeBB->GetTarget(); m_comp->fgRemoveRefPred(resumeBB->GetTargetEdge()); resumeBB->SetCond(rethrowEdge, storeResultEdge); @@ -1981,6 +2109,9 @@ void AsyncTransformation::CopyReturnValueOnResumption(GenTreeCall* LIR::AsRange(storeResultBB).InsertAtEnd(LIR::SeqTree(m_comp, storeResultBase)); resultBase = m_comp->gtNewLclVarNode(resultBaseVar, TYP_REF); + + // Can be reallocated by above call to GetResultBaseVar + resultLcl = m_comp->lvaGetDesc(callDefInfo.DefinitionNode); } assert(callDefInfo.DefinitionNode->OperIs(GT_STORE_LCL_VAR)); @@ -2305,8 +2436,8 @@ void AsyncTransformation::CreateResumptionSwitch() } } - BBswtDesc* const swtDesc = - new (m_comp, CMK_BasicBlock) BBswtDesc(cases, (unsigned)numCases, succs, numUniqueSuccs, true); + BBswtDesc* const swtDesc = new (m_comp, CMK_BasicBlock) + BBswtDesc(succs, numUniqueSuccs, cases, (unsigned)numCases, /* hasDefault */ true); switchBB->SetSwitch(swtDesc); } diff --git a/src/coreclr/jit/async.h b/src/coreclr/jit/async.h index e75f2ac8d157f6..e30aaf760e6395 100644 --- a/src/coreclr/jit/async.h +++ b/src/coreclr/jit/async.h @@ -84,6 +84,8 @@ class AsyncTransformation GenTreeCall* call, jitstd::vector& liveLocals); + void ClearSuspendedIndicator(BasicBlock* block, GenTreeCall* call); + CallDefinitionInfo CanonicalizeCallDefinition(BasicBlock* block, GenTreeCall* call, AsyncLiveness& life); BasicBlock* CreateSuspension( @@ -92,20 +94,21 @@ class AsyncTransformation GenTree* prevContinuation, unsigned gcRefsCount, unsigned int dataSize); - void FillInGCPointersOnSuspension(const ContinuationLayout& layout, BasicBlock* suspendBB); - void FillInDataOnSuspension(const jitstd::vector& liveLocals, BasicBlock* suspendBB); - void CreateCheckAndSuspendAfterCall(BasicBlock* block, - const CallDefinitionInfo& callDefInfo, - AsyncLiveness& life, - BasicBlock* suspendBB, - BasicBlock** remainder); - + void FillInGCPointersOnSuspension(GenTreeCall* call, const ContinuationLayout& layout, BasicBlock* suspendBB); + void FillInDataOnSuspension(const jitstd::vector& liveLocals, BasicBlock* suspendBB); + void CreateCheckAndSuspendAfterCall(BasicBlock* block, + GenTreeCall* call, + const CallDefinitionInfo& callDefInfo, + AsyncLiveness& life, + BasicBlock* suspendBB, + BasicBlock** remainder); BasicBlock* CreateResumption(BasicBlock* block, BasicBlock* remainder, GenTreeCall* call, const CallDefinitionInfo& callDefInfo, unsigned stateNum, const ContinuationLayout& layout); + void SetSuspendedIndicator(BasicBlock* block, BasicBlock* callBlock, GenTreeCall* call); void RestoreFromDataOnResumption(unsigned resumeByteArrLclNum, const jitstd::vector& liveLocals, BasicBlock* resumeBB); @@ -113,7 +116,6 @@ class AsyncTransformation const ContinuationLayout& layout, BasicBlock* resumeBB); BasicBlock* RethrowExceptionOnResumption(BasicBlock* block, - BasicBlock* remainder, unsigned resumeObjectArrLclNum, const ContinuationLayout& layout, BasicBlock* resumeBB); diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index d9e4f30f9e9985..e475d73240ad74 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -601,8 +601,7 @@ class LclVarDsc unsigned char lvIsMultiRegDest : 1; // true if this is a multireg LclVar struct that is stored from a multireg node #ifdef DEBUG - unsigned char lvHiddenBufferStructArg : 1; // True when this struct (or its field) are passed as hidden buffer - // pointer. + unsigned char lvDefinedViaAddress : 1; // True when this local may have LCL_ADDRs representing definitions #endif #ifdef FEATURE_HFA_FIELDS_PRESENT @@ -753,14 +752,14 @@ class LclVarDsc } #ifdef DEBUG - void SetHiddenBufferStructArg(char value) + void SetDefinedViaAddress(char value) { - lvHiddenBufferStructArg = value; + lvDefinedViaAddress = value; } - bool IsHiddenBufferStructArg() const + bool IsDefinedViaAddress() const { - return lvHiddenBufferStructArg; + return lvDefinedViaAddress; } #endif @@ -3741,6 +3740,7 @@ class Compiler bool gtIsTypeof(GenTree* tree, CORINFO_CLASS_HANDLE* handle = nullptr); GenTreeLclVarCommon* gtCallGetDefinedRetBufLclAddr(GenTreeCall* call); + GenTreeLclVarCommon* gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(GenTreeCall* call); //------------------------------------------------------------------------- // Functions to display the trees diff --git a/src/coreclr/jit/compiler.hpp b/src/coreclr/jit/compiler.hpp index 5d271f8c714311..c0a37daff4f6fa 100644 --- a/src/coreclr/jit/compiler.hpp +++ b/src/coreclr/jit/compiler.hpp @@ -4713,7 +4713,21 @@ GenTree::VisitResult GenTree::VisitLocalDefs(Compiler* comp, TVisitor visitor) } if (OperIs(GT_CALL)) { - GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(AsCall()); + GenTreeCall* call = AsCall(); + if (call->IsAsync()) + { + GenTreeLclVarCommon* suspendedArg = comp->gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(call); + if (suspendedArg != nullptr) + { + bool isEntire = comp->lvaLclExactSize(suspendedArg->GetLclNum()) == 1; + if (visitor(LocalDef(suspendedArg, isEntire, suspendedArg->GetLclOffs(), 1)) == VisitResult::Abort) + { + return VisitResult::Abort; + } + } + } + + GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(call); if (lclAddr != nullptr) { unsigned storeSize = comp->typGetObjLayout(AsCall()->gtRetClsHnd)->GetSize(); @@ -4755,7 +4769,17 @@ GenTree::VisitResult GenTree::VisitLocalDefNodes(Compiler* comp, TVisitor visito } if (OperIs(GT_CALL)) { - GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(AsCall()); + GenTreeCall* call = AsCall(); + if (call->IsAsync()) + { + GenTreeLclVarCommon* suspendedArg = comp->gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(call); + if ((suspendedArg != nullptr) && (visitor(suspendedArg) == VisitResult::Abort)) + { + return VisitResult::Abort; + } + } + + GenTreeLclVarCommon* lclAddr = comp->gtCallGetDefinedRetBufLclAddr(call); if (lclAddr != nullptr) { return visitor(lclAddr); diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index f33f309bece870..12ef5b2f1d8895 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -1829,6 +1829,46 @@ void CallArgs::Remove(CallArg* arg) assert(!"Did not find arg to remove in CallArgs::Remove"); } +//--------------------------------------------------------------- +// RemoveUnsafe: Remove an argument from the argument list, without validation. +// +// Parameters: +// arg - The arg to remove. +// +// Remarks: +// This function will break ABI information of other arguments. The caller +// needs to know what they are doing. +// +void CallArgs::RemoveUnsafe(CallArg* arg) +{ + CallArg** slot = &m_lateHead; + while (*slot != nullptr) + { + if (*slot == arg) + { + *slot = arg->GetLateNext(); + break; + } + + slot = &(*slot)->LateNextRef(); + } + + slot = &m_head; + while (*slot != nullptr) + { + if (*slot == arg) + { + *slot = arg->GetNext(); + RemovedWellKnownArg(arg->GetWellKnownArg()); + return; + } + + slot = &(*slot)->NextRef(); + } + + assert(!"Did not find arg to remove in CallArgs::Remove"); +} + #ifdef TARGET_XARCH //--------------------------------------------------------------- // NeedsVzeroupper: Determines if the call needs a vzeroupper emitted before it is invoked @@ -9908,7 +9948,7 @@ GenTreeCall* Compiler::gtCloneExprCallHelper(GenTreeCall* tree) } else if (tree->IsAsync()) { - copy->asyncInfo = tree->asyncInfo; + copy->asyncInfo = new (this, CMK_Async) AsyncCallInfo(*tree->asyncInfo); } else if (tree->IsTailPrefixedCall()) { @@ -11571,9 +11611,9 @@ void Compiler::gtDispNode(GenTree* tree, IndentStack* indentStack, _In_ _In_opt_ { printf("(AX)"); // Variable has address exposed. } - if (varDsc->IsHiddenBufferStructArg()) + if (varDsc->IsDefinedViaAddress()) { - printf("(RB)"); // Variable is hidden return buffer + printf("(DA)"); // Variable is defined via address } if (varDsc->lvUnusedStruct) { @@ -13177,6 +13217,8 @@ const char* Compiler::gtGetWellKnownArgNameForArgMsg(WellKnownArg arg) return "&lcl arr"; case WellKnownArg::RuntimeMethodHandle: return "meth hnd"; + case WellKnownArg::AsyncSuspendedIndicator: + return "async susp"; default: return nullptr; } @@ -19555,7 +19597,33 @@ GenTreeLclVarCommon* Compiler::gtCallGetDefinedRetBufLclAddr(GenTreeCall* call) // This may be called very late to check validity of LIR. node = node->gtSkipReloadOrCopy(); - assert(node->OperIs(GT_LCL_ADDR) && lvaGetDesc(node->AsLclVarCommon())->IsHiddenBufferStructArg()); + assert(node->OperIs(GT_LCL_ADDR) && lvaGetDesc(node->AsLclVarCommon())->IsDefinedViaAddress()); + + return node->AsLclVarCommon(); +} + +//------------------------------------------------------------------------ +// gtCallGetDefinedAsyncSuspendedIndicatorLclAddr: +// Get the tree corresponding to the address of the indicator local that this call defines. +// +// Parameters: +// call - the Call node +// +// Returns: +// A tree representing the address of a local. +// +GenTreeLclVarCommon* Compiler::gtCallGetDefinedAsyncSuspendedIndicatorLclAddr(GenTreeCall* call) +{ + if (!call->IsAsync() || !call->GetAsyncInfo().HasSuspensionIndicatorDef) + { + return nullptr; + } + + CallArg* asyncSuspensionIndicatorArg = call->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); + assert(asyncSuspensionIndicatorArg != nullptr); + GenTree* node = asyncSuspensionIndicatorArg->GetNode(); + + assert(node->OperIs(GT_LCL_ADDR) && lvaGetDesc(node->AsLclVarCommon())->IsDefinedViaAddress()); return node->AsLclVarCommon(); } diff --git a/src/coreclr/jit/gentree.h b/src/coreclr/jit/gentree.h index c618e5b84d1027..67be572a3b04a3 100644 --- a/src/coreclr/jit/gentree.h +++ b/src/coreclr/jit/gentree.h @@ -4359,8 +4359,42 @@ enum class ContinuationContextHandling // Additional async call info. struct AsyncCallInfo { - ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None; - ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None; + // The following information is used to implement the proper observable handling of `ExecutionContext`, + // `SynchronizationContext` and `TaskScheduler` in async methods. + // + // The breakdown of the handling is as follows: + // + // - For custom awaitables there is no special handling of `SynchronizationContext` or `TaskScheduler`. All the + // handling that exists is custom implemented by the user. In this case "ContinuationContextHandling == None" and + // "SaveAndRestoreSynchronizationContextField == false". + // + // - For custom awaitables there _is_ special handling of `ExecutionContext`: when the custom awaitable suspends, + // the JIT ensures that the `ExecutionContext` will be captured on suspension and restored when the continuation is + // running. This is represented by "ExecutionContextHandling == AsyncSaveAndRestore". + // + // - For task awaits there is special handling of `SynchronizationContext` and `TaskScheduler` in multiple ways: + // + // * The JIT ensures that `Thread.CurrentThread._synchronizationContext` is saved and restored around + // synchronously finishing calls. This is represented by "SaveAndRestoreSynchronizationContextField == true". + // + // * The JIT/runtime/BCL ensure that when the callee suspends, the caller will eventually be resumed on the + // `SynchronizationContext`/`TaskScheduler` present before the call started, depending on the configuration of the + // task await by the user. This resumption can be inlined if the `SynchronizationContext` is current when the + // continuation is about to run, and otherwise will be posted to it. This is represented by + // "ContinuationContextHandling == ContinueOnCapturedContext/ContinueOnThreadPool". + // + // * When the callee suspends restoration of `Thread.CurrentThread._synchronizationContext` is left up to the + // custom implementation of the `SynchronizationContext`, it must not be done by the JIT. + // + // - For task awaits the runtime/BCL ensure that `Thread.CurrentThread._executionContext` is captured before the + // call and restored after it. This happens consistently regardless of whether the callee finishes synchronously or + // not. This is represented by "ExecutionContextHandling == SaveAndRestore". + // + ExecutionContextHandling ExecutionContextHandling = ExecutionContextHandling::None; + ContinuationContextHandling ContinuationContextHandling = ContinuationContextHandling::None; + bool SaveAndRestoreSynchronizationContextField = false; + bool HasSuspensionIndicatorDef = false; + unsigned SynchronizationContextLclNum = BAD_VAR_NUM; }; // Return type descriptor of a GT_CALL node. @@ -4633,6 +4667,7 @@ enum class WellKnownArg : unsigned X86TailCallSpecialArg, StackArrayLocal, RuntimeMethodHandle, + AsyncSuspendedIndicator, }; #ifdef DEBUG @@ -4847,6 +4882,7 @@ class CallArgs CallArg* InsertAfterThisOrFirst(Compiler* comp, const NewCallArg& arg); void PushLateBack(CallArg* arg); void Remove(CallArg* arg); + void RemoveUnsafe(CallArg* arg); template void InternalCopyFrom(Compiler* comp, CallArgs* other, CopyNodeFunc copyFunc); @@ -5020,7 +5056,7 @@ struct GenTreeCall final : public GenTree // Only used for unmanaged calls, which cannot be tail-called CorInfoCallConvExtension unmgdCallConv; // Used for async calls - const AsyncCallInfo* asyncInfo; + AsyncCallInfo* asyncInfo; }; #if FEATURE_MULTIREG_RET @@ -5078,7 +5114,7 @@ struct GenTreeCall final : public GenTree #endif } - void SetIsAsync(const AsyncCallInfo* info) + void SetIsAsync(AsyncCallInfo* info) { assert(info != nullptr); gtCallMoreFlags |= GTF_CALL_M_ASYNC; diff --git a/src/coreclr/jit/gschecks.cpp b/src/coreclr/jit/gschecks.cpp index 75d5a365a5871f..4d7a54fcdcc2ff 100644 --- a/src/coreclr/jit/gschecks.cpp +++ b/src/coreclr/jit/gschecks.cpp @@ -410,7 +410,7 @@ void Compiler::gsParamsToShadows() shadowVarDsc->lvDoNotEnregister = varDsc->lvDoNotEnregister; #ifdef DEBUG shadowVarDsc->SetDoNotEnregReason(varDsc->GetDoNotEnregReason()); - shadowVarDsc->SetHiddenBufferStructArg(varDsc->IsHiddenBufferStructArg()); + shadowVarDsc->SetDefinedViaAddress(varDsc->IsDefinedViaAddress()); #endif if (varTypeIsStruct(type)) diff --git a/src/coreclr/jit/importercalls.cpp b/src/coreclr/jit/importercalls.cpp index 79581b183fc04c..e1570911031abf 100644 --- a/src/coreclr/jit/importercalls.cpp +++ b/src/coreclr/jit/importercalls.cpp @@ -705,7 +705,8 @@ var_types Compiler::impImportCall(OPCODE opcode, { JITDUMP("Call is an async task await\n"); - asyncInfo.ExecutionContextHandling = ExecutionContextHandling::SaveAndRestore; + asyncInfo.ExecutionContextHandling = ExecutionContextHandling::SaveAndRestore; + asyncInfo.SaveAndRestoreSynchronizationContextField = true; if ((prefixFlags & PREFIX_TASK_AWAIT_CONTINUE_ON_CAPTURED_CONTEXT) != 0) { @@ -729,7 +730,7 @@ var_types Compiler::impImportCall(OPCODE opcode, asyncInfo.ExecutionContextHandling = ExecutionContextHandling::AsyncSaveAndRestore; } - // For tailcalls the context does not need saving/restoring: it will be + // For tailcalls the contexts does not need saving/restoring: they will be // overwritten by the caller anyway. // // More specifically, if we can show that @@ -738,7 +739,9 @@ var_types Compiler::impImportCall(OPCODE opcode, // context. We do not do that optimization yet. if (tailCallFlags != 0) { - asyncInfo.ExecutionContextHandling = ExecutionContextHandling::None; + asyncInfo.ExecutionContextHandling = ExecutionContextHandling::None; + asyncInfo.ContinuationContextHandling = ContinuationContextHandling::None; + asyncInfo.SaveAndRestoreSynchronizationContextField = false; } call->AsCall()->SetIsAsync(new (this, CMK_Async) AsyncCallInfo(asyncInfo)); diff --git a/src/coreclr/jit/lclmorph.cpp b/src/coreclr/jit/lclmorph.cpp index 7366dfe9bce1fd..e403d662b5b683 100644 --- a/src/coreclr/jit/lclmorph.cpp +++ b/src/coreclr/jit/lclmorph.cpp @@ -1462,9 +1462,9 @@ class LocalAddressVisitor final : public GenTreeVisitor unsigned lclNum = val.LclNum(); LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclNum); - GenTreeFlags defFlag = GTF_EMPTY; - GenTreeCall* callUser = (user != nullptr) && user->IsCall() ? user->AsCall() : nullptr; - bool hasHiddenStructArg = false; + GenTreeFlags defFlag = GTF_EMPTY; + GenTreeCall* callUser = (user != nullptr) && user->IsCall() ? user->AsCall() : nullptr; + bool escapeAddr = true; if (m_compiler->opts.compJitOptimizeStructHiddenBuffer && (callUser != nullptr) && m_compiler->IsValidLclAddr(lclNum, val.Offset())) { @@ -1484,7 +1484,7 @@ class LocalAddressVisitor final : public GenTreeVisitor (val.Node() == callUser->gtArgs.GetRetBufferArg()->GetNode())) { m_compiler->lvaSetHiddenBufferStructArg(lclNum); - hasHiddenStructArg = true; + escapeAddr = false; callUser->gtCallMoreFlags |= GTF_CALL_M_RETBUFFARG_LCLOPT; defFlag = GTF_VAR_DEF; @@ -1496,7 +1496,25 @@ class LocalAddressVisitor final : public GenTreeVisitor } } - if (!hasHiddenStructArg) + if ((callUser != nullptr) && callUser->IsAsync() && m_compiler->IsValidLclAddr(lclNum, val.Offset())) + { + CallArg* suspendedArg = callUser->gtArgs.FindWellKnownArg(WellKnownArg::AsyncSuspendedIndicator); + if ((suspendedArg != nullptr) && (val.Node() == suspendedArg->GetNode())) + { + INDEBUG(varDsc->SetDefinedViaAddress(true)); + escapeAddr = false; + defFlag = GTF_VAR_DEF; + + if ((val.Offset() != 0) || (varDsc->lvExactSize() != 1)) + { + defFlag |= GTF_VAR_USEASG; + } + + callUser->asyncInfo->HasSuspensionIndicatorDef = true; + } + } + + if (escapeAddr) { unsigned exposedLclNum = varDsc->lvIsStructField ? varDsc->lvParentLcl : lclNum; @@ -1516,7 +1534,8 @@ class LocalAddressVisitor final : public GenTreeVisitor // a ByRef to an INT32 when they actually write a SIZE_T or INT64. There are cases where // overwriting these extra 4 bytes corrupts some data (such as a saved register) that leads // to A/V. Whereas previously the JIT64 codegen did not lead to an A/V. - if ((callUser != nullptr) && !varDsc->lvIsParam && !varDsc->lvIsStructField && genActualTypeIsInt(varDsc)) + if ((callUser != nullptr) && !varDsc->lvIsParam && !varDsc->lvIsStructField && genActualTypeIsInt(varDsc) && + escapeAddr) { varDsc->lvQuirkToLong = true; JITDUMP("Adding a quirk for the storage size of V%02u of type %s\n", val.LclNum(), @@ -1618,7 +1637,7 @@ class LocalAddressVisitor final : public GenTreeVisitor assert(addr->TypeIs(TYP_BYREF, TYP_I_IMPL)); assert(m_compiler->lvaVarAddrExposed(lclNum) || ((m_lclAddrAssertions != nullptr) && m_lclAddrAssertions->IsMarkedForExposure(lclNum)) || - m_compiler->lvaGetDesc(lclNum)->IsHiddenBufferStructArg()); + m_compiler->lvaGetDesc(lclNum)->IsDefinedViaAddress()); if (m_compiler->IsValidLclAddr(lclNum, offset)) { diff --git a/src/coreclr/jit/lclvars.cpp b/src/coreclr/jit/lclvars.cpp index 07e584dbe441eb..f96943f2069bbe 100644 --- a/src/coreclr/jit/lclvars.cpp +++ b/src/coreclr/jit/lclvars.cpp @@ -2089,7 +2089,7 @@ void Compiler::lvaSetHiddenBufferStructArg(unsigned varNum) LclVarDsc* varDsc = lvaGetDesc(varNum); #ifdef DEBUG - varDsc->SetHiddenBufferStructArg(true); + varDsc->SetDefinedViaAddress(true); #endif if (varDsc->lvPromoted) @@ -2100,7 +2100,7 @@ void Compiler::lvaSetHiddenBufferStructArg(unsigned varNum) { noway_assert(lvaTable[i].lvIsStructField); #ifdef DEBUG - lvaTable[i].SetHiddenBufferStructArg(true); + lvaTable[i].SetDefinedViaAddress(true); #endif lvaSetVarDoNotEnregister(i DEBUGARG(DoNotEnregisterReason::HiddenBufferStructArg)); @@ -3435,7 +3435,7 @@ void Compiler::lvaMarkLclRefs(GenTree* tree, BasicBlock* block, Statement* stmt, if (tree->OperIs(GT_LCL_ADDR)) { LclVarDsc* varDsc = lvaGetDesc(tree->AsLclVarCommon()); - assert(varDsc->IsAddressExposed() || varDsc->IsHiddenBufferStructArg()); + assert(varDsc->IsAddressExposed() || varDsc->IsDefinedViaAddress()); varDsc->incRefCnts(weight, this); return; } @@ -6365,9 +6365,9 @@ void Compiler::lvaDumpEntry(unsigned lclNum, FrameLayoutState curState, size_t r { printf("X"); } - if (varDsc->IsHiddenBufferStructArg()) + if (varDsc->IsDefinedViaAddress()) { - printf("H"); + printf("DA"); } if (varTypeIsStruct(varDsc)) { @@ -6428,9 +6428,9 @@ void Compiler::lvaDumpEntry(unsigned lclNum, FrameLayoutState curState, size_t r { printf(" addr-exposed"); } - if (varDsc->IsHiddenBufferStructArg()) + if (varDsc->IsDefinedViaAddress()) { - printf(" hidden-struct-arg"); + printf(" defined-via-address"); } if (varDsc->lvHasLdAddrOp) { diff --git a/src/coreclr/jit/liveness.cpp b/src/coreclr/jit/liveness.cpp index 08ef18528c1867..617067e7aeefa8 100644 --- a/src/coreclr/jit/liveness.cpp +++ b/src/coreclr/jit/liveness.cpp @@ -66,9 +66,10 @@ void Compiler::fgMarkUseDef(GenTreeLclVarCommon* tree) if (compRationalIRForm && (varDsc->lvType != TYP_STRUCT) && !varTypeIsMultiReg(varDsc)) { - // If this is an enregisterable variable that is not marked doNotEnregister, + // If this is an enregisterable variable that is not marked doNotEnregister and not defined via address, // we should only see direct references (not ADDRs). - assert(varDsc->lvDoNotEnregister || tree->OperIs(GT_LCL_VAR, GT_STORE_LCL_VAR)); + assert(varDsc->lvDoNotEnregister || varDsc->lvDefinedViaAddress || + tree->OperIs(GT_LCL_VAR, GT_STORE_LCL_VAR)); } if (isUse && !VarSetOps::IsMember(this, fgCurDefSet, varDsc->lvVarIndex)) diff --git a/src/coreclr/jit/morph.cpp b/src/coreclr/jit/morph.cpp index 84a5de5a58f6df..11e501c2f2b13d 100644 --- a/src/coreclr/jit/morph.cpp +++ b/src/coreclr/jit/morph.cpp @@ -678,6 +678,8 @@ const char* getWellKnownArgName(WellKnownArg arg) return "StackArrayLocal"; case WellKnownArg::RuntimeMethodHandle: return "RuntimeMethodHandle"; + case WellKnownArg::AsyncSuspendedIndicator: + return "AsyncSuspendedIndicator"; } return "N/A"; @@ -1842,8 +1844,6 @@ void CallArgs::AddFinalArgsAndDetermineABIInfo(Compiler* comp, GenTreeCall* call const CORINFO_CLASS_HANDLE argSigClass = arg.GetSignatureClassHandle(); ClassLayout* argLayout = argSigClass == NO_CLASS_HANDLE ? nullptr : comp->typGetObjLayout(argSigClass); - ABIPassingInformation abiInfo; - // Some well known args have custom register assignment. // These should not affect the placement of any other args or stack space required. // Example: on AMD64 R10 and R11 are used for indirect VSD (generic interface) and cookie calls. @@ -1852,20 +1852,26 @@ void CallArgs::AddFinalArgsAndDetermineABIInfo(Compiler* comp, GenTreeCall* call if (nonStdRegNum == REG_NA) { - abiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); + if (arg.GetWellKnownArg() == WellKnownArg::AsyncSuspendedIndicator) + { + // Represents definition of a local. Expanded out by async transformation. + arg.AbiInfo = ABIPassingInformation(comp, 0); + } + else + { + arg.AbiInfo = classifier.Classify(comp, argSigType, argLayout, arg.GetWellKnownArg()); + } } else { ABIPassingSegment segment = ABIPassingSegment::InRegister(nonStdRegNum, 0, TARGET_POINTER_SIZE); - abiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); + arg.AbiInfo = ABIPassingInformation::FromSegmentByValue(comp, segment); } JITDUMP("Argument %u ABI info: ", GetIndex(&arg)); - DBEXEC(VERBOSE, abiInfo.Dump()); - - arg.AbiInfo = abiInfo; + DBEXEC(VERBOSE, arg.AbiInfo.Dump()); - for (const ABIPassingSegment& segment : abiInfo.Segments()) + for (const ABIPassingSegment& segment : arg.AbiInfo.Segments()) { if (segment.IsPassedOnStack()) { @@ -1919,6 +1925,13 @@ void CallArgs::DetermineABIInfo(Compiler* comp, GenTreeCall* call) for (CallArg& arg : Args()) { + if (arg.GetWellKnownArg() == WellKnownArg::AsyncSuspendedIndicator) + { + // Represents definition of a local. Expanded out by async transformation. + arg.AbiInfo = ABIPassingInformation(comp, 0); + continue; + } + const var_types argSigType = arg.GetSignatureType(); const CORINFO_CLASS_HANDLE argSigClass = arg.GetSignatureClassHandle(); ClassLayout* argLayout = argSigClass == NO_CLASS_HANDLE ? nullptr : comp->typGetObjLayout(argSigClass); diff --git a/src/coreclr/jit/promotion.cpp b/src/coreclr/jit/promotion.cpp index b4e2e284d051e4..8a7d53f27e172b 100644 --- a/src/coreclr/jit/promotion.cpp +++ b/src/coreclr/jit/promotion.cpp @@ -1126,7 +1126,7 @@ class LocalsUseVisitor : public GenTreeVisitor if (lcl->OperIs(GT_LCL_ADDR)) { - assert(user->OperIs(GT_CALL) && dsc->IsHiddenBufferStructArg() && + assert(user->OperIs(GT_CALL) && dsc->IsDefinedViaAddress() && (user->AsCall()->gtArgs.GetRetBufferArg()->GetNode() == lcl)); accessType = TYP_STRUCT; diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index 962287ba6a517e..01f0d4f0509a42 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -12543,7 +12543,7 @@ void Compiler::fgValueNumberTree(GenTree* tree) unsigned lclOffs = tree->AsLclFld()->GetLclOffs(); tree->gtVNPair.SetBoth(vnStore->VNForFunc(TYP_BYREF, VNF_PtrToLoc, vnStore->VNForIntCon(lclNum), vnStore->VNForIntPtrCon(lclOffs))); - assert(lvaGetDesc(lclNum)->IsAddressExposed() || lvaGetDesc(lclNum)->IsHiddenBufferStructArg()); + assert(lvaGetDesc(lclNum)->IsAddressExposed() || lvaGetDesc(lclNum)->IsDefinedViaAddress()); } break; diff --git a/src/coreclr/tools/superpmi/superpmi-shared/agnostic.h b/src/coreclr/tools/superpmi/superpmi-shared/agnostic.h index 1a04979bad37a5..3097e47190040e 100644 --- a/src/coreclr/tools/superpmi/superpmi-shared/agnostic.h +++ b/src/coreclr/tools/superpmi/superpmi-shared/agnostic.h @@ -205,6 +205,9 @@ struct Agnostic_CORINFO_ASYNC_INFO DWORD continuationsNeedMethodHandle; DWORDLONG captureExecutionContextMethHnd; DWORDLONG restoreExecutionContextMethHnd; + DWORDLONG captureContinuationContextMethHnd; + DWORDLONG captureContextsMethHnd; + DWORDLONG restoreContextsMethHnd; }; struct Agnostic_GetOSRInfo diff --git a/src/coreclr/tools/superpmi/superpmi-shared/methodcontext.cpp b/src/coreclr/tools/superpmi/superpmi-shared/methodcontext.cpp index 152e1299e89465..66c7b548efd8d2 100644 --- a/src/coreclr/tools/superpmi/superpmi-shared/methodcontext.cpp +++ b/src/coreclr/tools/superpmi/superpmi-shared/methodcontext.cpp @@ -4486,6 +4486,9 @@ void MethodContext::recGetAsyncInfo(const CORINFO_ASYNC_INFO* pAsyncInfo) value.continuationsNeedMethodHandle = pAsyncInfo->continuationsNeedMethodHandle ? 1 : 0; value.captureExecutionContextMethHnd = CastHandle(pAsyncInfo->captureExecutionContextMethHnd); value.restoreExecutionContextMethHnd = CastHandle(pAsyncInfo->restoreExecutionContextMethHnd); + value.captureContinuationContextMethHnd = CastHandle(pAsyncInfo->captureContinuationContextMethHnd); + value.captureContextsMethHnd = CastHandle(pAsyncInfo->captureContextsMethHnd); + value.restoreContextsMethHnd = CastHandle(pAsyncInfo->restoreContextsMethHnd); GetAsyncInfo->Add(0, value); DEBUG_REC(dmpGetAsyncInfo(0, value)); @@ -4511,6 +4514,9 @@ void MethodContext::repGetAsyncInfo(CORINFO_ASYNC_INFO* pAsyncInfoOut) pAsyncInfoOut->continuationsNeedMethodHandle = value.continuationsNeedMethodHandle != 0; pAsyncInfoOut->captureExecutionContextMethHnd = (CORINFO_METHOD_HANDLE)value.captureExecutionContextMethHnd; pAsyncInfoOut->restoreExecutionContextMethHnd = (CORINFO_METHOD_HANDLE)value.restoreExecutionContextMethHnd; + pAsyncInfoOut->captureContinuationContextMethHnd = (CORINFO_METHOD_HANDLE)value.captureContinuationContextMethHnd; + pAsyncInfoOut->captureContextsMethHnd = (CORINFO_METHOD_HANDLE)value.captureContextsMethHnd; + pAsyncInfoOut->restoreContextsMethHnd = (CORINFO_METHOD_HANDLE)value.restoreContextsMethHnd; DEBUG_REP(dmpGetAsyncInfo(0, value)); } diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 5b894a5312d97c..3e47ccfab8068e 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -729,6 +729,8 @@ DEFINE_METHOD(ASYNC_HELPERS, UNSAFE_AWAIT_AWAITER_1, UnsafeAwaitAwaiter, DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_EXECUTION_CONTEXT, CaptureExecutionContext, NoSig) DEFINE_METHOD(ASYNC_HELPERS, RESTORE_EXECUTION_CONTEXT, RestoreExecutionContext, NoSig) DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_CONTINUATION_CONTEXT, CaptureContinuationContext, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_CONTEXTS, CaptureContexts, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, RESTORE_CONTEXTS, RestoreContexts, NoSig) DEFINE_CLASS(SPAN_HELPERS, System, SpanHelpers) DEFINE_METHOD(SPAN_HELPERS, MEMSET, Fill, SM_RefByte_Byte_UIntPtr_RetVoid) diff --git a/src/coreclr/vm/jitinterface.cpp b/src/coreclr/vm/jitinterface.cpp index c1bd8aad1fc45d..386a585ba559a2 100644 --- a/src/coreclr/vm/jitinterface.cpp +++ b/src/coreclr/vm/jitinterface.cpp @@ -10258,6 +10258,8 @@ void CEEInfo::getAsyncInfo(CORINFO_ASYNC_INFO* pAsyncInfoOut) pAsyncInfoOut->captureExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_EXECUTION_CONTEXT)); pAsyncInfoOut->restoreExecutionContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__RESTORE_EXECUTION_CONTEXT)); pAsyncInfoOut->captureContinuationContextMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_CONTINUATION_CONTEXT)); + pAsyncInfoOut->captureContextsMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__CAPTURE_CONTEXTS)); + pAsyncInfoOut->restoreContextsMethHnd = CORINFO_METHOD_HANDLE(CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__RESTORE_CONTEXTS)); EE_TO_JIT_TRANSITION(); } diff --git a/src/tests/async/synchronization-context/synchronization-context.cs b/src/tests/async/synchronization-context/synchronization-context.cs index e028ee09c9b061..e3a5e2eed606b4 100644 --- a/src/tests/async/synchronization-context/synchronization-context.cs +++ b/src/tests/async/synchronization-context/synchronization-context.cs @@ -11,13 +11,13 @@ public class Async2SynchronizationContext { [Fact] - public static void TestSyncContexts() + public static void TestSyncContextContinue() { SynchronizationContext prevContext = SynchronizationContext.Current; try { SynchronizationContext.SetSynchronizationContext(new MySyncContext()); - TestSyncContext().GetAwaiter().GetResult(); + TestSyncContextContinueAsync().GetAwaiter().GetResult(); } finally { @@ -25,7 +25,7 @@ public static void TestSyncContexts() } } - private static async Task TestSyncContext() + private static async Task TestSyncContextContinueAsync() { MySyncContext context = (MySyncContext)SynchronizationContext.Current; await WrappedYieldToThreadPool(suspend: false); @@ -104,4 +104,113 @@ public void OnCompleted(Action continuation) public void GetResult() { } } + + [Fact] + public static void TestSyncContextSaveRestore() + { + SynchronizationContext prevContext = SynchronizationContext.Current; + try + { + SynchronizationContext.SetSynchronizationContext(new SyncContextWithoutRestore()); + TestSyncContextSaveRestoreAsync().GetAwaiter().GetResult(); + } + finally + { + SynchronizationContext.SetSynchronizationContext(prevContext); + } + } + + private static async Task TestSyncContextSaveRestoreAsync() + { + Assert.True(SynchronizationContext.Current is SyncContextWithoutRestore); + await ClearSyncContext(); + Assert.True(SynchronizationContext.Current is SyncContextWithoutRestore); + } + + private static async Task ClearSyncContext() + { + SynchronizationContext.SetSynchronizationContext(null); + } + + [Fact] + public static void TestSyncContextNotRestored() + { + SynchronizationContext prevContext = SynchronizationContext.Current; + try + { + SynchronizationContext.SetSynchronizationContext(new SyncContextWithoutRestore()); + TestSyncContextNotRestoredAsync().GetAwaiter().GetResult(); + } + finally + { + SynchronizationContext.SetSynchronizationContext(prevContext); + } + } + + private static async Task TestSyncContextNotRestoredAsync() + { + Assert.True(SynchronizationContext.Current is SyncContextWithoutRestore); + await SuspendThenClearSyncContext(); + Assert.Null(SynchronizationContext.Current); + } + + private static async Task SuspendThenClearSyncContext() + { + Assert.True(SynchronizationContext.Current is SyncContextWithoutRestore); + SyncContextWithoutRestore syncCtx = (SyncContextWithoutRestore)SyncContextWithoutRestore.Current; + Assert.Equal(0, syncCtx.NumPosts); + + await Task.Yield(); + Assert.Null(SynchronizationContext.Current); + Assert.Equal(1, syncCtx.NumPosts); + } + + private class SyncContextWithoutRestore : SynchronizationContext + { + public int NumPosts; + + public override void Post(SendOrPostCallback d, object state) + { + NumPosts++; + ThreadPool.UnsafeQueueUserWorkItem(_ => + { + d(state); + }, null); + } + } + + [Fact] + public static void TestContinueOnCorrectSyncContext() + { + SynchronizationContext prevContext = SynchronizationContext.Current; + try + { + TestContinueOnCorrectSyncContextAsync().GetAwaiter().GetResult(); + } + finally + { + SynchronizationContext.SetSynchronizationContext(prevContext); + } + } + + private static async Task TestContinueOnCorrectSyncContextAsync() + { + MySyncContext context1 = new MySyncContext(); + MySyncContext context2 = new MySyncContext(); + + SynchronizationContext.SetSynchronizationContext(context1); + await SetContext(context2, suspend: false); + Assert.True(SynchronizationContext.Current == context1); + + await SetContext(context2, suspend: true); + Assert.True(SynchronizationContext.Current == context1); + } + + private static async Task SetContext(SynchronizationContext context, bool suspend) + { + SynchronizationContext.SetSynchronizationContext(context); + + if (suspend) + await Task.Yield(); + } }