@@ -380,7 +380,38 @@ private static unsafe void TransparentAwaitValueTask(ValueTask valueTask)
380380
381381 [ BypassReadyToRun ]
382382 [ MethodImpl ( MethodImplOptions . NoInlining | MethodImplOptions . Async ) ]
383- private static unsafe void TransparentAwaitValueTaskOfT < T > ( ValueTask < T ? > valueTask )
383+ private static unsafe void AwaitValueTaskSource ( object source , short token )
384+ {
385+ ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState ;
386+ Continuation ? sentinelContinuation = state . SentinelContinuation ??= new Continuation ( ) ;
387+
388+ ValueTaskContinuation ? vtsCont = state . CachedValueTaskContinuation ;
389+ if ( vtsCont != null )
390+ {
391+ state . CachedValueTaskContinuation = null ;
392+ }
393+ else
394+ {
395+ vtsCont = new ValueTaskContinuation ( ) ;
396+ }
397+
398+ Debug . Assert ( source != null ) ;
399+ vtsCont . Initialize ( source , token ) ;
400+
401+ // We only need to capture flags.
402+ // If needed, VTS will use the scheduling context captured in the "state".
403+ CaptureContinuationContextFlags ( ref vtsCont . Flags , state . CurrentThread ! ) ;
404+
405+ sentinelContinuation . Next = vtsCont ;
406+ state . StackState ->ValueTaskContinuation = vtsCont ;
407+
408+ state . CaptureContexts ( ) ;
409+ AsyncSuspend ( vtsCont ) ;
410+ }
411+
412+ [ BypassReadyToRun ]
413+ [ MethodImpl ( MethodImplOptions . NoInlining | MethodImplOptions . Async ) ]
414+ private static unsafe void TransparentAwaitValueTaskOfT < T > ( ValueTask < T > valueTask )
384415 {
385416 ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState ;
386417 Continuation ? sentinelContinuation = state . SentinelContinuation ??= new Continuation ( ) ;
@@ -405,6 +436,37 @@ private static unsafe void TransparentAwaitValueTaskOfT<T>(ValueTask<T?> valueTa
405436 AsyncSuspend ( vtsCont ) ;
406437 }
407438
439+ [ BypassReadyToRun ]
440+ [ MethodImpl ( MethodImplOptions . NoInlining | MethodImplOptions . Async ) ]
441+ private static unsafe void AwaitValueTaskSourceOfT < T > ( object source , short token )
442+ {
443+ ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState ;
444+ Continuation ? sentinelContinuation = state . SentinelContinuation ??= new Continuation ( ) ;
445+
446+ ValueTaskContinuation ? vtsCont = state . CachedValueTaskContinuation ;
447+ if ( vtsCont != null )
448+ {
449+ state . CachedValueTaskContinuation = null ;
450+ }
451+ else
452+ {
453+ vtsCont = new ValueTaskContinuation ( ) ;
454+ }
455+
456+ Debug . Assert ( source != null ) ;
457+ vtsCont . Initialize < T > ( source , token ) ;
458+
459+ // We only need to capture flags.
460+ // If needed, VTS will use the scheduling context captured in the "state".
461+ CaptureContinuationContextFlags ( ref vtsCont . Flags , state . CurrentThread ! ) ;
462+
463+ sentinelContinuation . Next = vtsCont ;
464+ state . StackState ->ValueTaskContinuation = vtsCont ;
465+
466+ state . CaptureContexts ( ) ;
467+ AsyncSuspend ( vtsCont ) ;
468+ }
469+
408470 /// <summary>
409471 /// Used by internal thunks that implement awaiting on Task.
410472 /// </summary>
@@ -493,23 +555,25 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
493555
494556 // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter.
495557 // These never have special continuation context handling.
558+ // Except for the scenario with ValueTaskContinuation that wraps ValueTaskSource
559+ // which can capture continuation context flags.
496560 const ContinuationFlags continueFlags =
497561 ContinuationFlags . ContinueOnCapturedSynchronizationContext |
498562 ContinuationFlags . ContinueOnThreadPool |
499563 ContinuationFlags . ContinueOnCapturedTaskScheduler ;
500564
501- Debug . Assert ( ( headContinuation . Flags & continueFlags ) == 0 ) ;
502-
503565 SetContinuationState ( headContinuation ) ;
504566
505567 try
506568 {
507569 if ( stackState ->CriticalNotifier is { } critNotifier )
508570 {
571+ Debug . Assert ( ( headContinuation . Flags & continueFlags ) == 0 ) ;
509572 critNotifier . UnsafeOnCompleted ( GetContinuationAction ( ) ) ;
510573 }
511574 else if ( stackState ->TaskNotifier is { } taskNotifier )
512575 {
576+ Debug . Assert ( ( headContinuation . Flags & continueFlags ) == 0 ) ;
513577 // Runtime async callable wrapper for task returning
514578 // method. This implements the context transparent
515579 // forwarding and makes these wrappers minimal cost.
@@ -525,6 +589,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
525589 Debug . Assert ( source != null ) ;
526590 if ( source is Task t )
527591 {
592+ Debug . Assert ( ( headContinuation . Flags & continueFlags ) == 0 ) ;
528593 if ( ! t . TryAddCompletionAction ( this ) )
529594 {
530595 ThreadPool . UnsafeQueueUserWorkItemInternal ( this , preferLocal : true ) ;
@@ -541,17 +606,18 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
541606 // the continuation chain builds from the innermost frame out and at the time when the
542607 // notifier is created we do not know yet if the caller wants to continue on a context.
543608
544- // Skip to a nontransparent/user continuation. Such continuaton must exist.
609+ // Skip to a nontransparent/user continuation. Such continuation must exist.
545610 // Since we see a VTS notifier, something was directly or indirectly
546- // awaiting an async thunk for a ValueTask-returning method.
547- // That can only happen in nontransparent/user code.
548- Continuation nextUserContinuation = valueTaskSourceCont . Next ! ;
549- while ( ( nextUserContinuation . Flags & continueFlags ) == 0 && nextUserContinuation . Next != null )
611+ // awaiting either an async thunk for a ValueTask-returning method or
612+ // the direct AsyncHelpers.Await(ValueTask/ValueTask<T>) path.
613+ // In either case, that can only happen in nontransparent/user code.
614+ Continuation contWithContinueFlags = valueTaskSourceCont ;
615+ while ( ( contWithContinueFlags . Flags & continueFlags ) == 0 && contWithContinueFlags . Next != null )
550616 {
551- nextUserContinuation = nextUserContinuation . Next ;
617+ contWithContinueFlags = contWithContinueFlags . Next ;
552618 }
553619
554- ContinuationFlags continuationFlags = nextUserContinuation . Flags ;
620+ ContinuationFlags continuationFlags = contWithContinueFlags . Flags ;
555621 const ContinuationFlags continueOnContextFlags =
556622 ContinuationFlags . ContinueOnCapturedSynchronizationContext |
557623 ContinuationFlags . ContinueOnCapturedTaskScheduler ;
@@ -564,7 +630,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
564630 }
565631
566632 // Clear continuation flags, so that continuation runs transparently
567- nextUserContinuation . Flags &= ~ continueFlags ;
633+ contWithContinueFlags . Flags &= ~ continueFlags ;
568634
569635 valueTaskSourceCont . OnCompletedValueTaskSource (
570636 source ,
@@ -576,6 +642,7 @@ internal unsafe bool HandleSuspended(ref RuntimeAsyncAwaitState state)
576642 }
577643 else
578644 {
645+ Debug . Assert ( ( headContinuation . Flags & continueFlags ) == 0 ) ;
579646 Debug . Assert ( stackState ->Notifier != null ) ;
580647 stackState ->Notifier ! . OnCompleted ( GetContinuationAction ( ) ) ;
581648 }
@@ -1117,7 +1184,7 @@ private static void RestoreContextsOnSuspension(bool resumed, ExecutionContext?
11171184 }
11181185 }
11191186
1120- private static void CaptureContinuationContext ( ref object continuationContext , ref ContinuationFlags flags )
1187+ private static void CaptureContinuationContext ( ref object ? continuationContext , ref ContinuationFlags flags )
11211188 {
11221189 SynchronizationContext ? syncCtx = Thread . CurrentThreadAssumedInitialized . _synchronizationContext ;
11231190 if ( syncCtx != null && syncCtx . GetType ( ) != typeof ( SynchronizationContext ) )
@@ -1138,6 +1205,26 @@ private static void CaptureContinuationContext(ref object continuationContext, r
11381205 flags |= ContinuationFlags . ContinueOnThreadPool ;
11391206 }
11401207
1208+ // Same as above, but only captures flags
1209+ private static void CaptureContinuationContextFlags ( ref ContinuationFlags flags , Thread currentThread )
1210+ {
1211+ SynchronizationContext ? syncCtx = currentThread . _synchronizationContext ;
1212+ if ( syncCtx != null && syncCtx . GetType ( ) != typeof ( SynchronizationContext ) )
1213+ {
1214+ flags |= ContinuationFlags . ContinueOnCapturedSynchronizationContext ;
1215+ return ;
1216+ }
1217+
1218+ TaskScheduler ? sched = TaskScheduler . InternalCurrent ;
1219+ if ( sched != null && sched != TaskScheduler . Default )
1220+ {
1221+ flags |= ContinuationFlags . ContinueOnCapturedTaskScheduler ;
1222+ return ;
1223+ }
1224+
1225+ flags |= ContinuationFlags . ContinueOnThreadPool ;
1226+ }
1227+
11411228 // Finish suspension in the common case of a custom await or for a ConfigureAwait(false) task await:
11421229 // - Capture current ExecutionContext into the continuation
11431230 // - Restore ExecutionContext and SynchronizationContext to the current Thread object
0 commit comments