diff --git a/dotnet/src/BearerTokenProvider.cs b/dotnet/src/BearerTokenProvider.cs index 2c59da09b..923c225bc 100644 --- a/dotnet/src/BearerTokenProvider.cs +++ b/dotnet/src/BearerTokenProvider.cs @@ -7,7 +7,7 @@ namespace GitHub.Copilot; /// -/// Arguments passed to a bearer-token callback (the GetBearerToken property +/// Arguments passed to a bearer-token callback (the BearerTokenProvider property /// on / ) when the /// runtime needs a fresh bearer token for a BYOK provider. /// @@ -29,4 +29,11 @@ public sealed class ProviderTokenArgs /// provider-agnostic and forwards only the provider name. /// public required string ProviderName { get; init; } + + /// + /// Id of the session that triggered this token request. A client-level + /// shared callback registered for many sessions can use this to resolve the + /// owning session and scope token acquisition or caching per session. + /// + public required string SessionId { get; init; } } diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index dff034b79..a67eb9681 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -671,7 +671,7 @@ private CopilotSession InitializeSession( private const string DefaultBearerTokenProviderName = "default"; /// - /// Collects the per-provider GetBearerToken callbacks keyed by + /// Collects the per-provider BearerTokenProvider callbacks keyed by /// provider name for session-side registration. The singular, whole-session /// uses the implicit /// . @@ -679,18 +679,15 @@ private CopilotSession InitializeSession( private static Dictionary>> BuildBearerTokenCallbacks(SessionConfigBase config) { var callbacks = new Dictionary>>(StringComparer.Ordinal); - if (config.Provider?.GetBearerToken is { } singular) + if (config.Provider?.BearerTokenProvider is { } singular) { callbacks[DefaultBearerTokenProviderName] = singular; } if (config.Providers != null) { - foreach (var provider in config.Providers) + foreach (var provider in config.Providers.Where(provider => provider.BearerTokenProvider is not null)) { - if (provider.GetBearerToken is { } callback) - { - callbacks[provider.Name] = callback; - } + callbacks[provider.Name] = provider.BearerTokenProvider!; } } return callbacks; diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index f8e285eab..0985848e2 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -871,7 +871,7 @@ internal void RegisterAutoModeSwitchHandler(Func - /// Registers per-provider GetBearerToken callbacks for BYOK + /// Registers per-provider BearerTokenProvider callbacks for BYOK /// providers configured with managed-identity / on-demand bearer-token auth. /// /// @@ -899,7 +899,7 @@ internal void RegisterBearerTokenProviders(IReadOnlyDictionary /// Routes runtime providerToken.getToken requests to the matching - /// per-provider GetBearerToken callback registered on the session. + /// per-provider BearerTokenProvider callback registered on the session. /// private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler { @@ -910,7 +910,7 @@ public async Task GetTokenAsync(ProviderTokenAcquire throw new InvalidOperationException( $"No bearer-token provider registered for provider \"{request.ProviderName}\""); } - var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false); + var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName, SessionId = request.SessionId }).ConfigureAwait(false); return new ProviderTokenAcquireResult { Token = token }; } } diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index a33dff61c..5ae965781 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -2044,26 +2044,27 @@ public sealed class ProviderConfig public string? BearerToken { get; set; } /// - /// Wire-only flag, emitted automatically when is set, that tells + /// Wire-only flag, emitted automatically when is set, that tells /// the runtime to request a token over the session-scoped providerToken.getToken RPC - /// before each outbound request to this provider. Derived from ; + /// before each outbound request to this provider. Derived from ; /// internal and never part of the public API. /// [JsonInclude] [JsonPropertyName("hasBearerTokenProvider")] - internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null; /// /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a /// callback backed by your own identity library. Never serialized — setting it makes the SDK send /// hasBearerTokenProvider: true on the wire and answer the runtime's - /// providerToken.getToken requests. Mutually exclusive with and - /// . + /// providerToken.getToken requests. When set alongside /, this callback takes precedence: + /// the runtime applies the token it returns as the Authorization: Bearer header for each request + /// and does not send the static credential. /// [JsonIgnore] [Experimental(Diagnostics.Experimental)] - public Func>? GetBearerToken { get; set; } + public Func>? BearerTokenProvider { get; set; } /// /// Azure-specific configuration options. @@ -2198,26 +2199,27 @@ public sealed class NamedProviderConfig public string? BearerToken { get; set; } /// - /// Wire-only flag, emitted automatically when is set, that tells + /// Wire-only flag, emitted automatically when is set, that tells /// the runtime to request a token over the session-scoped providerToken.getToken RPC - /// before each outbound request to this provider. Derived from ; + /// before each outbound request to this provider. Derived from ; /// internal and never part of the public API. /// [JsonInclude] [JsonPropertyName("hasBearerTokenProvider")] - internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null; /// /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a /// callback backed by your own identity library. Never serialized — setting it makes the SDK send /// hasBearerTokenProvider: true on the wire and answer the runtime's - /// providerToken.getToken requests. Mutually exclusive with and - /// . + /// providerToken.getToken requests. When set alongside /, this callback takes precedence: + /// the runtime applies the token it returns as the Authorization: Bearer header for each request + /// and does not send the static credential. /// [JsonIgnore] [Experimental(Diagnostics.Experimental)] - public Func>? GetBearerToken { get; set; } + public Func>? BearerTokenProvider { get; set; } /// /// Azure-specific configuration options. diff --git a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs index 3f869a437..5973dc61c 100644 --- a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs +++ b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs @@ -13,7 +13,7 @@ namespace GitHub.Copilot.Test.E2E; /// /// End-to-end coverage for the experimental BYOK bearer-token-provider surface -/// (GetBearerToken on a provider config). The callback stays entirely on +/// (BearerTokenProvider on a provider config). The callback stays entirely on /// the SDK/client side: the SDK strips it from the wire config, sets the /// hasBearerTokenProvider flag, and the runtime calls back over the /// session-scoped providerToken.getToken RPC before each outbound model @@ -81,7 +81,7 @@ private static async Task RunTurnAsync( { await session.SendAndWaitAsync(new MessageOptions { Prompt = prompt }); } - catch + catch (InvalidOperationException) { // The handler always 404s the BYOK endpoint, so the turn errors after // the token-bearing request was already captured. Expected. @@ -110,7 +110,7 @@ public async Task Applies_The_Callbacks_Token_As_The_Authorization_Header() Type = "openai", WireApi = "completions", BaseUrl = PrimaryBaseUrl, - GetBearerToken = _ => + BearerTokenProvider = _ => { Interlocked.Increment(ref calls); return Task.FromResult(sentinel); @@ -149,7 +149,7 @@ public async Task Re_Acquires_A_Fresh_Token_For_Each_Request() BaseUrl = PrimaryBaseUrl, // A distinct token per acquisition proves the runtime re-invokes // the callback per request rather than caching a previous token. - GetBearerToken = _ => + BearerTokenProvider = _ => { var n = Interlocked.Increment(ref calls); return Task.FromResult($"rotating-token-{n}"); @@ -189,6 +189,9 @@ Func> MakeCallback(string providerName) => // The runtime forwards the requesting provider's name so the client // can dispatch to the right credential. Assert.Equal(providerName, args.ProviderName); + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + Assert.False(string.IsNullOrEmpty(args.SessionId)); acquiredFor.Add(providerName); return Task.FromResult(tokenByProvider[providerName]); }; @@ -205,7 +208,7 @@ Func> MakeCallback(string providerName) => Type = "openai", WireApi = "completions", BaseUrl = RedBaseUrl, - GetBearerToken = MakeCallback("red"), + BearerTokenProvider = MakeCallback("red"), }, new() { @@ -213,7 +216,7 @@ Func> MakeCallback(string providerName) => Type = "openai", WireApi = "completions", BaseUrl = BlueBaseUrl, - GetBearerToken = MakeCallback("blue"), + BearerTokenProvider = MakeCallback("blue"), }, }; var models = new List @@ -240,7 +243,7 @@ Func> MakeCallback(string providerName) => /// The runtime invokes for every model-layer HTTP /// request. Requests aimed at a fake BYOK host (*.invalid) are captured — /// recording the Authorization header the runtime applied after calling -/// the provider's GetBearerToken callback over the session-scoped +/// the provider's BearerTokenProvider callback over the session-scoped /// providerToken.getToken RPC — and answered with a synthetic 404 /// (a non-retryable status, so each outbound model request yields exactly one /// capture). Every other request (CAPI bootstrap: model catalog, policy, …) is @@ -262,13 +265,14 @@ protected override Task SendRequestAsync(HttpRequestMessage ? string.Join(", ", values) : null)); - return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + var response = new HttpResponseMessage(HttpStatusCode.NotFound) { Content = new StringContent( "{\"error\":{\"message\":\"fake byok endpoint\"}}", System.Text.Encoding.UTF8, "application/json"), - }); + }; + return Task.FromResult(response); } // CAPI bootstrap (model catalog, policy, …) — answered off-network. diff --git a/go/client.go b/go/client.go index 7cdb4fbad..970f04642 100644 --- a/go/client.go +++ b/go/client.go @@ -57,18 +57,18 @@ import ( // whole-session [ProviderConfig]. Named providers are keyed by their own Name. const defaultBearerTokenProviderName = "default" -// collectBearerTokenProviders gathers the per-provider [GetBearerToken] callbacks +// collectBearerTokenProviders gathers the per-provider [BearerTokenProvider] callbacks // from the singular provider and any named providers, keyed by provider name. The // singular provider uses the implicit name "default"; named providers use their // own Name. Returns nil when no callbacks are configured. -func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]GetBearerToken { - callbacks := make(map[string]GetBearerToken) - if provider != nil && provider.GetBearerToken != nil { - callbacks[defaultBearerTokenProviderName] = provider.GetBearerToken +func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]BearerTokenProvider { + callbacks := make(map[string]BearerTokenProvider) + if provider != nil && provider.BearerTokenProvider != nil { + callbacks[defaultBearerTokenProviderName] = provider.BearerTokenProvider } for i := range providers { - if providers[i].GetBearerToken != nil { - callbacks[providers[i].Name] = providers[i].GetBearerToken + if providers[i].BearerTokenProvider != nil { + callbacks[providers[i].Name] = providers[i].BearerTokenProvider } } if len(callbacks) == 0 { diff --git a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go index 6a6e5cbc2..2f298596d 100644 --- a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go +++ b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go @@ -37,7 +37,7 @@ type capturedBYOKRequest struct { // byokCapturingRoundTripper stands in for a real HTTP upstream. It records the // `Authorization` header the runtime applied (after calling the provider's -// GetBearerToken callback over the session-scoped `providerToken.getToken` RPC) +// BearerTokenProvider callback over the session-scoped `providerToken.getToken` RPC) // for every request aimed at a fake `.invalid` BYOK host, answering them with a // synthetic 404 (a non-retryable status, so each outbound model request yields // exactly one capture). Every other request (CAPI bootstrap: model catalog, @@ -96,7 +96,7 @@ func (rt *byokCapturingRoundTripper) reset() { } // TestBYOKBearerTokenProvider is end-to-end coverage for the experimental BYOK -// bearer-token-provider surface (GetBearerToken on a provider config). The +// bearer-token-provider surface (BearerTokenProvider on a provider config). The // callback stays entirely on the SDK/client side: the SDK strips it from the // wire config, sets the `hasBearerTokenProvider` flag, and the runtime calls // back over the session-scoped `providerToken.getToken` RPC before each outbound @@ -151,11 +151,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } providers := []copilot.NamedProviderConfig{{ - Name: "mi", - Type: "openai", - WireAPI: "completions", - BaseURL: byokPrimaryBaseURL, - GetBearerToken: getBearerToken, + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + BearerTokenProvider: getBearerToken, }} models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} @@ -189,11 +189,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } providers := []copilot.NamedProviderConfig{{ - Name: "mi", - Type: "openai", - WireAPI: "completions", - BaseURL: byokPrimaryBaseURL, - GetBearerToken: getBearerToken, + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + BearerTokenProvider: getBearerToken, }} models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} @@ -227,13 +227,18 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } var mu sync.Mutex var acquiredFor []string - makeCallback := func(providerName string) copilot.GetBearerToken { + makeCallback := func(providerName string) copilot.BearerTokenProvider { return func(args copilot.ProviderTokenArgs) (string, error) { // The runtime forwards the requesting provider's name so the // client can dispatch to the right credential. if args.ProviderName != providerName { t.Errorf("Expected providerName %q, got %q", providerName, args.ProviderName) } + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + if args.SessionID == "" { + t.Errorf("Expected a non-empty session id in token args") + } mu.Lock() acquiredFor = append(acquiredFor, providerName) mu.Unlock() @@ -243,18 +248,18 @@ func TestBYOKBearerTokenProvider(t *testing.T) { providers := []copilot.NamedProviderConfig{ { - Name: "red", - Type: "openai", - WireAPI: "completions", - BaseURL: byokRedBaseURL, - GetBearerToken: makeCallback("red"), + Name: "red", + Type: "openai", + WireAPI: "completions", + BaseURL: byokRedBaseURL, + BearerTokenProvider: makeCallback("red"), }, { - Name: "blue", - Type: "openai", - WireAPI: "completions", - BaseURL: byokBlueBaseURL, - GetBearerToken: makeCallback("blue"), + Name: "blue", + Type: "openai", + WireAPI: "completions", + BaseURL: byokBlueBaseURL, + BearerTokenProvider: makeCallback("blue"), }, } models := []copilot.ProviderModelConfig{ diff --git a/go/session.go b/go/session.go index d92466d8e..851157ba8 100644 --- a/go/session.go +++ b/go/session.go @@ -77,7 +77,7 @@ type Session struct { elicitationMu sync.RWMutex canvasHandler CanvasHandler canvasMu sync.RWMutex - bearerTokenProviders map[string]GetBearerToken + bearerTokenProviders map[string]BearerTokenProvider bearerTokenMu sync.RWMutex openCanvases []rpc.OpenCanvasInstance openCanvasesMu sync.RWMutex @@ -183,7 +183,7 @@ func (s *Session) getCanvasHandler() CanvasHandler { return s.canvasHandler } -// registerBearerTokenProviders installs per-provider [GetBearerToken] callbacks +// registerBearerTokenProviders installs per-provider [BearerTokenProvider] callbacks // for BYOK providers configured with managed-identity / on-demand bearer-token // auth, keyed by provider name. // @@ -192,10 +192,10 @@ func (s *Session) getCanvasHandler() CanvasHandler { // runtime needs a token it issues a session-scoped `providerToken.getToken` // request, which the session's provider-token adapter routes to the matching // per-provider callback. -func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerToken) { +func (s *Session) registerBearerTokenProviders(providers map[string]BearerTokenProvider) { s.bearerTokenMu.Lock() defer s.bearerTokenMu.Unlock() - s.bearerTokenProviders = make(map[string]GetBearerToken, len(providers)) + s.bearerTokenProviders = make(map[string]BearerTokenProvider, len(providers)) for name, callback := range providers { if callback == nil { continue @@ -204,7 +204,7 @@ func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerTok } } -func (s *Session) getBearerTokenProvider(providerName string) GetBearerToken { +func (s *Session) getBearerTokenProvider(providerName string) BearerTokenProvider { s.bearerTokenMu.RLock() defer s.bearerTokenMu.RUnlock() return s.bearerTokenProviders[providerName] @@ -229,7 +229,7 @@ func (a *providerTokenClientSessionAdapter) GetToken(request *rpc.ProviderTokenA if callback == nil { return nil, providerTokenJSONRPCError(fmt.Sprintf("No bearer-token provider registered for provider %q", request.ProviderName)) } - token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName}) + token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName, SessionID: request.SessionID}) if err != nil { return nil, providerTokenJSONRPCError(err.Error()) } diff --git a/go/types.go b/go/types.go index 52a06e110..8a7df3c46 100644 --- a/go/types.go +++ b/go/types.go @@ -1564,7 +1564,7 @@ type ResumeSessionConfig struct { ExpAssignments any } -// ProviderTokenArgs carries the context passed to a [GetBearerToken] callback +// ProviderTokenArgs carries the context passed to a [BearerTokenProvider] callback // when the runtime needs a fresh bearer token for a BYOK provider. // // Experimental: ProviderTokenArgs is part of the experimental managed-identity / @@ -1579,9 +1579,15 @@ type ProviderTokenArgs struct { // The callback closes over its own token scope/audience; the runtime is // provider-agnostic and forwards only the provider name. ProviderName string + + // SessionID is the id of the session that triggered this token request. A + // client-level shared callback registered for many sessions can use this to + // resolve the owning session and scope token acquisition or caching per + // session. + SessionID string } -// GetBearerToken is a per-provider callback that resolves a bearer token on +// BearerTokenProvider is a per-provider callback that resolves a bearer token on // demand, returning the raw token string (without the "Bearer " prefix). The // Copilot SDK itself takes no Azure dependency: the consumer supplies this // callback backed by their own identity library (for example azidentity's @@ -1589,10 +1595,10 @@ type ProviderTokenArgs struct { // outbound model request. The runtime does no caching of its own, so the callback // (or the identity library it wraps) owns token caching and refresh. // -// Experimental: GetBearerToken is part of the experimental managed-identity / +// Experimental: BearerTokenProvider is part of the experimental managed-identity / // bearer-token-provider surface and may change or be removed in future SDK or CLI // releases. -type GetBearerToken func(args ProviderTokenArgs) (string, error) +type BearerTokenProvider func(args ProviderTokenArgs) (string, error) type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". @@ -1634,20 +1640,24 @@ type ProviderConfig struct { // tokens. When hit, the model stops generating and returns a truncated // response. MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - // GetBearerToken resolves a bearer token on demand for this provider + // BearerTokenProvider resolves a bearer token on demand for this provider // (managed-identity / on-demand auth). When set, the SDK strips the callback // from the wire config and instead sends `hasBearerTokenProvider: true`; the // runtime calls back over the session-scoped `providerToken.getToken` RPC // before each outbound model request and applies the returned token as the // Authorization header. Never serialized. // + // When set alongside APIKey/BearerToken, this callback takes precedence: the + // runtime applies the token it returns as the Authorization: Bearer header for + // each request and does not send the static credential. + // // Experimental: part of the experimental managed-identity / bearer-token-provider // surface and may change or be removed in future SDK or CLI releases. - GetBearerToken GetBearerToken `json:"-"` + BearerTokenProvider BearerTokenProvider `json:"-"` } // MarshalJSON serializes the provider config, deriving the wire-only -// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.GetBearerToken]. +// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.BearerTokenProvider]. // The non-serializable callback never crosses the RPC boundary; the runtime only // learns that a token provider exists and forwards the provider name back when it // needs a token. @@ -1657,7 +1667,7 @@ func (p ProviderConfig) MarshalJSON() ([]byte, error) { wire HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` }{wire: wire(p)} - if p.GetBearerToken != nil { + if p.BearerTokenProvider != nil { aux.HasBearerTokenProvider = Bool(true) } return json.Marshal(aux) @@ -1715,21 +1725,25 @@ type NamedProviderConfig struct { Azure *AzureProviderOptions `json:"azure,omitempty"` // Headers are custom HTTP headers included in all outbound provider requests. Headers map[string]string `json:"headers,omitempty"` - // GetBearerToken resolves a bearer token on demand for this provider + // BearerTokenProvider resolves a bearer token on demand for this provider // (managed-identity / on-demand auth). When set, the SDK strips the callback // from the wire config and instead sends `hasBearerTokenProvider: true`; the // runtime calls back over the session-scoped `providerToken.getToken` RPC // before each outbound model request and applies the returned token as the // Authorization header. Never serialized. // + // When set alongside APIKey/BearerToken, this callback takes precedence: the + // runtime applies the token it returns as the Authorization: Bearer header for + // each request and does not send the static credential. + // // Experimental: part of the experimental managed-identity / bearer-token-provider // surface and may change or be removed in future SDK or CLI releases. - GetBearerToken GetBearerToken `json:"-"` + BearerTokenProvider BearerTokenProvider `json:"-"` } // MarshalJSON serializes the named provider config, deriving the wire-only // `hasBearerTokenProvider` flag from the presence of -// [NamedProviderConfig.GetBearerToken]. The non-serializable callback never +// [NamedProviderConfig.BearerTokenProvider]. The non-serializable callback never // crosses the RPC boundary; the runtime only learns that a token provider exists // and forwards the provider name back when it needs a token. func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { @@ -1738,7 +1752,7 @@ func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { wire HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` }{wire: wire(p)} - if p.GetBearerToken != nil { + if p.BearerTokenProvider != nil { aux.HasBearerTokenProvider = Bool(true) } return json.Marshal(aux) diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 9e0391594..90f76b6df 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -74,7 +74,7 @@ import com.github.copilot.rpc.ExitPlanModeRequest; import com.github.copilot.rpc.ExitPlanModeResult; import com.github.copilot.rpc.ElicitationSchema; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.GetMessagesResponse; import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; @@ -169,7 +169,7 @@ public final class CopilotSession implements AutoCloseable { private final Set> eventHandlers = ConcurrentHashMap.newKeySet(); private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); - private final Map bearerTokenProviders = new ConcurrentHashMap<>(); + private final Map bearerTokenProviders = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); @@ -1358,7 +1358,7 @@ void registerElicitationHandler(ElicitationHandler handler) { * @param providers * the callbacks keyed by provider name */ - void registerBearerTokenProviders(Map providers) { + void registerBearerTokenProviders(Map providers) { bearerTokenProviders.clear(); if (providers != null) { bearerTokenProviders.putAll(providers); @@ -1372,7 +1372,7 @@ void registerBearerTokenProviders(Map providers) { * the provider name * @return the registered callback, or {@code null} if none is registered */ - GetBearerToken getBearerTokenProvider(String providerName) { + BearerTokenProvider getBearerTokenProvider(String providerName) { return bearerTokenProviders.get(providerName); } diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index b62e8c582..9a42a8e22 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -19,7 +19,7 @@ import com.github.copilot.generated.SessionEvent; import com.github.copilot.rpc.AutoModeSwitchRequest; import com.github.copilot.rpc.ExitPlanModeRequest; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.ProviderTokenArgs; import com.github.copilot.rpc.PermissionRequestResult; import com.github.copilot.rpc.PermissionRequestResultKind; @@ -321,14 +321,15 @@ private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, Js return; } - GetBearerToken provider = session.getBearerTokenProvider(providerName); + BearerTokenProvider provider = session.getBearerTokenProvider(providerName); if (provider == null) { rpc.sendErrorResponse(requestIdLong, -32603, "No bearer-token provider registered for provider " + providerName); return; } - CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName)); + CompletableFuture tokenFuture = provider + .getToken(new ProviderTokenArgs(providerName, sessionId)); if (tokenFuture == null) { rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider returned null future for provider " + providerName); diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index 943894d28..6000bdef8 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -13,7 +13,7 @@ import com.github.copilot.rpc.CreateSessionRequest; import com.github.copilot.rpc.ProviderConfig; import com.github.copilot.rpc.NamedProviderConfig; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.CommandWireDefinition; import com.github.copilot.rpc.ResumeSessionConfig; import com.github.copilot.rpc.ResumeSessionRequest; @@ -335,7 +335,7 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } - Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), config.getProviders()); if (!bearerTokenProviders.isEmpty()) { session.registerBearerTokenProviders(bearerTokenProviders); @@ -382,7 +382,7 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } - Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), config.getProviders()); if (!bearerTokenProviders.isEmpty()) { session.registerBearerTokenProviders(bearerTokenProviders); @@ -398,17 +398,17 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) } } - private static Map collectBearerTokenProviders(ProviderConfig provider, + private static Map collectBearerTokenProviders(ProviderConfig provider, List providers) { - Map bearerTokenProviders = new HashMap<>(); - if (provider != null && provider.getGetBearerToken() != null) { - bearerTokenProviders.put("default", provider.getGetBearerToken()); + Map bearerTokenProviders = new HashMap<>(); + if (provider != null && provider.getBearerTokenProvider() != null) { + bearerTokenProviders.put("default", provider.getBearerTokenProvider()); } if (providers != null) { for (NamedProviderConfig namedProvider : providers) { if (namedProvider != null && namedProvider.getName() != null - && namedProvider.getGetBearerToken() != null) { - bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken()); + && namedProvider.getBearerTokenProvider() != null) { + bearerTokenProviders.put(namedProvider.getName(), namedProvider.getBearerTokenProvider()); } } } diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java similarity index 87% rename from java/src/main/java/com/github/copilot/rpc/GetBearerToken.java rename to java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java index 27ec7f09c..7b37925aa 100644 --- a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java +++ b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java @@ -20,13 +20,13 @@ * Experimental. This managed-identity surface may change or be * removed in future SDK or CLI releases. * - * @see ProviderConfig#setGetBearerToken(GetBearerToken) - * @see NamedProviderConfig#setGetBearerToken(GetBearerToken) + * @see ProviderConfig#setBearerTokenProvider(BearerTokenProvider) + * @see NamedProviderConfig#setBearerTokenProvider(BearerTokenProvider) * @since 1.0.0 */ @CopilotExperimental @FunctionalInterface -public interface GetBearerToken { +public interface BearerTokenProvider { /** * Gets a bearer token for the provider identified by {@code args}. diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java index 2bdf2678f..e3b090019 100644 --- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java @@ -61,7 +61,7 @@ public class NamedProviderConfig { private String bearerToken; @JsonIgnore - private GetBearerToken getBearerToken; + private BearerTokenProvider bearerTokenProvider; @JsonProperty("azure") private AzureOptions azure; @@ -221,8 +221,8 @@ public NamedProviderConfig setBearerToken(String bearerToken) { * * @return the bearer-token provider callback, or {@code null} if not set */ - public GetBearerToken getGetBearerToken() { - return getBearerToken; + public BearerTokenProvider getBearerTokenProvider() { + return bearerTokenProvider; } /** @@ -234,19 +234,19 @@ public GetBearerToken getGetBearerToken() { * RPC before each model request. Return the raw token without a {@code Bearer } * prefix. * - * @param getBearerToken + * @param bearerTokenProvider * the bearer-token provider callback * @return this config for method chaining */ - public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { - this.getBearerToken = getBearerToken; + public NamedProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) { + this.bearerTokenProvider = bearerTokenProvider; return this; } @JsonProperty("hasBearerTokenProvider") @JsonInclude(JsonInclude.Include.NON_NULL) Boolean hasBearerTokenProviderWireFlag() { - return getBearerToken != null ? Boolean.TRUE : null; + return bearerTokenProvider != null ? Boolean.TRUE : null; } /** diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java index ae59e7ead..3d6faba34 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java @@ -57,7 +57,7 @@ public class ProviderConfig { private String bearerToken; @JsonIgnore - private GetBearerToken getBearerToken; + private BearerTokenProvider bearerTokenProvider; @JsonProperty("azure") private AzureOptions azure; @@ -230,8 +230,8 @@ public ProviderConfig setBearerToken(String bearerToken) { * * @return the bearer-token provider callback, or {@code null} if not set */ - public GetBearerToken getGetBearerToken() { - return getBearerToken; + public BearerTokenProvider getBearerTokenProvider() { + return bearerTokenProvider; } /** @@ -243,19 +243,19 @@ public GetBearerToken getGetBearerToken() { * RPC before each model request. Return the raw token without a {@code Bearer } * prefix. * - * @param getBearerToken + * @param bearerTokenProvider * the bearer-token provider callback * @return this config for method chaining */ - public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { - this.getBearerToken = getBearerToken; + public ProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) { + this.bearerTokenProvider = bearerTokenProvider; return this; } @JsonProperty("hasBearerTokenProvider") @JsonInclude(JsonInclude.Include.NON_NULL) Boolean hasBearerTokenProviderWireFlag() { - return getBearerToken != null ? Boolean.TRUE : null; + return bearerTokenProvider != null ? Boolean.TRUE : null; } /** diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java index 3866cc0ad..009734ad1 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -17,13 +17,9 @@ @CopilotExperimental public class ProviderTokenArgs { - private String providerName; + private final String providerName; - /** - * Creates an empty argument object. - */ - public ProviderTokenArgs() { - } + private final String sessionId; /** * Creates argument object for the named provider. @@ -32,9 +28,12 @@ public ProviderTokenArgs() { * the name of the BYOK provider needing a token; {@code "default"} * for the singular whole-session provider, otherwise the named * provider's {@code name} + * @param sessionId + * the id of the session that triggered this token request */ - public ProviderTokenArgs(String providerName) { + public ProviderTokenArgs(String providerName, String sessionId) { this.providerName = providerName; + this.sessionId = sessionId; } /** @@ -50,14 +49,15 @@ public String getProviderName() { } /** - * Sets the name of the BYOK provider needing a token. + * Gets the id of the session that triggered this token request. + *

