diff --git a/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj b/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj
index 10b04a41383d42..560eef2279e8fa 100644
--- a/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj
+++ b/src/coreclr/System.Private.CoreLib/System.Private.CoreLib.csproj
@@ -206,6 +206,7 @@
+
diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs
index bf742b7a505f66..29948314acafbb 100644
--- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs
+++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs
@@ -23,6 +23,8 @@ internal enum ContinuationFlags
ContinueOnCapturedSynchronizationContext = 1 << 1,
ContinueOnCapturedTaskScheduler = 1 << 2,
+ AllContinuationFlags = ContinueOnThreadPool | ContinueOnCapturedSynchronizationContext | ContinueOnCapturedTaskScheduler,
+
// The flags encode where in the continuation various members are stored.
// If the encoded index is 0, it means no such member is present.
// Otherwise the exact offset of the member is computed as
@@ -213,7 +215,7 @@ private ref struct RuntimeAsyncStackState
public ICriticalNotifyCompletion? CriticalNotifier;
public INotifyCompletion? Notifier;
public ValueTaskContinuation? ValueTaskContinuation;
- public Task? TaskNotifier;
+ public TaskContinuation? TaskContinuation;
// When we suspend in the leaf, the contexts are captured into these fields.
public ExecutionContext? LeafExecutionContext;
@@ -256,6 +258,7 @@ private unsafe struct RuntimeAsyncAwaitState
{
public Continuation? SentinelContinuation;
public ValueTaskContinuation? CachedValueTaskContinuation;
+ public TaskContinuation? CachedTaskContinuation;
// We cache the thread here to avoid unnecessary repeated TLS lookups.
public Thread? CurrentThread;
@@ -470,18 +473,61 @@ private static unsafe void AwaitValueTaskSourceOfT(object source, short token
///
/// Used by internal thunks that implement awaiting on Task.
///
- /// Task whose completion we are awaiting.
+ /// Task whose completion we are awaiting.
+ [BypassReadyToRun]
+ [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
+ private static unsafe void TransparentAwait(Task task)
+ {
+ ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
+ Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();
+
+ TaskContinuation? taskCont = state.CachedTaskContinuation;
+ if (taskCont != null)
+ {
+ state.CachedTaskContinuation = null;
+ }
+ else
+ {
+ taskCont = new TaskContinuation();
+ }
+
+ taskCont.Initialize(task);
+
+ sentinelContinuation.Next = taskCont;
+ state.StackState->TaskContinuation = taskCont;
+
+ state.CaptureContexts();
+ AsyncSuspend(taskCont);
+ }
+
+ ///
+ /// Used by internal thunks that implement awaiting on Task.
+ ///
+ /// Task whose completion we are awaiting.
[BypassReadyToRun]
[MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)]
- private static unsafe void TransparentAwait(Task t)
+ private static unsafe void TransparentAwaitOfT(Task task)
{
ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState;
Continuation? sentinelContinuation = state.SentinelContinuation ??= new Continuation();
- state.StackState->TaskNotifier = t;
+ TaskContinuation? taskCont = state.CachedTaskContinuation;
+ if (taskCont != null)
+ {
+ state.CachedTaskContinuation = null;
+ }
+ else
+ {
+ taskCont = new TaskContinuation();
+ }
+
+ taskCont.Initialize(task);
+
+ sentinelContinuation.Next = taskCont;
+ state.StackState->TaskContinuation = taskCont;
state.CaptureContexts();
- AsyncSuspend(sentinelContinuation);
+ AsyncSuspend(taskCont);
}
// Represents execution of a chain of suspended and resuming runtime
@@ -553,14 +599,6 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Continuation headContinuation = sentinelContinuation.Next!;
sentinelContinuation.Next = null;
- // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
- // These never have special continuation context handling.
- // Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource
- // which can capture continuation context flags.
- const ContinuationFlags continueFlags =
- ContinuationFlags.ContinueOnCapturedSynchronizationContext |
- ContinuationFlags.ContinueOnThreadPool |
- ContinuationFlags.ContinueOnCapturedTaskScheduler;
SetContinuationState(headContinuation);
@@ -568,16 +606,22 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
{
if (stackState->CriticalNotifier is { } critNotifier)
{
- Debug.Assert((headContinuation.Flags & continueFlags) == 0);
+ // Result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
+ // These never have special continuation context handling.
+ Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
critNotifier.UnsafeOnCompleted(GetContinuationAction());
}
- else if (stackState->TaskNotifier is { } taskNotifier)
+ else if (stackState->TaskContinuation is { } taskCont)
{
- Debug.Assert((headContinuation.Flags & continueFlags) == 0);
+ Debug.Assert(headContinuation == taskCont);
+ // Similarly for transparent awwaits we do not expect
+ // any continuation flags.
+ Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
// Runtime async callable wrapper for task returning
// method. This implements the context transparent
// forwarding and makes these wrappers minimal cost.
- if (!taskNotifier.TryAddCompletionAction(this))
+ Debug.Assert(taskCont.Task != null);
+ if (!taskCont.Task.TryAddCompletionAction(this))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
}
@@ -589,7 +633,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
Debug.Assert(source != null);
if (source is Task t)
{
- Debug.Assert((headContinuation.Flags & continueFlags) == 0);
+ Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
if (!t.TryAddCompletionAction(this))
{
ThreadPool.UnsafeQueueUserWorkItemInternal(this, preferLocal: true);
@@ -612,7 +656,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
// the direct AsyncHelpers.Await(ValueTask/ValueTask) path.
// In either case, that can only happen in nontransparent/user code.
Continuation contWithContinueFlags = valueTaskSourceCont;
- while ((contWithContinueFlags.Flags & continueFlags) == 0 && contWithContinueFlags.Next != null)
+ while ((contWithContinueFlags.Flags & ContinuationFlags.AllContinuationFlags) == 0 && contWithContinueFlags.Next != null)
{
contWithContinueFlags = contWithContinueFlags.Next;
}
@@ -630,7 +674,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}
// Clear continuation flags, so that continuation runs transparently
- contWithContinueFlags.Flags &= ~continueFlags;
+ contWithContinueFlags.Flags &= ~ContinuationFlags.AllContinuationFlags;
valueTaskSourceCont.OnCompletedValueTaskSource(
source,
@@ -642,7 +686,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
}
else
{
- Debug.Assert((headContinuation.Flags & continueFlags) == 0);
+ Debug.Assert((headContinuation.Flags & ContinuationFlags.AllContinuationFlags) == 0);
Debug.Assert(stackState->Notifier != null);
stackState->Notifier!.OnCompleted(GetContinuationAction());
}
diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs
new file mode 100644
index 00000000000000..9c8b4b2512a2fd
--- /dev/null
+++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.TaskContinuation.cs
@@ -0,0 +1,82 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics;
+using System.Threading.Tasks;
+
+namespace System.Runtime.CompilerServices
+{
+ public static partial class AsyncHelpers
+ {
+ private sealed unsafe class TaskContinuation : Continuation
+ {
+ internal Task? Task;
+ private delegate* _getResult;
+
+ public TaskContinuation()
+ {
+ ResumeInfo = (ResumeInfo*)Unsafe.AsPointer(in TaskContinuationResume.ResumeInfo);
+ }
+
+ public void GetResult(ref byte returnValue)
+ {
+ Debug.Assert(Task != null);
+
+ // Avoid retaining the task. The call below may throw.
+ Task task = Task;
+ Task = null;
+
+ _getResult(task, ref returnValue);
+ }
+
+ public void Initialize(Task task)
+ {
+ Task = task;
+ _getResult = &GetResult;
+ }
+
+ public void Initialize(Task task)
+ {
+ Task = task;
+ _getResult = &GetResult;
+ }
+
+ private static void GetResult(Task task, ref byte result)
+ {
+ TaskAwaiter.ValidateEnd(task);
+ }
+
+ private static void GetResult(Task task, ref byte result)
+ {
+ Debug.Assert(task is Task);
+
+ Task taskOfT = Unsafe.As>(ref task);
+ TaskAwaiter.ValidateEnd(taskOfT);
+ Unsafe.As(ref result) = taskOfT.ResultOnSuccess;
+ }
+
+ private static class TaskContinuationResume
+ {
+ [FixedAddressValueType]
+ public static readonly ResumeInfo ResumeInfo = new ResumeInfo
+ {
+ DiagnosticIP = null,
+ Resume = &ResumeTaskContinuation,
+ };
+
+ private static Continuation? ResumeTaskContinuation(Continuation cont, ref byte result)
+ {
+ var taskCont = (TaskContinuation)cont;
+ taskCont.Next = null;
+
+ Debug.Assert((taskCont.Flags & ContinuationFlags.AllContinuationFlags) == 0);
+
+ t_runtimeAsyncAwaitState.CachedTaskContinuation = taskCont;
+
+ taskCont.GetResult(ref result);
+ return null;
+ }
+ }
+ }
+ }
+}
diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs
index 7d5af875d08a50..7e6693599738f8 100644
--- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs
+++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.ValueTaskContinuation.cs
@@ -106,12 +106,7 @@ private static class ValueTaskContinuationResume
var vtsCont = (ValueTaskContinuation)cont;
vtsCont.Next = null;
- const ContinuationFlags continueFlags =
- ContinuationFlags.ContinueOnCapturedSynchronizationContext |
- ContinuationFlags.ContinueOnThreadPool |
- ContinuationFlags.ContinueOnCapturedTaskScheduler;
-
- Debug.Assert((vtsCont.Flags & continueFlags) == 0);
+ Debug.Assert((vtsCont.Flags & ContinuationFlags.AllContinuationFlags) == 0);
t_runtimeAsyncAwaitState.CachedValueTaskContinuation = vtsCont;
diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj
index f687deaff880bc..3108c8dea494d1 100644
--- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj
+++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System.Private.CoreLib.csproj
@@ -64,6 +64,7 @@
+
diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs
index 8dac68d3d4bb6a..4773b38bc9aebd 100644
--- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs
+++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs
@@ -341,6 +341,7 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
// Task path
TypeDesc taskType = taskReturningMethodReturnType;
MethodDesc completedTaskResultMethod;
+ MethodDesc transparentAwaitMethod;
if (!taskReturningMethodReturnType.HasInstantiation)
{
@@ -348,6 +349,9 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
completedTaskResultMethod = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("CompletedTask"u8, null);
+ transparentAwaitMethod = context.SystemModule
+ .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
+ .GetKnownMethod("TransparentAwait"u8, null);
}
else
{
@@ -357,7 +361,12 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
MethodDesc completedTaskResultMethodOpen = context.SystemModule
.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
.GetKnownMethod("CompletedTaskResult"u8, null);
+ MethodDesc transparentAwaitMethodOpen = context.SystemModule
+ .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
+ .GetKnownMethod("TransparentAwaitOfT"u8, null);
+
completedTaskResultMethod = completedTaskResultMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType));
+ transparentAwaitMethod = transparentAwaitMethodOpen.MakeInstantiatedMethod(new Instantiation(logicalReturnType));
}
ILLocalVariable taskLocal = emitter.NewLocal(taskType);
@@ -373,9 +382,8 @@ public static MethodIL EmitAsyncMethodThunk(MethodDesc asyncMethod, MethodDesc t
codestream.Emit(ILOpcode.brtrue, getResultLabel);
codestream.EmitLdLoc(taskLocal);
- codestream.Emit(ILOpcode.call, emitter.NewToken(
- context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8)
- .GetKnownMethod("TransparentAwait"u8, null)));
+ codestream.Emit(ILOpcode.call, emitter.NewToken(context.GetCoreLibEntryPoint("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8, "TailAwait"u8, null)));
+ codestream.Emit(ILOpcode.call, emitter.NewToken(transparentAwaitMethod));
codestream.EmitLabel(getResultLabel);
codestream.EmitLdLoc(taskLocal);
diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp
index c3dfa0a60fc84a..ed80ce9a6c19e0 100644
--- a/src/coreclr/vm/asyncthunks.cpp
+++ b/src/coreclr/vm/asyncthunks.cpp
@@ -487,7 +487,7 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig
// Task task = other(arg);
// if (!task.IsCompleted)
// {
- // // Magic function which will suspend the current run of async methods
+ // TailAwait();
// AsyncHelpers.TransparentAwait(task);
// }
// return AsyncHelpers.CompletedTaskResult(task);
@@ -595,12 +595,17 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig
MethodTable* pMTTask;
int completedTaskResultToken;
+ int transparentAwaitToken;
+
if (msig.IsReturnTypeVoid())
{
pMTTask = CoreLibBinder::GetClass(CLASS__TASK);
MethodDesc* pMDCompletedTask = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK);
+ MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT);
+
completedTaskResultToken = pCode->GetToken(pMDCompletedTask);
+ transparentAwaitToken = pCode->GetToken(pMDTransparentAwait);
}
else
{
@@ -608,24 +613,33 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pTaskReturningVariant, MetaSig
pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable();
MethodDesc* pMDCompletedTaskResult = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK_RESULT);
+ MethodDesc* pMDTransparentAwait = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT_OF_T);
+
pMDCompletedTaskResult = FindOrCreateAssociatedMethodDesc(pMDCompletedTaskResult, pMDCompletedTaskResult->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE);
+ pMDTransparentAwait = FindOrCreateAssociatedMethodDesc(pMDTransparentAwait, pMDTransparentAwait->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE);
+
completedTaskResultToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDCompletedTaskResult);
+ transparentAwaitToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDTransparentAwait);
}
LocalDesc taskLocalDesc(pMTTask);
DWORD taskLocal = pCode->NewLocal(taskLocalDesc);
ILCodeLabel* pGetResultLabel = pCode->NewCodeLabel();
- // Store task returned by actual user func or by ValueTask.AsTask
+ // Store task returned by actual user func
pCode->EmitSTLOC(taskLocal);
+ // Did it already complete?
pCode->EmitLDLOC(taskLocal);
pCode->EmitCALL(METHOD__TASK__GET_ISCOMPLETED, 1, 1);
pCode->EmitBRTRUE(pGetResultLabel);
+ // No, so tail await to TransparentAwait
pCode->EmitLDLOC(taskLocal);
- pCode->EmitCALL(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT, 1, 0);
+ pCode->EmitCALL(METHOD__ASYNC_HELPERS__TAIL_AWAIT, 0, 0);
+ pCode->EmitCALL(transparentAwaitToken, 1, 0);
+ // Yes, so just get the result
pCode->EmitLabel(pGetResultLabel);
pCode->EmitLDLOC(taskLocal);
pCode->EmitCALL(completedTaskResultToken, 1, msig.IsReturnTypeVoid() ? 0 : 1);
diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h
index 476878a901f333..cc515fd5f610e8 100644
--- a/src/coreclr/vm/corelib.h
+++ b/src/coreclr/vm/corelib.h
@@ -716,6 +716,7 @@ DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION, ValueTaskFromExcepti
DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION_1, ValueTaskFromException, GM_Exception_RetValueTaskOfT)
DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT, TransparentAwait, NoSig)
+DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_OF_T, TransparentAwaitOfT, NoSig)
DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_VALUE_TASK, TransparentAwaitValueTask, NoSig)
DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_VALUE_TASK_OF_T, TransparentAwaitValueTaskOfT, NoSig)
DEFINE_METHOD(ASYNC_HELPERS, COMPLETED_TASK_RESULT, CompletedTaskResult, NoSig)