diff --git a/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj b/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj index 10b04a41383d42..560eef2279e8fa 100644 --- a/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj +++ b/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj @@ -206,6 +206,7 @@ + diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index bf742b7a505f66..29948314acafbb 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -23,6 +23,8 @@ internal enum ContinuationFlags ContinueOnCapturedSynchronizationContext = 1 << 1, ContinueOnCapturedTaskScheduler = 1 << 2, + AllContinuationFlags = ContinueOnThreadPool | ContinueOnCapturedSynchronizationContext | ContinueOnCapturedTaskScheduler, + // The flags encode where in the continuation various members are stored. // If the encoded index is 0, it means no such member is present. // Otherwise the exact offset of the member is computed as @@ -213,7 +215,7 @@ private ref struct RuntimeAsyncStackState public ICriticalNotifyCompletion? CriticalNotifier; public INotifyCompletion? Notifier; public ValueTaskContinuation? ValueTaskContinuation; - public Task? TaskNotifier; + public TaskContinuation? TaskContinuation; // When we suspend in the leaf, the contexts are captured into these fields. public ExecutionContext? LeafExecutionContext; @@ -256,6 +258,7 @@ private unsafe struct RuntimeAsyncAwaitState { public Continuation? SentinelContinuation; public ValueTaskContinuation? CachedValueTaskContinuation; + public TaskContinuation? CachedTaskContinuation; // We cache the thread here to avoid unnecessary repeated TLS lookups. public Thread? CurrentThread; @@ -470,18 +473,61 @@ private static unsafe void AwaitValueTaskSourceOfT(object source, short token /// /// Used by internal thunks that implement awaiting on Task. /// - /// Task whose completion we are awaiting. + /// Task whose completion we are awaiting. + [BypassReadyToRun] + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] + private static unsafe void TransparentAwait(Task task) + { + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); + + TaskContinuation? taskCont = state.CachedTaskContinuation; + if (taskCont != null) + { + state.CachedTaskContinuation = null; + } + else + { + taskCont = new TaskContinuation(); + } + + taskCont.Initialize(task); + + sentinelContinuation.Next = taskCont; + state.StackState->TaskContinuation = taskCont; + + state.CaptureContexts(); + AsyncSuspend(taskCont); + } + + /// + /// Used by internal thunks that implement awaiting on Task. + /// + /// Task whose completion we are awaiting. [BypassReadyToRun] [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] - private static unsafe void TransparentAwait(Task t) + private static unsafe void TransparentAwaitOfT(Task task) { ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation(); - state.StackState->TaskNotifier = t; + TaskContinuation? taskCont = state.CachedTaskContinuation; + if (taskCont != null) + { + state.CachedTaskContinuation = null; + } + else + { + taskCont = new TaskContinuation(); + } + + taskCont.Initialize(task); + + sentinelContinuation.Next = taskCont; + state.StackState->TaskContinuation = taskCont; state.CaptureContexts(); - AsyncSuspend(sentinelContinuation); + AsyncSuspend(taskCont); } // Represents execution of a chain of suspended and resuming runtime @@ -553,14 +599,6 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) Continuation headContinuation = sentinelContinuation.Next!; sentinelContinuation.Next = null; - // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. - // These never have special continuation context handling. - // Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource - // which can capture continuation context flags. - const ContinuationFlags continueFlags = - ContinuationFlags.ContinueOnCapturedSynchronizationContext | - ContinuationFlags.ContinueOnThreadPool | - ContinuationFlags.ContinueOnCapturedTaskScheduler; SetContinuationState(headContinuation); @@ -568,16 +606,22 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) { if (stackState->CriticalNotifier is { } critNotifier) { - Debug.Assert((headContinuation.Flags & continueFlags) == 0); + // Result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. + // These never have special continuation context handling. + Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0); critNotifier.UnsafeOnCompleted(GetContinuationAction()); } - else if (stackState->TaskNotifier is { } taskNotifier) + else if (stackState->TaskContinuation is { } taskCont) { - Debug.Assert((headContinuation.Flags & continueFlags) == 0); + Debug.Assert(headContinuation == taskCont); + // Similarly for transparent awwaits we do not expect + // any continuation flags. + Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0); // Runtime async callable wrapper for task returning // method. This implements the context transparent // forwarding and makes these wrappers minimal cost. - if (!taskNotifier.TryAddCompletionAction(this)) + Debug.Assert(taskCont.Task != null); + if (!taskCont.Task.TryAddCompletionAction(this)) { ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); } @@ -589,7 +633,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) Debug.Assert(source != null); if (source is Task t) { - Debug.Assert((headContinuation.Flags & continueFlags) == 0); + Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0); if (!t.TryAddCompletionAction(this)) { ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true); @@ -612,7 +656,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) // the direct AsyncHelpers.Await(ValueTask/ValueTask) path. // In either case, that can only happen in nontransparent/user code. Continuation contWithContinueFlags = valueTaskSourceCont; - while ((contWithContinueFlags.Flags & continueFlags) == 0 && contWithContinueFlags.Next != null) + while ((contWithContinueFlags.Flags & ContinuationFlags.AllContinuationFlags) == 0 && contWithContinueFlags.Next != null) { contWithContinueFlags = contWithContinueFlags.Next; } @@ -630,7 +674,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) } // Clear continuation flags, so that continuation runs transparently - contWithContinueFlags.Flags &= ~continueFlags; + contWithContinueFlags.Flags &= ~ContinuationFlags.AllContinuationFlags; valueTaskSourceCont.OnCompletedValueTaskSource( source, @@ -642,7 +686,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state) } else { - Debug.Assert((headContinuation.Flags & continueFlags) == 0); + Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0); Debug.Assert(stackState->Notifier != null); stackState->Notifier!.OnCompleted(GetContinuationAction()); } diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs new file mode 100644 index 00000000000000..9c8b4b2512a2fd --- /dev/null +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Threading.Tasks; + +namespace System.Runtime.CompilerServices +{ + public static partial class AsyncHelpers + { + private sealed unsafe class TaskContinuation : Continuation + { + internal Task? Task; + private delegate* _getResult; + + public TaskContinuation() + { + ResumeInfo = (ResumeInfo*)Unsafe.AsPointer(in TaskContinuationResume.ResumeInfo); + } + + public void GetResult(ref byte returnValue) + { + Debug.Assert(Task != null); + + // Avoid retaining the task. The call below may throw. + Task task = Task; + Task = null; + + _getResult(task, ref returnValue); + } + + public void Initialize(Task task) + { + Task = task; + _getResult = &GetResult; + } + + public void Initialize(Task task) + { + Task = task; + _getResult = &GetResult; + } + + private static void GetResult(Task task, ref byte result) + { + TaskAwaiter.ValidateEnd(task); + } + + private static void GetResult(Task task, ref byte result) + { + Debug.Assert(task is Task); + + Task taskOfT = Unsafe.As>(ref task); + TaskAwaiter.ValidateEnd(taskOfT); + Unsafe.As(ref result) = taskOfT.ResultOnSuccess; + } + + private static class TaskContinuationResume + { + [FixedAddressValueType] + public static readonly ResumeInfo ResumeInfo = new ResumeInfo + { + DiagnosticIP = null, + Resume = &ResumeTaskContinuation, + }; + + private static Continuation? ResumeTaskContinuation(Continuation cont, ref byte result) + { + var taskCont = (TaskContinuation)cont; + taskCont.Next = null; + + Debug.Assert((taskCont.Flags & ContinuationFlags.AllContinuationFlags) == 0); + + t_runtimeAsyncAwaitState.CachedTaskContinuation = taskCont; + + taskCont.GetResult(ref result); + return null; + } + } + } + } +} diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs index 7d5af875d08a50..7e6693599738f8 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs @@ -106,12 +106,7 @@ private static class ValueTaskContinuationResume var vtsCont = (ValueTaskContinuation)cont; vtsCont.Next = null; - const ContinuationFlags continueFlags = - ContinuationFlags.ContinueOnCapturedSynchronizationContext | - ContinuationFlags.ContinueOnThreadPool | - ContinuationFlags.ContinueOnCapturedTaskScheduler; - - Debug.Assert((vtsCont.Flags & continueFlags) == 0); + Debug.Assert((vtsCont.Flags & ContinuationFlags.AllContinuationFlags) == 0); t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont; diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj index f687deaff880bc..3108c8dea494d1 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj @@ -64,6 +64,7 @@ + diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs index 8dac68d3d4bb6a..4773b38bc9aebd 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs @@ -341,6 +341,7 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t // Task path TypeDesc taskType = taskReturningMethodReturnType; MethodDesc completedTaskResultMethod; + MethodDesc transparentAwaitMethod; if (!taskReturningMethodReturnType.HasInstantiation) { @@ -348,6 +349,9 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t completedTaskResultMethod = context.SystemModule .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) .GetKnownMethod("CompletedTask"u8, null); + transparentAwaitMethod = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod("TransparentAwait"u8, null); } else { @@ -357,7 +361,12 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t MethodDesc completedTaskResultMethodOpen = context.SystemModule .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) .GetKnownMethod("CompletedTaskResult"u8, null); + MethodDesc transparentAwaitMethodOpen = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod("TransparentAwaitOfT"u8, null); + completedTaskResultMethod = completedTaskResultMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType)); + transparentAwaitMethod = transparentAwaitMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType)); } ILLocalVariable taskLocal = emitter.NewLocal(taskType); @@ -373,9 +382,8 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t codestream.Emit(ILOpcode.brtrue, getResultLabel); codestream.EmitLdLoc(taskLocal); - codestream.Emit(ILOpcode.call, emitter.NewToken( - context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) - .GetKnownMethod("TransparentAwait"u8, null))); + codestream.Emit(ILOpcode.call, emitter.NewToken(context.GetCoreLibEntryPoint("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8, "TailAwait"u8, null))); + codestream.Emit(ILOpcode.call, emitter.NewToken(transparentAwaitMethod)); codestream.EmitLabel(getResultLabel); codestream.EmitLdLoc(taskLocal); diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index c3dfa0a60fc84a..ed80ce9a6c19e0 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -487,7 +487,7 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig // Task task = other(arg); // if (!task.IsCompleted) // { - // // Magic function which will suspend the current run of async methods + // TailAwait(); // AsyncHelpers.TransparentAwait(task); // } // return AsyncHelpers.CompletedTaskResult(task); @@ -595,12 +595,17 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig MethodTable* pMTTask; int completedTaskResultToken; + int transparentAwaitToken; + if (msig.IsReturnTypeVoid()) { pMTTask = CoreLibBinder::GetClass(CLASS__TASK); MethodDesc* pMDCompletedTask = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK); + MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT); + completedTaskResultToken = pCode->GetToken(pMDCompletedTask); + transparentAwaitToken = pCode->GetToken(pMDTransparentAwait); } else { @@ -608,24 +613,33 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable(); MethodDesc* pMDCompletedTaskResult = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK_RESULT); + MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT_OF_T); + pMDCompletedTaskResult = FindOrCreateAssociatedMethodDesc(pMDCompletedTaskResult, pMDCompletedTaskResult->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE); + pMDTransparentAwait = FindOrCreateAssociatedMethodDesc(pMDTransparentAwait, pMDTransparentAwait->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE); + completedTaskResultToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDCompletedTaskResult); + transparentAwaitToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDTransparentAwait); } LocalDesc taskLocalDesc(pMTTask); DWORD taskLocal = pCode->NewLocal(taskLocalDesc); ILCodeLabel* pGetResultLabel = pCode->NewCodeLabel(); - // Store task returned by actual user func or by ValueTask.AsTask + // Store task returned by actual user func pCode->EmitSTLOC(taskLocal); + // Did it already complete? pCode->EmitLDLOC(taskLocal); pCode->EmitCALL(METHOD__TASK__GET_ISCOMPLETED, 1, 1); pCode->EmitBRTRUE(pGetResultLabel); + // No, so tail await to TransparentAwait pCode->EmitLDLOC(taskLocal); - pCode->EmitCALL(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT, 1, 0); + pCode->EmitCALL(METHOD__ASYNC_HELPERS__TAIL_AWAIT, 0, 0); + pCode->EmitCALL(transparentAwaitToken, 1, 0); + // Yes, so just get the result pCode->EmitLabel(pGetResultLabel); pCode->EmitLDLOC(taskLocal); pCode->EmitCALL(completedTaskResultToken, 1, msig.IsReturnTypeVoid() ? 0 : 1); diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 476878a901f333..cc515fd5f610e8 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -716,6 +716,7 @@ DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION, ValueTaskFromExcepti DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION_1, ValueTaskFromException, GM_Exception_RetValueTaskOfT) DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT, TransparentAwait, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_OF_T, TransparentAwaitOfT, NoSig) DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_VALUE_TASK, TransparentAwaitValueTask, NoSig) DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_VALUE_TASK_OF_T, TransparentAwaitValueTaskOfT, NoSig) DEFINE_METHOD(ASYNC_HELPERS, COMPLETED_TASK_RESULT, CompletedTaskResult, NoSig)