+ * A client-level shared callback registered for many sessions can use this to + * resolve the owning session and scope token acquisition or caching per + * session. * - * @param providerName - * the provider name - * @return this args instance for method chaining + * @return the session id */ - public ProviderTokenArgs setProviderName(String providerName) { - this.providerName = providerName; - return this; + public String getSessionId() { + return sessionId; } } diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java index 253ce136c..b035bd54d 100644 --- a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -34,7 +34,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.MessageOptions; import com.github.copilot.rpc.NamedProviderConfig; import com.github.copilot.rpc.PermissionHandler; @@ -43,9 +43,9 @@ /** * End-to-end coverage for the experimental BYOK bearer-token-provider surface - * ({@code getBearerToken} on a provider config). The callback stays entirely on - * the SDK/client side: the SDK keeps it off the wire, sends only the - * {@code hasBearerTokenProvider} flag, and the runtime calls back over the + * ({@code BearerTokenProvider} on a provider config). The callback stays + * entirely on the SDK/client side: the SDK keeps it off the wire, sends only + * the {@code hasBearerTokenProvider} flag, and the runtime calls back over the * session-scoped {@code providerToken.getToken} RPC before each outbound model * request. */ @@ -82,13 +82,13 @@ void resetHandler() { void appliesCallbackTokenAsAuthorizationHeader() throws Exception { String sentinel = "sentinel-bearer-token-abc123"; AtomicInteger calls = new AtomicInteger(); - GetBearerToken getBearerToken = args -> { + BearerTokenProvider tokenProvider = args -> { calls.incrementAndGet(); return CompletableFuture.completedFuture(sentinel); }; List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") - .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider)); List models = List .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); @@ -102,11 +102,11 @@ void appliesCallbackTokenAsAuthorizationHeader() throws Exception { @Test void reacquiresFreshTokenForEachRequest() throws Exception { AtomicInteger calls = new AtomicInteger(); - GetBearerToken getBearerToken = args -> CompletableFuture + BearerTokenProvider tokenProvider = args -> CompletableFuture .completedFuture("rotating-token-" + calls.incrementAndGet()); List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") - .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider)); List models = List .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); @@ -124,15 +124,19 @@ void reacquiresFreshTokenForEachRequest() throws Exception { @Test void dispatchesTokenAcquisitionPerProvider() throws Exception { List acquiredFor = new ArrayList<>(); - GetBearerToken redCallback = args -> { + BearerTokenProvider redCallback = args -> { assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded"); + assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), + "Expected a non-empty session id in token args"); synchronized (acquiredFor) { acquiredFor.add("red"); } return CompletableFuture.completedFuture("token-for-red"); }; - GetBearerToken blueCallback = args -> { + BearerTokenProvider blueCallback = args -> { assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded"); + assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), + "Expected a non-empty session id in token args"); synchronized (acquiredFor) { acquiredFor.add("blue"); } @@ -141,9 +145,9 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception { List providers = List.of( new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions") - .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback), + .setBaseUrl(RED_BASE_URL).setBearerTokenProvider(redCallback), new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions") - .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback)); + .setBaseUrl(BLUE_BASE_URL).setBearerTokenProvider(blueCallback)); List models = List.of( new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"), new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o")); diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index e0315b211..53686a6ca 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -51,7 +51,7 @@ import type { ExitPlanModeResult, ForegroundSessionInfo, GetAuthStatusResponse, - GetBearerToken, + BearerTokenProvider, GetStatusResponse, InternalRuntimeConnection, LargeToolOutputConfig, @@ -161,17 +161,17 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | const DEFAULT_PROVIDER_NAME = "default"; /** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */ -type WireProviderConfig = Omit & { +type WireProviderConfig = Omit & { hasBearerTokenProvider?: boolean; }; /** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */ -type WireNamedProviderConfig = Omit & { +type WireNamedProviderConfig = Omit & { hasBearerTokenProvider?: boolean; }; /** - * Strips the non-serializable {@link GetBearerToken} callbacks from the singular + * Strips the non-serializable {@link BearerTokenProvider} callbacks from the singular * and named provider configs before they cross the RPC boundary, replacing each * with a `hasBearerTokenProvider: true` wire flag. The callback closes over its * own token scope/audience, so nothing scope-related crosses the wire — the @@ -185,14 +185,14 @@ function extractBearerTokenProviders( ): { wireProvider: WireProviderConfig | undefined; wireProviders: WireNamedProviderConfig[] | undefined; - callbacks: Map; + callbacks: Map; } { - const callbacks = new Map(); + const callbacks = new Map(); let wireProvider: WireProviderConfig | undefined = provider; - if (provider?.getBearerToken) { - const { getBearerToken, ...rest } = provider; - callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken); + if (provider?.bearerTokenProvider) { + const { bearerTokenProvider, ...rest } = provider; + callbacks.set(DEFAULT_PROVIDER_NAME, bearerTokenProvider); wireProvider = { ...rest, hasBearerTokenProvider: true, @@ -200,11 +200,11 @@ function extractBearerTokenProviders( } let wireProviders: WireNamedProviderConfig[] | undefined = providers; - if (providers?.some((p) => p.getBearerToken)) { + if (providers?.some((p) => p.bearerTokenProvider)) { wireProviders = providers.map((p) => { - if (!p.getBearerToken) return p; - const { getBearerToken, ...rest } = p; - callbacks.set(p.name, getBearerToken); + if (!p.bearerTokenProvider) return p; + const { bearerTokenProvider, ...rest } = p; + callbacks.set(p.name, bearerTokenProvider); return { ...rest, hasBearerTokenProvider: true, @@ -1305,7 +1305,7 @@ export class CopilotClient { const useServerGeneratedId = config.cloud != null && callerSessionId == null; const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID()); - // Strip non-serializable getBearerToken callbacks from provider configs, + // Strip non-serializable bearerTokenProvider callbacks from provider configs, // replacing them with a wire flag; keep the callbacks for session-side // registration so the runtime can call back to acquire tokens. const { diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 740a7bc89..eebf9add5 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -84,7 +84,7 @@ export type { MCPHTTPServerConfig, MCPServerConfig, DefaultAgentConfig, - GetBearerToken, + BearerTokenProvider, MessageOptions, ModelBilling, ModelBillingTokenPrices, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index d87d2b9de..8bf9589c3 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -26,7 +26,7 @@ import type { ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, - GetBearerToken, + BearerTokenProvider, UiInputOptions, MessageOptions, PermissionHandler, @@ -121,7 +121,7 @@ export class CopilotSession { new Map(); private toolHandlers: Map = new Map(); private canvases: Map = new Map(); - private bearerTokenProviders: Map = new Map(); + private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; @@ -798,7 +798,7 @@ export class CopilotSession { } /** - * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers + * Registers per-provider {@link BearerTokenProvider} callbacks for BYOK providers * configured with managed-identity / on-demand bearer-token auth. * * The runtime never receives the callback itself; the SDK strips it from the @@ -809,7 +809,7 @@ export class CopilotSession { * @param providers - Map of provider name → callback, or undefined/empty to clear. * @internal This method is called internally when creating/resuming a session. */ - registerBearerTokenProviders(providers?: Map): void { + registerBearerTokenProviders(providers?: Map): void { this.bearerTokenProviders.clear(); if (!providers || providers.size === 0) { delete this.clientSessionApis.providerToken; @@ -830,6 +830,7 @@ export class CopilotSession { } const token = await callback({ providerName: params.providerName, + sessionId: params.sessionId, }); return { token }; }, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 61d9ca06d..e354bd821 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2215,7 +2215,7 @@ export interface ResumeSessionConfig extends SessionConfigBase { } /** - * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a + * Arguments passed to a {@link BearerTokenProvider} callback when the runtime needs a * fresh bearer token for a BYOK provider. * * @experimental Part of the experimental managed-identity / bearer-token-provider @@ -2230,7 +2230,15 @@ export interface ProviderTokenArgs { * The callback closes over its own token scope/audience; the runtime is * provider-agnostic and forwards only the provider name. */ - providerName: string; + readonly providerName: string; + + /** + * Id of the session that triggered this token request. A client-level shared + * callback registered for many sessions can use this to resolve the owning + * session (e.g. via the client's session lookup) to scope token acquisition + * or caching per session. + */ + readonly sessionId: string; } /** @@ -2245,7 +2253,7 @@ export interface ProviderTokenArgs { * @experimental Part of the experimental managed-identity / bearer-token-provider * surface and may change or be removed in future SDK or CLI releases. */ -export type GetBearerToken = (args: ProviderTokenArgs) => Promise; +export type BearerTokenProvider = (args: ProviderTokenArgs) => Promise; /** * Configuration for a custom API provider. @@ -2294,12 +2302,14 @@ export interface ProviderConfig { * When set, the SDK keeps this function client-side (it is never serialized) * and the runtime calls back into this client to acquire a token before each * outbound request. The runtime does no caching of its own, so the callback - * owns token caching and refresh. Mutually exclusive with {@link apiKey} / - * {@link bearerToken}. + * owns token caching and refresh. When set alongside {@link apiKey} / + * {@link bearerToken}, this callback takes precedence: the runtime applies + * the token it returns as the `Authorization: Bearer` header for each + * request and does not send the static credential. * * @experimental */ - getBearerToken?: GetBearerToken; + bearerTokenProvider?: BearerTokenProvider; /** * Azure-specific options @@ -2397,12 +2407,14 @@ export interface NamedProviderConfig { * When set, the SDK keeps this function client-side (it is never serialized) * and the runtime calls back into this client to acquire a token before each * outbound request. The runtime does no caching of its own, so the callback - * owns token caching and refresh. Mutually exclusive with {@link apiKey} / - * {@link bearerToken}. + * owns token caching and refresh. When set alongside {@link apiKey} / + * {@link bearerToken}, this callback takes precedence: the runtime applies + * the token it returns as the `Authorization: Bearer` header for each + * request and does not send the static credential. * * @experimental */ - getBearerToken?: GetBearerToken; + bearerTokenProvider?: BearerTokenProvider; /** * Azure-specific options. diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts index 228b7a022..c528fb23d 100644 --- a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -6,7 +6,7 @@ import { beforeEach, describe, expect, it } from "vitest"; import { approveAll, CopilotRequestHandler } from "../../src/index.js"; import type { CopilotRequestContext, - GetBearerToken, + BearerTokenProvider, NamedProviderConfig, ProviderModelConfig, } from "../../src/index.js"; @@ -40,7 +40,7 @@ const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`; * The runtime invokes {@link sendRequest} for every model-layer HTTP request it * would otherwise issue. We capture the ones aimed at a fake BYOK host — * recording the `Authorization` header the runtime applied after calling the - * provider's `getBearerToken` callback over the session-scoped + * provider's `bearerTokenProvider` callback over the session-scoped * `providerToken.getToken` RPC — and answer them with a synthetic `404` (a * non-retryable status, so each outbound model request yields exactly one * capture). Every other request (CAPI bootstrap: model catalog, policy, …) is @@ -89,7 +89,7 @@ class CapturingRequestHandler extends CopilotRequestHandler { /** * End-to-end coverage for the experimental BYOK bearer-token-provider surface - * (`getBearerToken` on a provider config). The callback stays entirely on the + * (`bearerTokenProvider` on a provider config). The callback stays entirely on the * SDK/client side: the SDK strips it from the wire config, sets the * `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped * `providerToken.getToken` RPC before each outbound model request, applying the @@ -144,7 +144,7 @@ describe("BYOK bearer-token provider", async () => { it("applies the callback's token as the Authorization header", async () => { const SENTINEL = "sentinel-bearer-token-abc123"; let calls = 0; - const getBearerToken: GetBearerToken = async () => { + const getBearerToken: BearerTokenProvider = async () => { calls += 1; return SENTINEL; }; @@ -155,7 +155,7 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: PRIMARY_BASE_URL, - getBearerToken, + bearerTokenProvider: getBearerToken, }, ]; const models: ProviderModelConfig[] = [ @@ -172,7 +172,7 @@ describe("BYOK bearer-token provider", async () => { it("re-acquires a fresh token for each request (no runtime caching)", async () => { let calls = 0; - const getBearerToken: GetBearerToken = async () => { + const getBearerToken: BearerTokenProvider = async () => { calls += 1; // A distinct token per acquisition proves the runtime re-invokes the // callback per request rather than caching a previous token. @@ -185,7 +185,7 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: PRIMARY_BASE_URL, - getBearerToken, + bearerTokenProvider: getBearerToken, }, ]; const models: ProviderModelConfig[] = [ @@ -211,11 +211,15 @@ describe("BYOK bearer-token provider", async () => { }; const acquiredFor: string[] = []; const makeCallback = - (providerName: string): GetBearerToken => + (providerName: string): BearerTokenProvider => async (args) => { // The runtime forwards the requesting provider's name so the client // can dispatch to the right credential. expect(args.providerName).toBe(providerName); + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + expect(typeof args.sessionId).toBe("string"); + expect(args.sessionId.length).toBeGreaterThan(0); acquiredFor.push(providerName); return tokenByProvider[providerName]; }; @@ -226,14 +230,14 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: RED_BASE_URL, - getBearerToken: makeCallback("red"), + bearerTokenProvider: makeCallback("red"), }, { name: "blue", type: "openai", wireApi: "completions", baseUrl: BLUE_BASE_URL, - getBearerToken: makeCallback("blue"), + bearerTokenProvider: makeCallback("blue"), }, ]; const models: ProviderModelConfig[] = [ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 1e7a3afb1..ff13d47de 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -86,6 +86,7 @@ AutoModeSwitchHandler, AutoModeSwitchRequest, AutoModeSwitchResponse, + BearerTokenProvider, CommandContext, CommandDefinition, CopilotSession, @@ -100,7 +101,6 @@ ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, - GetBearerToken, InfiniteSessionConfig, InputOptions, LargeToolOutputConfig, @@ -216,7 +216,7 @@ "ExtensionInfo", "CopilotWebSocketForwarder", "GetAuthStatusResponse", - "GetBearerToken", + "BearerTokenProvider", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", diff --git a/python/copilot/client.py b/python/copilot/client.py index ebfdcf992..c7d11d12b 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -81,6 +81,7 @@ ) from .session import ( AutoModeSwitchHandler, + BearerTokenProvider, CommandDefinition, ContextTier, CopilotSession, @@ -89,7 +90,6 @@ DefaultAgentConfig, ElicitationHandler, ExitPlanModeHandler, - GetBearerToken, InfiniteSessionConfig, LargeToolOutputConfig, MCPServerConfig, @@ -180,8 +180,8 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any] def _collect_bearer_token_callbacks( provider: ProviderConfig | None, providers: list[NamedProviderConfig] | None, -) -> dict[str, GetBearerToken]: - """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name. +) -> dict[str, BearerTokenProvider]: + """Collect per-provider ``bearer_token_provider`` callbacks keyed by provider name. The singular, whole-session ``provider`` uses the implicit ``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own @@ -189,14 +189,14 @@ def _collect_bearer_token_callbacks( ``hasBearerTokenProvider: true`` instead and the runtime calls back over ``providerToken.getToken``. """ - callbacks: dict[str, GetBearerToken] = {} + callbacks: dict[str, BearerTokenProvider] = {} if provider is not None: - singular = provider.get("get_bearer_token") + singular = provider.get("bearer_token_provider") if singular is not None: callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular if providers: for named in providers: - callback = named.get("get_bearer_token") + callback = named.get("bearer_token_provider") if callback is not None: callbacks[named["name"]] = callback return callbacks @@ -3266,7 +3266,7 @@ def _convert_provider_to_wire_format( wire_provider["transport"] = provider["transport"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] - if provider.get("get_bearer_token") is not None: + if provider.get("bearer_token_provider") is not None: wire_provider["hasBearerTokenProvider"] = True if "headers" in provider: wire_provider["headers"] = provider["headers"] @@ -3304,7 +3304,7 @@ def _convert_named_provider_to_wire_format( wire["apiKey"] = provider["api_key"] if "bearer_token" in provider: wire["bearerToken"] = provider["bearer_token"] - if provider.get("get_bearer_token") is not None: + if provider.get("bearer_token_provider") is not None: wire["hasBearerTokenProvider"] = True if "headers" in provider: wire["headers"] = provider["headers"] diff --git a/python/copilot/session.py b/python/copilot/session.py index 94fba994a..0dc569f25 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1080,7 +1080,7 @@ class AzureProviderOptions(TypedDict, total=False): class ProviderTokenArgs(TypedDict): - """Arguments passed to a :data:`GetBearerToken` callback when the runtime + """Arguments passed to a :data:`BearerTokenProvider` callback when the runtime needs a fresh bearer token for a BYOK provider. **Experimental.** Part of the bearer-token-provider surface and may change or @@ -1092,6 +1092,11 @@ class ProviderTokenArgs(TypedDict): # ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``. provider_name: str + # Id of the session that triggered this token request. A client-level shared + # callback registered for many sessions can use this to resolve the owning + # session and scope token acquisition or caching per session. + session_id: str + # Per-request callback that resolves a bearer token on demand for a BYOK # provider (for example via Azure Managed Identity). The Copilot SDK takes no @@ -1099,7 +1104,7 @@ class ProviderTokenArgs(TypedDict): # Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on # the wire and answer the runtime's ``providerToken.getToken`` requests. May be # sync or async. -GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]] +BearerTokenProvider = Callable[[ProviderTokenArgs], str | Awaitable[str]] class ProviderConfig(TypedDict, total=False): @@ -1142,8 +1147,10 @@ class ProviderConfig(TypedDict, total=False): # provider (for example via Azure Managed Identity). Never serialized — the # SDK sends hasBearerTokenProvider: true on the wire and answers the # runtime's providerToken.getToken requests with this callback's result. - # Mutually exclusive with api_key and bearer_token. - get_bearer_token: GetBearerToken + # When set alongside api_key/bearer_token, this callback takes precedence: the + # runtime applies the token it returns as the Authorization: Bearer header for + # each request and does not send the static credential. + bearer_token_provider: BearerTokenProvider class NamedProviderConfig(TypedDict, total=False): @@ -1172,9 +1179,11 @@ class NamedProviderConfig(TypedDict, total=False): headers: dict[str, str] # Per-request bearer-token callback for this named BYOK provider. Never # serialized; the SDK sends hasBearerTokenProvider: true and answers the - # runtime's providerToken.getToken requests. Mutually exclusive with api_key - # and bearer_token. - get_bearer_token: GetBearerToken + # runtime's providerToken.getToken requests. When set alongside + # api_key/bearer_token, this callback takes precedence: the runtime applies + # the token it returns as the Authorization: Bearer header for each request + # and does not send the static credential. + bearer_token_provider: BearerTokenProvider class ProviderModelConfig(TypedDict, total=False): @@ -1248,7 +1257,7 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError: class _BearerTokenProviderAdapter: """Routes runtime ``providerToken.getToken`` requests to the matching - per-provider :data:`GetBearerToken` callback registered on the session. + per-provider :data:`BearerTokenProvider` callback registered on the session. The runtime calls this once per outbound request for a BYOK provider that declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK @@ -1268,7 +1277,10 @@ async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenA -32603, f"No bearer-token provider registered for provider: {provider_name!r}", ) - args: ProviderTokenArgs = {"provider_name": provider_name} + args: ProviderTokenArgs = { + "provider_name": provider_name, + "session_id": params.session_id, + } result = callback(args) if inspect.isawaitable(result): result = await result @@ -1340,7 +1352,7 @@ def __init__( self._transform_callbacks_lock = threading.Lock() self._command_handlers: dict[str, CommandHandler] = {} self._command_handlers_lock = threading.Lock() - self._bearer_token_providers: dict[str, GetBearerToken] = {} + self._bearer_token_providers: dict[str, BearerTokenProvider] = {} self._bearer_token_providers_lock = threading.Lock() self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() @@ -2082,7 +2094,9 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None: for cmd in commands: self._command_handlers[cmd.name] = cmd.handler - def _register_bearer_token_providers(self, providers: dict[str, GetBearerToken] | None) -> None: + def _register_bearer_token_providers( + self, providers: dict[str, BearerTokenProvider] | None + ) -> None: """Register per-provider bearer-token callbacks for this session. The runtime never receives the callbacks themselves; the SDK strips them diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py index 28f9e0586..37dfbc009 100644 --- a/python/e2e/test_byok_bearer_token_provider_e2e.py +++ b/python/e2e/test_byok_bearer_token_provider_e2e.py @@ -5,7 +5,7 @@ """E2E coverage for the experimental BYOK bearer-token-provider surface. Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK -provider config may carry a ``get_bearer_token`` callback; the callback stays +provider config may carry a ``bearer_token_provider`` callback; the callback stays entirely on the SDK/client side. The SDK strips it from the wire config, sets the ``hasBearerTokenProvider`` flag, and the runtime calls back over the session-scoped ``providerToken.getToken`` RPC before each outbound model @@ -31,7 +31,7 @@ import pytest_asyncio from copilot import CopilotRequestContext, CopilotRequestHandler -from copilot.session import GetBearerToken, PermissionHandler +from copilot.session import BearerTokenProvider, PermissionHandler from ._copilot_request_helpers import build_isolated_client, build_non_inference_response from .testharness import E2ETestContext @@ -56,7 +56,7 @@ class _CapturingRequestHandler(CopilotRequestHandler): The runtime invokes :meth:`send_request` for every model-layer HTTP request. Requests aimed at a fake BYOK host are captured — recording the ``Authorization`` header the runtime applied after calling the provider's - ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered + ``bearer_token_provider`` callback over ``providerToken.getToken`` — and answered with a synthetic ``404`` (non-retryable, so each outbound model request yields exactly one capture). Every other request (CAPI bootstrap: model catalog, policy, …) is fabricated locally so no real network or CAPI proxy @@ -126,6 +126,7 @@ async def _run_turn(client, providers, models, selection_id: str, prompt: str) - try: await session.send_and_wait(prompt) except Exception: + # The fake BYOK endpoint intentionally errors after capture. pass finally: try: @@ -154,7 +155,7 @@ async def get_bearer_token(args) -> str: "type": "openai", "wire_api": "completions", "base_url": PRIMARY_BASE_URL, - "get_bearer_token": get_bearer_token, + "bearer_token_provider": get_bearer_token, } ] models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] @@ -185,7 +186,7 @@ async def get_bearer_token(args) -> str: "type": "openai", "wire_api": "completions", "base_url": PRIMARY_BASE_URL, - "get_bearer_token": get_bearer_token, + "bearer_token_provider": get_bearer_token, } ] models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] @@ -208,11 +209,14 @@ async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture): token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"} acquired_for: list[str] = [] - def make_callback(provider_name: str) -> GetBearerToken: + def make_callback(provider_name: str) -> BearerTokenProvider: async def callback(args) -> str: # The runtime forwards the requesting provider's name so the # client can dispatch to the right credential. assert args["provider_name"] == provider_name + # The runtime also forwards the owning session id so a + # client-level shared callback can resolve the session. + assert isinstance(args["session_id"], str) and args["session_id"] acquired_for.append(provider_name) return token_by_provider[provider_name] @@ -224,14 +228,14 @@ async def callback(args) -> str: "type": "openai", "wire_api": "completions", "base_url": RED_BASE_URL, - "get_bearer_token": make_callback("red"), + "bearer_token_provider": make_callback("red"), }, { "name": "blue", "type": "openai", "wire_api": "completions", "base_url": BLUE_BASE_URL, - "get_bearer_token": make_callback("blue"), + "bearer_token_provider": make_callback("blue"), }, ] models = [ diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs index f92715006..a8b75f196 100644 --- a/rust/src/provider_token.rs +++ b/rust/src/provider_token.rs @@ -30,6 +30,13 @@ pub struct ProviderTokenArgs { /// This is `"default"` for the singular whole-session provider, otherwise /// the named provider's `name`. pub provider_name: String, + + /// Id of the session that triggered this token request. + /// + /// A client-level shared callback registered for many sessions can use this + /// to resolve the owning session and scope token acquisition or caching per + /// session. + pub session_id: String, } /// Error returned by a [`BearerTokenProvider`]. diff --git a/rust/src/provider_token_dispatch.rs b/rust/src/provider_token_dispatch.rs index c100443cd..0631260a4 100644 --- a/rust/src/provider_token_dispatch.rs +++ b/rust/src/provider_token_dispatch.rs @@ -119,6 +119,7 @@ async fn get_token( match token_provider .get_token(ProviderTokenArgs { provider_name: params.provider_name, + session_id: params.session_id.into_inner(), }) .await { diff --git a/rust/src/types.rs b/rust/src/types.rs index d62e20f70..75408db02 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1053,7 +1053,7 @@ pub struct ProviderConfig { /// **Experimental.** Callback used to acquire a bearer token before each /// outbound request to this provider. #[serde(skip)] - pub get_bearer_token: Option>, + pub bearer_token_provider: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. @@ -1097,8 +1097,8 @@ impl std::fmt::Debug for ProviderConfig { .field("api_key", &self.api_key) .field("bearer_token", &self.bearer_token) .field( - "get_bearer_token", - &self.get_bearer_token.as_ref().map(|_| ""), + "bearer_token_provider", + &self.bearer_token_provider.as_ref().map(|_| ""), ) .field("has_bearer_token_provider", &self.has_bearer_token_provider) .field("azure", &self.azure) @@ -1158,8 +1158,8 @@ impl ProviderConfig { /// /// **Experimental.** This method is part of an experimental wire-protocol /// surface and may change or be removed in a future release. - pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { - self.get_bearer_token = Some(provider); + pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self { + self.bearer_token_provider = Some(provider); self } @@ -1291,7 +1291,7 @@ pub struct NamedProviderConfig { /// **Experimental.** Callback used to acquire a bearer token before each /// outbound request to this provider. #[serde(skip)] - pub get_bearer_token: Option>, + pub bearer_token_provider: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. @@ -1312,8 +1312,8 @@ impl std::fmt::Debug for NamedProviderConfig { .field("api_key", &self.api_key) .field("bearer_token", &self.bearer_token) .field( - "get_bearer_token", - &self.get_bearer_token.as_ref().map(|_| ""), + "bearer_token_provider", + &self.bearer_token_provider.as_ref().map(|_| ""), ) .field("has_bearer_token_provider", &self.has_bearer_token_provider) .field("azure", &self.azure) @@ -1363,8 +1363,8 @@ impl NamedProviderConfig { /// /// **Experimental.** This method is part of an experimental wire-protocol /// surface and may change or be removed in a future release. - pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { - self.get_bearer_token = Some(provider); + pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self { + self.bearer_token_provider = Some(provider); self } @@ -1388,7 +1388,7 @@ fn prepare_bearer_token_providers( let mut bearer_token_providers = HashMap::new(); if let Some(provider) = provider.as_mut() - && let Some(token_provider) = provider.get_bearer_token.take() + && let Some(token_provider) = provider.bearer_token_provider.take() { provider.has_bearer_token_provider = Some(true); bearer_token_providers.insert("default".to_string(), token_provider); @@ -1396,7 +1396,7 @@ fn prepare_bearer_token_providers( if let Some(providers) = providers.as_mut() { for provider in providers { - if let Some(token_provider) = provider.get_bearer_token.take() { + if let Some(token_provider) = provider.bearer_token_provider.take() { provider.has_bearer_token_provider = Some(true); bearer_token_providers.insert(provider.name.clone(), token_provider); } diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs index c3cd9ef4b..fc3ef89d9 100644 --- a/rust/tests/e2e/byok_bearer_token_provider.rs +++ b/rust/tests/e2e/byok_bearer_token_provider.rs @@ -155,7 +155,7 @@ async fn callback_token_is_applied_as_authorization_header() { NamedProviderConfig::new("mi", PRIMARY_BASE_URL) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| { let callback_calls = callback_calls.clone(); async move { callback_calls.fetch_add(1, Ordering::SeqCst); @@ -202,7 +202,7 @@ async fn reacquires_a_fresh_token_for_each_request() { NamedProviderConfig::new("mi", PRIMARY_BASE_URL) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| { let callback_calls = callback_calls.clone(); async move { let call = callback_calls.fetch_add(1, Ordering::SeqCst) + 1; @@ -266,10 +266,14 @@ async fn dispatches_token_acquisition_per_provider() { NamedProviderConfig::new(name, base_url) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |args: ProviderTokenArgs| { let acquired_for = acquired_for.clone(); async move { assert_eq!(args.provider_name, name); + assert!( + !args.session_id.is_empty(), + "expected a non-empty session id in token args" + ); acquired_for.lock().unwrap().push(name.to_string()); Ok::<_, BearerTokenError>(token.to_string()) }