From c9626e633a9b5d799fa933f659f3c8fc3e42c575 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 25 Jun 2026 09:33:53 +0100 Subject: [PATCH 1/4] Add sessionId to BYOK provider-token callback args across all SDKs Surface the owning session id on the hand-written ProviderTokenArgs in every language SDK (Rust, Node, Go, Python, Java, .NET) and wire the dispatch glue to populate it from the generated request (which the runtime already injects on the wire for session-scoped RPCs). A client-level shared getBearerToken callback registered across many sessions can now resolve which session triggered the request and scope token acquisition or caching accordingly. No runtime changes required: the runtime already merges sessionId into the wire params for every clientSession-scoped method, and codegen already includes it on each generated ProviderTokenAcquireRequest. Only the consumer-facing args type and the glue that builds it from the request were missing the field. Per-language e2e tests now assert a non-empty sessionId is surfaced. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/BearerTokenProvider.cs | 7 ++++ dotnet/src/Session.cs | 2 +- .../E2E/ByokBearerTokenProviderE2ETests.cs | 3 ++ .../byok_bearer_token_provider_e2e_test.go | 5 +++ go/session.go | 2 +- go/types.go | 6 ++++ .../github/copilot/RpcHandlerDispatcher.java | 2 +- .../github/copilot/rpc/ProviderTokenArgs.java | 32 ++++++++++++++++++- .../ByokBearerTokenProviderE2ETest.java | 4 +++ nodejs/src/session.ts | 1 + nodejs/src/types.ts | 8 +++++ .../byok_bearer_token_provider.e2e.test.ts | 4 +++ python/copilot/session.py | 10 +++++- .../test_byok_bearer_token_provider_e2e.py | 3 ++ rust/src/provider_token.rs | 7 ++++ rust/src/provider_token_dispatch.rs | 1 + rust/tests/e2e/byok_bearer_token_provider.rs | 4 +++ 17 files changed, 96 insertions(+), 5 deletions(-) diff --git a/dotnet/src/BearerTokenProvider.cs b/dotnet/src/BearerTokenProvider.cs index 2c59da09b..1ed11198f 100644 --- a/dotnet/src/BearerTokenProvider.cs +++ b/dotnet/src/BearerTokenProvider.cs @@ -29,4 +29,11 @@ public sealed class ProviderTokenArgs /// provider-agnostic and forwards only the provider name. /// public required string ProviderName { get; init; } + + /// + /// Id of the session that triggered this token request. A client-level + /// shared callback registered for many sessions can use this to resolve the + /// owning session and scope token acquisition or caching per session. + /// + public required string SessionId { get; init; } } diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index f8e285eab..5d5661e65 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -910,7 +910,7 @@ public async Task GetTokenAsync(ProviderTokenAcquire throw new InvalidOperationException( $"No bearer-token provider registered for provider \"{request.ProviderName}\""); } - var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false); + var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName, SessionId = request.SessionId }).ConfigureAwait(false); return new ProviderTokenAcquireResult { Token = token }; } } diff --git a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs index 3f869a437..0cb0e6e16 100644 --- a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs +++ b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs @@ -189,6 +189,9 @@ Func> MakeCallback(string providerName) => // The runtime forwards the requesting provider's name so the client // can dispatch to the right credential. Assert.Equal(providerName, args.ProviderName); + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + Assert.False(string.IsNullOrEmpty(args.SessionId)); acquiredFor.Add(providerName); return Task.FromResult(tokenByProvider[providerName]); }; diff --git a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go index 6a6e5cbc2..db9864a4b 100644 --- a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go +++ b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go @@ -234,6 +234,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) { if args.ProviderName != providerName { t.Errorf("Expected providerName %q, got %q", providerName, args.ProviderName) } + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + if args.SessionID == "" { + t.Errorf("Expected a non-empty session id in token args") + } mu.Lock() acquiredFor = append(acquiredFor, providerName) mu.Unlock() diff --git a/go/session.go b/go/session.go index d92466d8e..da82c6a78 100644 --- a/go/session.go +++ b/go/session.go @@ -229,7 +229,7 @@ func (a *providerTokenClientSessionAdapter) GetToken(request *rpc.ProviderTokenA if callback == nil { return nil, providerTokenJSONRPCError(fmt.Sprintf("No bearer-token provider registered for provider %q", request.ProviderName)) } - token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName}) + token, err := callback(ProviderTokenArgs{ProviderName: request.ProviderName, SessionID: request.SessionID}) if err != nil { return nil, providerTokenJSONRPCError(err.Error()) } diff --git a/go/types.go b/go/types.go index 52a06e110..a7007c513 100644 --- a/go/types.go +++ b/go/types.go @@ -1579,6 +1579,12 @@ type ProviderTokenArgs struct { // The callback closes over its own token scope/audience; the runtime is // provider-agnostic and forwards only the provider name. ProviderName string + + // SessionID is the id of the session that triggered this token request. A + // client-level shared callback registered for many sessions can use this to + // resolve the owning session and scope token acquisition or caching per + // session. + SessionID string } // GetBearerToken is a per-provider callback that resolves a bearer token on diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index b62e8c582..6b39ea027 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -328,7 +328,7 @@ private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, Js return; } - CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName)); + CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName, sessionId)); if (tokenFuture == null) { rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider returned null future for provider " + providerName); diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java index 3866cc0ad..1515c9c5e 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -19,6 +19,8 @@ public class ProviderTokenArgs { private String providerName; + private String sessionId; + /** * Creates an empty argument object. */ @@ -32,9 +34,12 @@ public ProviderTokenArgs() { * the name of the BYOK provider needing a token; {@code "default"} * for the singular whole-session provider, otherwise the named * provider's {@code name} + * @param sessionId + * the id of the session that triggered this token request */ - public ProviderTokenArgs(String providerName) { + public ProviderTokenArgs(String providerName, String sessionId) { this.providerName = providerName; + this.sessionId = sessionId; } /** @@ -60,4 +65,29 @@ public ProviderTokenArgs setProviderName(String providerName) { this.providerName = providerName; return this; } + + /** + * Gets the id of the session that triggered this token request. + *

+ * A client-level shared callback registered for many sessions can use this + * to resolve the owning session and scope token acquisition or caching per + * session. + * + * @return the session id + */ + public String getSessionId() { + return sessionId; + } + + /** + * Sets the id of the session that triggered this token request. + * + * @param sessionId + * the session id + * @return this args instance for method chaining + */ + public ProviderTokenArgs setSessionId(String sessionId) { + this.sessionId = sessionId; + return this; + } } diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java index 253ce136c..b7a2db80f 100644 --- a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -126,6 +126,8 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception { List acquiredFor = new ArrayList<>(); GetBearerToken redCallback = args -> { assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded"); + assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), + "Expected a non-empty session id in token args"); synchronized (acquiredFor) { acquiredFor.add("red"); } @@ -133,6 +135,8 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception { }; GetBearerToken blueCallback = args -> { assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded"); + assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), + "Expected a non-empty session id in token args"); synchronized (acquiredFor) { acquiredFor.add("blue"); } diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index d87d2b9de..5501a2ead 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -830,6 +830,7 @@ export class CopilotSession { } const token = await callback({ providerName: params.providerName, + sessionId: params.sessionId, }); return { token }; }, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 61d9ca06d..ec88bfd74 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2231,6 +2231,14 @@ export interface ProviderTokenArgs { * provider-agnostic and forwards only the provider name. */ providerName: string; + + /** + * Id of the session that triggered this token request. A client-level shared + * callback registered for many sessions can use this to resolve the owning + * session (e.g. via the client's session lookup) to scope token acquisition + * or caching per session. + */ + sessionId: string; } /** diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts index 228b7a022..deae2fa84 100644 --- a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -216,6 +216,10 @@ describe("BYOK bearer-token provider", async () => { // The runtime forwards the requesting provider's name so the client // can dispatch to the right credential. expect(args.providerName).toBe(providerName); + // The runtime also forwards the owning session id so a + // client-level shared callback can resolve the session. + expect(typeof args.sessionId).toBe("string"); + expect(args.sessionId.length).toBeGreaterThan(0); acquiredFor.push(providerName); return tokenByProvider[providerName]; }; diff --git a/python/copilot/session.py b/python/copilot/session.py index 94fba994a..c0b52a42e 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1092,6 +1092,11 @@ class ProviderTokenArgs(TypedDict): # ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``. provider_name: str + # Id of the session that triggered this token request. A client-level shared + # callback registered for many sessions can use this to resolve the owning + # session and scope token acquisition or caching per session. + session_id: str + # Per-request callback that resolves a bearer token on demand for a BYOK # provider (for example via Azure Managed Identity). The Copilot SDK takes no @@ -1268,7 +1273,10 @@ async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenA -32603, f"No bearer-token provider registered for provider: {provider_name!r}", ) - args: ProviderTokenArgs = {"provider_name": provider_name} + args: ProviderTokenArgs = { + "provider_name": provider_name, + "session_id": params.session_id, + } result = callback(args) if inspect.isawaitable(result): result = await result diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py index 28f9e0586..855fb8b78 100644 --- a/python/e2e/test_byok_bearer_token_provider_e2e.py +++ b/python/e2e/test_byok_bearer_token_provider_e2e.py @@ -213,6 +213,9 @@ async def callback(args) -> str: # The runtime forwards the requesting provider's name so the # client can dispatch to the right credential. assert args["provider_name"] == provider_name + # The runtime also forwards the owning session id so a + # client-level shared callback can resolve the session. + assert isinstance(args["session_id"], str) and args["session_id"] acquired_for.append(provider_name) return token_by_provider[provider_name] diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs index f92715006..a8b75f196 100644 --- a/rust/src/provider_token.rs +++ b/rust/src/provider_token.rs @@ -30,6 +30,13 @@ pub struct ProviderTokenArgs { /// This is `"default"` for the singular whole-session provider, otherwise /// the named provider's `name`. pub provider_name: String, + + /// Id of the session that triggered this token request. + /// + /// A client-level shared callback registered for many sessions can use this + /// to resolve the owning session and scope token acquisition or caching per + /// session. + pub session_id: String, } /// Error returned by a [`BearerTokenProvider`]. diff --git a/rust/src/provider_token_dispatch.rs b/rust/src/provider_token_dispatch.rs index c100443cd..0631260a4 100644 --- a/rust/src/provider_token_dispatch.rs +++ b/rust/src/provider_token_dispatch.rs @@ -119,6 +119,7 @@ async fn get_token( match token_provider .get_token(ProviderTokenArgs { provider_name: params.provider_name, + session_id: params.session_id.into_inner(), }) .await { diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs index c3cd9ef4b..5e9b28170 100644 --- a/rust/tests/e2e/byok_bearer_token_provider.rs +++ b/rust/tests/e2e/byok_bearer_token_provider.rs @@ -270,6 +270,10 @@ async fn dispatches_token_acquisition_per_provider() { let acquired_for = acquired_for.clone(); async move { assert_eq!(args.provider_name, name); + assert!( + !args.session_id.is_empty(), + "expected a non-empty session id in token args" + ); acquired_for.lock().unwrap().push(name.to_string()); Ok::<_, BearerTokenError>(token.to_string()) } From bf30ae88f2ff1f4768923d7dffafc3a1ec6f9190 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 25 Jun 2026 10:19:52 +0100 Subject: [PATCH 2/4] Rename getBearerToken callback to bearerTokenProvider; fix precedence docs and CodeQL findings Address post-merge review feedback from #1748 across all 6 SDKs: - Rename the BYOK token callback field getBearerToken/get_bearer_token/ GetBearerToken to bearerTokenProvider/bearer_token_provider/ BearerTokenProvider, and the callback type to BearerTokenProvider. The Provider suffix distinguishes the dynamic token source from the static bearerToken credential and aligns with the existing Rust trait and the SDK's *Provider value-producer precedent. In Java this also drops the double-get accessor (getGetBearerToken -> getBearerTokenProvider). - Fix docs that incorrectly described the callback and static apiKey/ bearerToken as mutually exclusive; the runtime applies precedence (the callback wins and the static credential is not sent). - Resolve 4 CodeQL findings: empty except in python; LINQ .Where filter, specific catch type, and HttpResponseMessage disposal in dotnet. Java was not build-verified locally (requires JDK 25); CI validates it. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- dotnet/src/BearerTokenProvider.cs | 2 +- dotnet/src/Client.cs | 11 ++--- dotnet/src/Session.cs | 4 +- dotnet/src/Types.cs | 26 ++++++----- .../E2E/ByokBearerTokenProviderE2ETests.cs | 19 ++++---- go/client.go | 14 +++--- .../byok_bearer_token_provider_e2e_test.go | 46 +++++++++---------- go/session.go | 10 ++-- go/types.go | 32 ++++++++----- .../com/github/copilot/CopilotSession.java | 8 ++-- .../github/copilot/RpcHandlerDispatcher.java | 4 +- .../github/copilot/SessionRequestBuilder.java | 18 ++++---- ...rerToken.java => BearerTokenProvider.java} | 6 +-- .../copilot/rpc/NamedProviderConfig.java | 14 +++--- .../github/copilot/rpc/ProviderConfig.java | 14 +++--- .../ByokBearerTokenProviderE2ETest.java | 20 ++++---- nodejs/src/client.ts | 28 +++++------ nodejs/src/index.ts | 2 +- nodejs/src/session.ts | 8 ++-- nodejs/src/types.ts | 20 ++++---- .../byok_bearer_token_provider.e2e.test.ts | 20 ++++---- python/copilot/__init__.py | 4 +- python/copilot/client.py | 16 +++---- python/copilot/session.py | 26 +++++++---- .../test_byok_bearer_token_provider_e2e.py | 17 +++---- rust/src/types.rs | 24 +++++----- rust/tests/e2e/byok_bearer_token_provider.rs | 6 +-- 27 files changed, 219 insertions(+), 200 deletions(-) rename java/src/main/java/com/github/copilot/rpc/{GetBearerToken.java => BearerTokenProvider.java} (87%) diff --git a/dotnet/src/BearerTokenProvider.cs b/dotnet/src/BearerTokenProvider.cs index 1ed11198f..923c225bc 100644 --- a/dotnet/src/BearerTokenProvider.cs +++ b/dotnet/src/BearerTokenProvider.cs @@ -7,7 +7,7 @@ namespace GitHub.Copilot; ///

-/// Arguments passed to a bearer-token callback (the GetBearerToken property +/// Arguments passed to a bearer-token callback (the BearerTokenProvider property /// on / ) when the /// runtime needs a fresh bearer token for a BYOK provider. /// diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index dff034b79..a67eb9681 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -671,7 +671,7 @@ private CopilotSession InitializeSession( private const string DefaultBearerTokenProviderName = "default"; /// - /// Collects the per-provider GetBearerToken callbacks keyed by + /// Collects the per-provider BearerTokenProvider callbacks keyed by /// provider name for session-side registration. The singular, whole-session /// uses the implicit /// . @@ -679,18 +679,15 @@ private CopilotSession InitializeSession( private static Dictionary>> BuildBearerTokenCallbacks(SessionConfigBase config) { var callbacks = new Dictionary>>(StringComparer.Ordinal); - if (config.Provider?.GetBearerToken is { } singular) + if (config.Provider?.BearerTokenProvider is { } singular) { callbacks[DefaultBearerTokenProviderName] = singular; } if (config.Providers != null) { - foreach (var provider in config.Providers) + foreach (var provider in config.Providers.Where(provider => provider.BearerTokenProvider is not null)) { - if (provider.GetBearerToken is { } callback) - { - callbacks[provider.Name] = callback; - } + callbacks[provider.Name] = provider.BearerTokenProvider!; } } return callbacks; diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 5d5661e65..0985848e2 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -871,7 +871,7 @@ internal void RegisterAutoModeSwitchHandler(Func - /// Registers per-provider GetBearerToken callbacks for BYOK + /// Registers per-provider BearerTokenProvider callbacks for BYOK /// providers configured with managed-identity / on-demand bearer-token auth. /// /// @@ -899,7 +899,7 @@ internal void RegisterBearerTokenProviders(IReadOnlyDictionary /// Routes runtime providerToken.getToken requests to the matching - /// per-provider GetBearerToken callback registered on the session. + /// per-provider BearerTokenProvider callback registered on the session. /// private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler { diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index a33dff61c..5ae965781 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -2044,26 +2044,27 @@ public sealed class ProviderConfig public string? BearerToken { get; set; } /// - /// Wire-only flag, emitted automatically when is set, that tells + /// Wire-only flag, emitted automatically when is set, that tells /// the runtime to request a token over the session-scoped providerToken.getToken RPC - /// before each outbound request to this provider. Derived from ; + /// before each outbound request to this provider. Derived from ; /// internal and never part of the public API. /// [JsonInclude] [JsonPropertyName("hasBearerTokenProvider")] - internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null; /// /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a /// callback backed by your own identity library. Never serialized — setting it makes the SDK send /// hasBearerTokenProvider: true on the wire and answer the runtime's - /// providerToken.getToken requests. Mutually exclusive with and - /// . + /// providerToken.getToken requests. When set alongside /, this callback takes precedence: + /// the runtime applies the token it returns as the Authorization: Bearer header for each request + /// and does not send the static credential. /// [JsonIgnore] [Experimental(Diagnostics.Experimental)] - public Func>? GetBearerToken { get; set; } + public Func>? BearerTokenProvider { get; set; } /// /// Azure-specific configuration options. @@ -2198,26 +2199,27 @@ public sealed class NamedProviderConfig public string? BearerToken { get; set; } /// - /// Wire-only flag, emitted automatically when is set, that tells + /// Wire-only flag, emitted automatically when is set, that tells /// the runtime to request a token over the session-scoped providerToken.getToken RPC - /// before each outbound request to this provider. Derived from ; + /// before each outbound request to this provider. Derived from ; /// internal and never part of the public API. /// [JsonInclude] [JsonPropertyName("hasBearerTokenProvider")] - internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null; + internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null; /// /// Per-request callback that resolves a bearer token on demand for this BYOK provider (for /// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a /// callback backed by your own identity library. Never serialized — setting it makes the SDK send /// hasBearerTokenProvider: true on the wire and answer the runtime's - /// providerToken.getToken requests. Mutually exclusive with and - /// . + /// providerToken.getToken requests. When set alongside /, this callback takes precedence: + /// the runtime applies the token it returns as the Authorization: Bearer header for each request + /// and does not send the static credential. /// [JsonIgnore] [Experimental(Diagnostics.Experimental)] - public Func>? GetBearerToken { get; set; } + public Func>? BearerTokenProvider { get; set; } /// /// Azure-specific configuration options. diff --git a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs index 0cb0e6e16..5973dc61c 100644 --- a/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs +++ b/dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs @@ -13,7 +13,7 @@ namespace GitHub.Copilot.Test.E2E; /// /// End-to-end coverage for the experimental BYOK bearer-token-provider surface -/// (GetBearerToken on a provider config). The callback stays entirely on +/// (BearerTokenProvider on a provider config). The callback stays entirely on /// the SDK/client side: the SDK strips it from the wire config, sets the /// hasBearerTokenProvider flag, and the runtime calls back over the /// session-scoped providerToken.getToken RPC before each outbound model @@ -81,7 +81,7 @@ private static async Task RunTurnAsync( { await session.SendAndWaitAsync(new MessageOptions { Prompt = prompt }); } - catch + catch (InvalidOperationException) { // The handler always 404s the BYOK endpoint, so the turn errors after // the token-bearing request was already captured. Expected. @@ -110,7 +110,7 @@ public async Task Applies_The_Callbacks_Token_As_The_Authorization_Header() Type = "openai", WireApi = "completions", BaseUrl = PrimaryBaseUrl, - GetBearerToken = _ => + BearerTokenProvider = _ => { Interlocked.Increment(ref calls); return Task.FromResult(sentinel); @@ -149,7 +149,7 @@ public async Task Re_Acquires_A_Fresh_Token_For_Each_Request() BaseUrl = PrimaryBaseUrl, // A distinct token per acquisition proves the runtime re-invokes // the callback per request rather than caching a previous token. - GetBearerToken = _ => + BearerTokenProvider = _ => { var n = Interlocked.Increment(ref calls); return Task.FromResult($"rotating-token-{n}"); @@ -208,7 +208,7 @@ Func> MakeCallback(string providerName) => Type = "openai", WireApi = "completions", BaseUrl = RedBaseUrl, - GetBearerToken = MakeCallback("red"), + BearerTokenProvider = MakeCallback("red"), }, new() { @@ -216,7 +216,7 @@ Func> MakeCallback(string providerName) => Type = "openai", WireApi = "completions", BaseUrl = BlueBaseUrl, - GetBearerToken = MakeCallback("blue"), + BearerTokenProvider = MakeCallback("blue"), }, }; var models = new List @@ -243,7 +243,7 @@ Func> MakeCallback(string providerName) => /// The runtime invokes for every model-layer HTTP /// request. Requests aimed at a fake BYOK host (*.invalid) are captured — /// recording the Authorization header the runtime applied after calling -/// the provider's GetBearerToken callback over the session-scoped +/// the provider's BearerTokenProvider callback over the session-scoped /// providerToken.getToken RPC — and answered with a synthetic 404 /// (a non-retryable status, so each outbound model request yields exactly one /// capture). Every other request (CAPI bootstrap: model catalog, policy, …) is @@ -265,13 +265,14 @@ protected override Task SendRequestAsync(HttpRequestMessage ? string.Join(", ", values) : null)); - return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound) + var response = new HttpResponseMessage(HttpStatusCode.NotFound) { Content = new StringContent( "{\"error\":{\"message\":\"fake byok endpoint\"}}", System.Text.Encoding.UTF8, "application/json"), - }); + }; + return Task.FromResult(response); } // CAPI bootstrap (model catalog, policy, …) — answered off-network. diff --git a/go/client.go b/go/client.go index 7cdb4fbad..970f04642 100644 --- a/go/client.go +++ b/go/client.go @@ -57,18 +57,18 @@ import ( // whole-session [ProviderConfig]. Named providers are keyed by their own Name. const defaultBearerTokenProviderName = "default" -// collectBearerTokenProviders gathers the per-provider [GetBearerToken] callbacks +// collectBearerTokenProviders gathers the per-provider [BearerTokenProvider] callbacks // from the singular provider and any named providers, keyed by provider name. The // singular provider uses the implicit name "default"; named providers use their // own Name. Returns nil when no callbacks are configured. -func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]GetBearerToken { - callbacks := make(map[string]GetBearerToken) - if provider != nil && provider.GetBearerToken != nil { - callbacks[defaultBearerTokenProviderName] = provider.GetBearerToken +func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]BearerTokenProvider { + callbacks := make(map[string]BearerTokenProvider) + if provider != nil && provider.BearerTokenProvider != nil { + callbacks[defaultBearerTokenProviderName] = provider.BearerTokenProvider } for i := range providers { - if providers[i].GetBearerToken != nil { - callbacks[providers[i].Name] = providers[i].GetBearerToken + if providers[i].BearerTokenProvider != nil { + callbacks[providers[i].Name] = providers[i].BearerTokenProvider } } if len(callbacks) == 0 { diff --git a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go index db9864a4b..2f298596d 100644 --- a/go/internal/e2e/byok_bearer_token_provider_e2e_test.go +++ b/go/internal/e2e/byok_bearer_token_provider_e2e_test.go @@ -37,7 +37,7 @@ type capturedBYOKRequest struct { // byokCapturingRoundTripper stands in for a real HTTP upstream. It records the // `Authorization` header the runtime applied (after calling the provider's -// GetBearerToken callback over the session-scoped `providerToken.getToken` RPC) +// BearerTokenProvider callback over the session-scoped `providerToken.getToken` RPC) // for every request aimed at a fake `.invalid` BYOK host, answering them with a // synthetic 404 (a non-retryable status, so each outbound model request yields // exactly one capture). Every other request (CAPI bootstrap: model catalog, @@ -96,7 +96,7 @@ func (rt *byokCapturingRoundTripper) reset() { } // TestBYOKBearerTokenProvider is end-to-end coverage for the experimental BYOK -// bearer-token-provider surface (GetBearerToken on a provider config). The +// bearer-token-provider surface (BearerTokenProvider on a provider config). The // callback stays entirely on the SDK/client side: the SDK strips it from the // wire config, sets the `hasBearerTokenProvider` flag, and the runtime calls // back over the session-scoped `providerToken.getToken` RPC before each outbound @@ -151,11 +151,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } providers := []copilot.NamedProviderConfig{{ - Name: "mi", - Type: "openai", - WireAPI: "completions", - BaseURL: byokPrimaryBaseURL, - GetBearerToken: getBearerToken, + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + BearerTokenProvider: getBearerToken, }} models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} @@ -189,11 +189,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } providers := []copilot.NamedProviderConfig{{ - Name: "mi", - Type: "openai", - WireAPI: "completions", - BaseURL: byokPrimaryBaseURL, - GetBearerToken: getBearerToken, + Name: "mi", + Type: "openai", + WireAPI: "completions", + BaseURL: byokPrimaryBaseURL, + BearerTokenProvider: getBearerToken, }} models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}} @@ -227,7 +227,7 @@ func TestBYOKBearerTokenProvider(t *testing.T) { } var mu sync.Mutex var acquiredFor []string - makeCallback := func(providerName string) copilot.GetBearerToken { + makeCallback := func(providerName string) copilot.BearerTokenProvider { return func(args copilot.ProviderTokenArgs) (string, error) { // The runtime forwards the requesting provider's name so the // client can dispatch to the right credential. @@ -248,18 +248,18 @@ func TestBYOKBearerTokenProvider(t *testing.T) { providers := []copilot.NamedProviderConfig{ { - Name: "red", - Type: "openai", - WireAPI: "completions", - BaseURL: byokRedBaseURL, - GetBearerToken: makeCallback("red"), + Name: "red", + Type: "openai", + WireAPI: "completions", + BaseURL: byokRedBaseURL, + BearerTokenProvider: makeCallback("red"), }, { - Name: "blue", - Type: "openai", - WireAPI: "completions", - BaseURL: byokBlueBaseURL, - GetBearerToken: makeCallback("blue"), + Name: "blue", + Type: "openai", + WireAPI: "completions", + BaseURL: byokBlueBaseURL, + BearerTokenProvider: makeCallback("blue"), }, } models := []copilot.ProviderModelConfig{ diff --git a/go/session.go b/go/session.go index da82c6a78..851157ba8 100644 --- a/go/session.go +++ b/go/session.go @@ -77,7 +77,7 @@ type Session struct { elicitationMu sync.RWMutex canvasHandler CanvasHandler canvasMu sync.RWMutex - bearerTokenProviders map[string]GetBearerToken + bearerTokenProviders map[string]BearerTokenProvider bearerTokenMu sync.RWMutex openCanvases []rpc.OpenCanvasInstance openCanvasesMu sync.RWMutex @@ -183,7 +183,7 @@ func (s *Session) getCanvasHandler() CanvasHandler { return s.canvasHandler } -// registerBearerTokenProviders installs per-provider [GetBearerToken] callbacks +// registerBearerTokenProviders installs per-provider [BearerTokenProvider] callbacks // for BYOK providers configured with managed-identity / on-demand bearer-token // auth, keyed by provider name. // @@ -192,10 +192,10 @@ func (s *Session) getCanvasHandler() CanvasHandler { // runtime needs a token it issues a session-scoped `providerToken.getToken` // request, which the session's provider-token adapter routes to the matching // per-provider callback. -func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerToken) { +func (s *Session) registerBearerTokenProviders(providers map[string]BearerTokenProvider) { s.bearerTokenMu.Lock() defer s.bearerTokenMu.Unlock() - s.bearerTokenProviders = make(map[string]GetBearerToken, len(providers)) + s.bearerTokenProviders = make(map[string]BearerTokenProvider, len(providers)) for name, callback := range providers { if callback == nil { continue @@ -204,7 +204,7 @@ func (s *Session) registerBearerTokenProviders(providers map[string]GetBearerTok } } -func (s *Session) getBearerTokenProvider(providerName string) GetBearerToken { +func (s *Session) getBearerTokenProvider(providerName string) BearerTokenProvider { s.bearerTokenMu.RLock() defer s.bearerTokenMu.RUnlock() return s.bearerTokenProviders[providerName] diff --git a/go/types.go b/go/types.go index a7007c513..8a7df3c46 100644 --- a/go/types.go +++ b/go/types.go @@ -1564,7 +1564,7 @@ type ResumeSessionConfig struct { ExpAssignments any } -// ProviderTokenArgs carries the context passed to a [GetBearerToken] callback +// ProviderTokenArgs carries the context passed to a [BearerTokenProvider] callback // when the runtime needs a fresh bearer token for a BYOK provider. // // Experimental: ProviderTokenArgs is part of the experimental managed-identity / @@ -1587,7 +1587,7 @@ type ProviderTokenArgs struct { SessionID string } -// GetBearerToken is a per-provider callback that resolves a bearer token on +// BearerTokenProvider is a per-provider callback that resolves a bearer token on // demand, returning the raw token string (without the "Bearer " prefix). The // Copilot SDK itself takes no Azure dependency: the consumer supplies this // callback backed by their own identity library (for example azidentity's @@ -1595,10 +1595,10 @@ type ProviderTokenArgs struct { // outbound model request. The runtime does no caching of its own, so the callback // (or the identity library it wraps) owns token caching and refresh. // -// Experimental: GetBearerToken is part of the experimental managed-identity / +// Experimental: BearerTokenProvider is part of the experimental managed-identity / // bearer-token-provider surface and may change or be removed in future SDK or CLI // releases. -type GetBearerToken func(args ProviderTokenArgs) (string, error) +type BearerTokenProvider func(args ProviderTokenArgs) (string, error) type ProviderConfig struct { // Type is the provider type: "openai", "azure", or "anthropic". Defaults to "openai". @@ -1640,20 +1640,24 @@ type ProviderConfig struct { // tokens. When hit, the model stops generating and returns a truncated // response. MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - // GetBearerToken resolves a bearer token on demand for this provider + // BearerTokenProvider resolves a bearer token on demand for this provider // (managed-identity / on-demand auth). When set, the SDK strips the callback // from the wire config and instead sends `hasBearerTokenProvider: true`; the // runtime calls back over the session-scoped `providerToken.getToken` RPC // before each outbound model request and applies the returned token as the // Authorization header. Never serialized. // + // When set alongside APIKey/BearerToken, this callback takes precedence: the + // runtime applies the token it returns as the Authorization: Bearer header for + // each request and does not send the static credential. + // // Experimental: part of the experimental managed-identity / bearer-token-provider // surface and may change or be removed in future SDK or CLI releases. - GetBearerToken GetBearerToken `json:"-"` + BearerTokenProvider BearerTokenProvider `json:"-"` } // MarshalJSON serializes the provider config, deriving the wire-only -// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.GetBearerToken]. +// `hasBearerTokenProvider` flag from the presence of [ProviderConfig.BearerTokenProvider]. // The non-serializable callback never crosses the RPC boundary; the runtime only // learns that a token provider exists and forwards the provider name back when it // needs a token. @@ -1663,7 +1667,7 @@ func (p ProviderConfig) MarshalJSON() ([]byte, error) { wire HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` }{wire: wire(p)} - if p.GetBearerToken != nil { + if p.BearerTokenProvider != nil { aux.HasBearerTokenProvider = Bool(true) } return json.Marshal(aux) @@ -1721,21 +1725,25 @@ type NamedProviderConfig struct { Azure *AzureProviderOptions `json:"azure,omitempty"` // Headers are custom HTTP headers included in all outbound provider requests. Headers map[string]string `json:"headers,omitempty"` - // GetBearerToken resolves a bearer token on demand for this provider + // BearerTokenProvider resolves a bearer token on demand for this provider // (managed-identity / on-demand auth). When set, the SDK strips the callback // from the wire config and instead sends `hasBearerTokenProvider: true`; the // runtime calls back over the session-scoped `providerToken.getToken` RPC // before each outbound model request and applies the returned token as the // Authorization header. Never serialized. // + // When set alongside APIKey/BearerToken, this callback takes precedence: the + // runtime applies the token it returns as the Authorization: Bearer header for + // each request and does not send the static credential. + // // Experimental: part of the experimental managed-identity / bearer-token-provider // surface and may change or be removed in future SDK or CLI releases. - GetBearerToken GetBearerToken `json:"-"` + BearerTokenProvider BearerTokenProvider `json:"-"` } // MarshalJSON serializes the named provider config, deriving the wire-only // `hasBearerTokenProvider` flag from the presence of -// [NamedProviderConfig.GetBearerToken]. The non-serializable callback never +// [NamedProviderConfig.BearerTokenProvider]. The non-serializable callback never // crosses the RPC boundary; the runtime only learns that a token provider exists // and forwards the provider name back when it needs a token. func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { @@ -1744,7 +1752,7 @@ func (p NamedProviderConfig) MarshalJSON() ([]byte, error) { wire HasBearerTokenProvider *bool `json:"hasBearerTokenProvider,omitempty"` }{wire: wire(p)} - if p.GetBearerToken != nil { + if p.BearerTokenProvider != nil { aux.HasBearerTokenProvider = Bool(true) } return json.Marshal(aux) diff --git a/java/src/main/java/com/github/copilot/CopilotSession.java b/java/src/main/java/com/github/copilot/CopilotSession.java index 9e0391594..90f76b6df 100644 --- a/java/src/main/java/com/github/copilot/CopilotSession.java +++ b/java/src/main/java/com/github/copilot/CopilotSession.java @@ -74,7 +74,7 @@ import com.github.copilot.rpc.ExitPlanModeRequest; import com.github.copilot.rpc.ExitPlanModeResult; import com.github.copilot.rpc.ElicitationSchema; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.GetMessagesResponse; import com.github.copilot.rpc.HookInvocation; import com.github.copilot.rpc.InputOptions; @@ -169,7 +169,7 @@ public final class CopilotSession implements AutoCloseable { private final Set> eventHandlers = ConcurrentHashMap.newKeySet(); private final Map toolHandlers = new ConcurrentHashMap<>(); private final Map commandHandlers = new ConcurrentHashMap<>(); - private final Map bearerTokenProviders = new ConcurrentHashMap<>(); + private final Map bearerTokenProviders = new ConcurrentHashMap<>(); private final AtomicReference permissionHandler = new AtomicReference<>(); private final AtomicReference userInputHandler = new AtomicReference<>(); private final AtomicReference elicitationHandler = new AtomicReference<>(); @@ -1358,7 +1358,7 @@ void registerElicitationHandler(ElicitationHandler handler) { * @param providers * the callbacks keyed by provider name */ - void registerBearerTokenProviders(Map providers) { + void registerBearerTokenProviders(Map providers) { bearerTokenProviders.clear(); if (providers != null) { bearerTokenProviders.putAll(providers); @@ -1372,7 +1372,7 @@ void registerBearerTokenProviders(Map providers) { * the provider name * @return the registered callback, or {@code null} if none is registered */ - GetBearerToken getBearerTokenProvider(String providerName) { + BearerTokenProvider getBearerTokenProvider(String providerName) { return bearerTokenProviders.get(providerName); } diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index 6b39ea027..d6526e6a1 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -19,7 +19,7 @@ import com.github.copilot.generated.SessionEvent; import com.github.copilot.rpc.AutoModeSwitchRequest; import com.github.copilot.rpc.ExitPlanModeRequest; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.ProviderTokenArgs; import com.github.copilot.rpc.PermissionRequestResult; import com.github.copilot.rpc.PermissionRequestResultKind; @@ -321,7 +321,7 @@ private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, Js return; } - GetBearerToken provider = session.getBearerTokenProvider(providerName); + BearerTokenProvider provider = session.getBearerTokenProvider(providerName); if (provider == null) { rpc.sendErrorResponse(requestIdLong, -32603, "No bearer-token provider registered for provider " + providerName); diff --git a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java index 943894d28..6000bdef8 100644 --- a/java/src/main/java/com/github/copilot/SessionRequestBuilder.java +++ b/java/src/main/java/com/github/copilot/SessionRequestBuilder.java @@ -13,7 +13,7 @@ import com.github.copilot.rpc.CreateSessionRequest; import com.github.copilot.rpc.ProviderConfig; import com.github.copilot.rpc.NamedProviderConfig; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.CommandWireDefinition; import com.github.copilot.rpc.ResumeSessionConfig; import com.github.copilot.rpc.ResumeSessionRequest; @@ -335,7 +335,7 @@ static void configureSession(CopilotSession session, SessionConfig config) { if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } - Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), config.getProviders()); if (!bearerTokenProviders.isEmpty()) { session.registerBearerTokenProviders(bearerTokenProviders); @@ -382,7 +382,7 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) if (config.getOnElicitationRequest() != null) { session.registerElicitationHandler(config.getOnElicitationRequest()); } - Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), + Map bearerTokenProviders = collectBearerTokenProviders(config.getProvider(), config.getProviders()); if (!bearerTokenProviders.isEmpty()) { session.registerBearerTokenProviders(bearerTokenProviders); @@ -398,17 +398,17 @@ static void configureSession(CopilotSession session, ResumeSessionConfig config) } } - private static Map collectBearerTokenProviders(ProviderConfig provider, + private static Map collectBearerTokenProviders(ProviderConfig provider, List providers) { - Map bearerTokenProviders = new HashMap<>(); - if (provider != null && provider.getGetBearerToken() != null) { - bearerTokenProviders.put("default", provider.getGetBearerToken()); + Map bearerTokenProviders = new HashMap<>(); + if (provider != null && provider.getBearerTokenProvider() != null) { + bearerTokenProviders.put("default", provider.getBearerTokenProvider()); } if (providers != null) { for (NamedProviderConfig namedProvider : providers) { if (namedProvider != null && namedProvider.getName() != null - && namedProvider.getGetBearerToken() != null) { - bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken()); + && namedProvider.getBearerTokenProvider() != null) { + bearerTokenProviders.put(namedProvider.getName(), namedProvider.getBearerTokenProvider()); } } } diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java similarity index 87% rename from java/src/main/java/com/github/copilot/rpc/GetBearerToken.java rename to java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java index 27ec7f09c..7b37925aa 100644 --- a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java +++ b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java @@ -20,13 +20,13 @@ * Experimental. This managed-identity surface may change or be * removed in future SDK or CLI releases. * - * @see ProviderConfig#setGetBearerToken(GetBearerToken) - * @see NamedProviderConfig#setGetBearerToken(GetBearerToken) + * @see ProviderConfig#setBearerTokenProvider(BearerTokenProvider) + * @see NamedProviderConfig#setBearerTokenProvider(BearerTokenProvider) * @since 1.0.0 */ @CopilotExperimental @FunctionalInterface -public interface GetBearerToken { +public interface BearerTokenProvider { /** * Gets a bearer token for the provider identified by {@code args}. diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java index 2bdf2678f..e3b090019 100644 --- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java @@ -61,7 +61,7 @@ public class NamedProviderConfig { private String bearerToken; @JsonIgnore - private GetBearerToken getBearerToken; + private BearerTokenProvider bearerTokenProvider; @JsonProperty("azure") private AzureOptions azure; @@ -221,8 +221,8 @@ public NamedProviderConfig setBearerToken(String bearerToken) { * * @return the bearer-token provider callback, or {@code null} if not set */ - public GetBearerToken getGetBearerToken() { - return getBearerToken; + public BearerTokenProvider getBearerTokenProvider() { + return bearerTokenProvider; } /** @@ -234,19 +234,19 @@ public GetBearerToken getGetBearerToken() { * RPC before each model request. Return the raw token without a {@code Bearer } * prefix. * - * @param getBearerToken + * @param bearerTokenProvider * the bearer-token provider callback * @return this config for method chaining */ - public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { - this.getBearerToken = getBearerToken; + public NamedProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) { + this.bearerTokenProvider = bearerTokenProvider; return this; } @JsonProperty("hasBearerTokenProvider") @JsonInclude(JsonInclude.Include.NON_NULL) Boolean hasBearerTokenProviderWireFlag() { - return getBearerToken != null ? Boolean.TRUE : null; + return bearerTokenProvider != null ? Boolean.TRUE : null; } /** diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java index ae59e7ead..3d6faba34 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java @@ -57,7 +57,7 @@ public class ProviderConfig { private String bearerToken; @JsonIgnore - private GetBearerToken getBearerToken; + private BearerTokenProvider bearerTokenProvider; @JsonProperty("azure") private AzureOptions azure; @@ -230,8 +230,8 @@ public ProviderConfig setBearerToken(String bearerToken) { * * @return the bearer-token provider callback, or {@code null} if not set */ - public GetBearerToken getGetBearerToken() { - return getBearerToken; + public BearerTokenProvider getBearerTokenProvider() { + return bearerTokenProvider; } /** @@ -243,19 +243,19 @@ public GetBearerToken getGetBearerToken() { * RPC before each model request. Return the raw token without a {@code Bearer } * prefix. * - * @param getBearerToken + * @param bearerTokenProvider * the bearer-token provider callback * @return this config for method chaining */ - public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) { - this.getBearerToken = getBearerToken; + public ProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) { + this.bearerTokenProvider = bearerTokenProvider; return this; } @JsonProperty("hasBearerTokenProvider") @JsonInclude(JsonInclude.Include.NON_NULL) Boolean hasBearerTokenProviderWireFlag() { - return getBearerToken != null ? Boolean.TRUE : null; + return bearerTokenProvider != null ? Boolean.TRUE : null; } /** diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java index b7a2db80f..280e5fd24 100644 --- a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -34,7 +34,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import com.github.copilot.rpc.GetBearerToken; +import com.github.copilot.rpc.BearerTokenProvider; import com.github.copilot.rpc.MessageOptions; import com.github.copilot.rpc.NamedProviderConfig; import com.github.copilot.rpc.PermissionHandler; @@ -43,7 +43,7 @@ /** * End-to-end coverage for the experimental BYOK bearer-token-provider surface - * ({@code getBearerToken} on a provider config). The callback stays entirely on + * ({@code BearerTokenProvider} on a provider config). The callback stays entirely on * the SDK/client side: the SDK keeps it off the wire, sends only the * {@code hasBearerTokenProvider} flag, and the runtime calls back over the * session-scoped {@code providerToken.getToken} RPC before each outbound model @@ -82,13 +82,13 @@ void resetHandler() { void appliesCallbackTokenAsAuthorizationHeader() throws Exception { String sentinel = "sentinel-bearer-token-abc123"; AtomicInteger calls = new AtomicInteger(); - GetBearerToken getBearerToken = args -> { + BearerTokenProvider tokenProvider = args -> { calls.incrementAndGet(); return CompletableFuture.completedFuture(sentinel); }; List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") - .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider)); List models = List .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); @@ -102,11 +102,11 @@ void appliesCallbackTokenAsAuthorizationHeader() throws Exception { @Test void reacquiresFreshTokenForEachRequest() throws Exception { AtomicInteger calls = new AtomicInteger(); - GetBearerToken getBearerToken = args -> CompletableFuture + BearerTokenProvider tokenProvider = args -> CompletableFuture .completedFuture("rotating-token-" + calls.incrementAndGet()); List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai") - .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken)); + .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider)); List models = List .of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o")); @@ -124,7 +124,7 @@ void reacquiresFreshTokenForEachRequest() throws Exception { @Test void dispatchesTokenAcquisitionPerProvider() throws Exception { List acquiredFor = new ArrayList<>(); - GetBearerToken redCallback = args -> { + BearerTokenProvider redCallback = args -> { assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded"); assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), "Expected a non-empty session id in token args"); @@ -133,7 +133,7 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception { } return CompletableFuture.completedFuture("token-for-red"); }; - GetBearerToken blueCallback = args -> { + BearerTokenProvider blueCallback = args -> { assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded"); assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(), "Expected a non-empty session id in token args"); @@ -145,9 +145,9 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception { List providers = List.of( new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions") - .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback), + .setBaseUrl(RED_BASE_URL).setBearerTokenProvider(redCallback), new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions") - .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback)); + .setBaseUrl(BLUE_BASE_URL).setBearerTokenProvider(blueCallback)); List models = List.of( new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"), new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o")); diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index e0315b211..53686a6ca 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -51,7 +51,7 @@ import type { ExitPlanModeResult, ForegroundSessionInfo, GetAuthStatusResponse, - GetBearerToken, + BearerTokenProvider, GetStatusResponse, InternalRuntimeConnection, LargeToolOutputConfig, @@ -161,17 +161,17 @@ function toJsonSchema(parameters: Tool["parameters"]): Record | const DEFAULT_PROVIDER_NAME = "default"; /** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */ -type WireProviderConfig = Omit & { +type WireProviderConfig = Omit & { hasBearerTokenProvider?: boolean; }; /** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */ -type WireNamedProviderConfig = Omit & { +type WireNamedProviderConfig = Omit & { hasBearerTokenProvider?: boolean; }; /** - * Strips the non-serializable {@link GetBearerToken} callbacks from the singular + * Strips the non-serializable {@link BearerTokenProvider} callbacks from the singular * and named provider configs before they cross the RPC boundary, replacing each * with a `hasBearerTokenProvider: true` wire flag. The callback closes over its * own token scope/audience, so nothing scope-related crosses the wire — the @@ -185,14 +185,14 @@ function extractBearerTokenProviders( ): { wireProvider: WireProviderConfig | undefined; wireProviders: WireNamedProviderConfig[] | undefined; - callbacks: Map; + callbacks: Map; } { - const callbacks = new Map(); + const callbacks = new Map(); let wireProvider: WireProviderConfig | undefined = provider; - if (provider?.getBearerToken) { - const { getBearerToken, ...rest } = provider; - callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken); + if (provider?.bearerTokenProvider) { + const { bearerTokenProvider, ...rest } = provider; + callbacks.set(DEFAULT_PROVIDER_NAME, bearerTokenProvider); wireProvider = { ...rest, hasBearerTokenProvider: true, @@ -200,11 +200,11 @@ function extractBearerTokenProviders( } let wireProviders: WireNamedProviderConfig[] | undefined = providers; - if (providers?.some((p) => p.getBearerToken)) { + if (providers?.some((p) => p.bearerTokenProvider)) { wireProviders = providers.map((p) => { - if (!p.getBearerToken) return p; - const { getBearerToken, ...rest } = p; - callbacks.set(p.name, getBearerToken); + if (!p.bearerTokenProvider) return p; + const { bearerTokenProvider, ...rest } = p; + callbacks.set(p.name, bearerTokenProvider); return { ...rest, hasBearerTokenProvider: true, @@ -1305,7 +1305,7 @@ export class CopilotClient { const useServerGeneratedId = config.cloud != null && callerSessionId == null; const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID()); - // Strip non-serializable getBearerToken callbacks from provider configs, + // Strip non-serializable bearerTokenProvider callbacks from provider configs, // replacing them with a wire flag; keep the callbacks for session-side // registration so the runtime can call back to acquire tokens. const { diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 740a7bc89..eebf9add5 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -84,7 +84,7 @@ export type { MCPHTTPServerConfig, MCPServerConfig, DefaultAgentConfig, - GetBearerToken, + BearerTokenProvider, MessageOptions, ModelBilling, ModelBillingTokenPrices, diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts index 5501a2ead..8bf9589c3 100644 --- a/nodejs/src/session.ts +++ b/nodejs/src/session.ts @@ -26,7 +26,7 @@ import type { ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, - GetBearerToken, + BearerTokenProvider, UiInputOptions, MessageOptions, PermissionHandler, @@ -121,7 +121,7 @@ export class CopilotSession { new Map(); private toolHandlers: Map = new Map(); private canvases: Map = new Map(); - private bearerTokenProviders: Map = new Map(); + private bearerTokenProviders: Map = new Map(); private commandHandlers: Map = new Map(); private permissionHandler?: PermissionHandler; private userInputHandler?: UserInputHandler; @@ -798,7 +798,7 @@ export class CopilotSession { } /** - * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers + * Registers per-provider {@link BearerTokenProvider} callbacks for BYOK providers * configured with managed-identity / on-demand bearer-token auth. * * The runtime never receives the callback itself; the SDK strips it from the @@ -809,7 +809,7 @@ export class CopilotSession { * @param providers - Map of provider name → callback, or undefined/empty to clear. * @internal This method is called internally when creating/resuming a session. */ - registerBearerTokenProviders(providers?: Map): void { + registerBearerTokenProviders(providers?: Map): void { this.bearerTokenProviders.clear(); if (!providers || providers.size === 0) { delete this.clientSessionApis.providerToken; diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index ec88bfd74..946ef7423 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2215,7 +2215,7 @@ export interface ResumeSessionConfig extends SessionConfigBase { } /** - * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a + * Arguments passed to a {@link BearerTokenProvider} callback when the runtime needs a * fresh bearer token for a BYOK provider. * * @experimental Part of the experimental managed-identity / bearer-token-provider @@ -2253,7 +2253,7 @@ export interface ProviderTokenArgs { * @experimental Part of the experimental managed-identity / bearer-token-provider * surface and may change or be removed in future SDK or CLI releases. */ -export type GetBearerToken = (args: ProviderTokenArgs) => Promise; +export type BearerTokenProvider = (args: ProviderTokenArgs) => Promise; /** * Configuration for a custom API provider. @@ -2302,12 +2302,14 @@ export interface ProviderConfig { * When set, the SDK keeps this function client-side (it is never serialized) * and the runtime calls back into this client to acquire a token before each * outbound request. The runtime does no caching of its own, so the callback - * owns token caching and refresh. Mutually exclusive with {@link apiKey} / - * {@link bearerToken}. + * owns token caching and refresh. When set alongside {@link apiKey} / + * {@link bearerToken}, this callback takes precedence: the runtime applies + * the token it returns as the `Authorization: Bearer` header for each + * request and does not send the static credential. * * @experimental */ - getBearerToken?: GetBearerToken; + bearerTokenProvider?: BearerTokenProvider; /** * Azure-specific options @@ -2405,12 +2407,14 @@ export interface NamedProviderConfig { * When set, the SDK keeps this function client-side (it is never serialized) * and the runtime calls back into this client to acquire a token before each * outbound request. The runtime does no caching of its own, so the callback - * owns token caching and refresh. Mutually exclusive with {@link apiKey} / - * {@link bearerToken}. + * owns token caching and refresh. When set alongside {@link apiKey} / + * {@link bearerToken}, this callback takes precedence: the runtime applies + * the token it returns as the `Authorization: Bearer` header for each + * request and does not send the static credential. * * @experimental */ - getBearerToken?: GetBearerToken; + bearerTokenProvider?: BearerTokenProvider; /** * Azure-specific options. diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts index deae2fa84..c528fb23d 100644 --- a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts +++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts @@ -6,7 +6,7 @@ import { beforeEach, describe, expect, it } from "vitest"; import { approveAll, CopilotRequestHandler } from "../../src/index.js"; import type { CopilotRequestContext, - GetBearerToken, + BearerTokenProvider, NamedProviderConfig, ProviderModelConfig, } from "../../src/index.js"; @@ -40,7 +40,7 @@ const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`; * The runtime invokes {@link sendRequest} for every model-layer HTTP request it * would otherwise issue. We capture the ones aimed at a fake BYOK host — * recording the `Authorization` header the runtime applied after calling the - * provider's `getBearerToken` callback over the session-scoped + * provider's `bearerTokenProvider` callback over the session-scoped * `providerToken.getToken` RPC — and answer them with a synthetic `404` (a * non-retryable status, so each outbound model request yields exactly one * capture). Every other request (CAPI bootstrap: model catalog, policy, …) is @@ -89,7 +89,7 @@ class CapturingRequestHandler extends CopilotRequestHandler { /** * End-to-end coverage for the experimental BYOK bearer-token-provider surface - * (`getBearerToken` on a provider config). The callback stays entirely on the + * (`bearerTokenProvider` on a provider config). The callback stays entirely on the * SDK/client side: the SDK strips it from the wire config, sets the * `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped * `providerToken.getToken` RPC before each outbound model request, applying the @@ -144,7 +144,7 @@ describe("BYOK bearer-token provider", async () => { it("applies the callback's token as the Authorization header", async () => { const SENTINEL = "sentinel-bearer-token-abc123"; let calls = 0; - const getBearerToken: GetBearerToken = async () => { + const getBearerToken: BearerTokenProvider = async () => { calls += 1; return SENTINEL; }; @@ -155,7 +155,7 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: PRIMARY_BASE_URL, - getBearerToken, + bearerTokenProvider: getBearerToken, }, ]; const models: ProviderModelConfig[] = [ @@ -172,7 +172,7 @@ describe("BYOK bearer-token provider", async () => { it("re-acquires a fresh token for each request (no runtime caching)", async () => { let calls = 0; - const getBearerToken: GetBearerToken = async () => { + const getBearerToken: BearerTokenProvider = async () => { calls += 1; // A distinct token per acquisition proves the runtime re-invokes the // callback per request rather than caching a previous token. @@ -185,7 +185,7 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: PRIMARY_BASE_URL, - getBearerToken, + bearerTokenProvider: getBearerToken, }, ]; const models: ProviderModelConfig[] = [ @@ -211,7 +211,7 @@ describe("BYOK bearer-token provider", async () => { }; const acquiredFor: string[] = []; const makeCallback = - (providerName: string): GetBearerToken => + (providerName: string): BearerTokenProvider => async (args) => { // The runtime forwards the requesting provider's name so the client // can dispatch to the right credential. @@ -230,14 +230,14 @@ describe("BYOK bearer-token provider", async () => { type: "openai", wireApi: "completions", baseUrl: RED_BASE_URL, - getBearerToken: makeCallback("red"), + bearerTokenProvider: makeCallback("red"), }, { name: "blue", type: "openai", wireApi: "completions", baseUrl: BLUE_BASE_URL, - getBearerToken: makeCallback("blue"), + bearerTokenProvider: makeCallback("blue"), }, ]; const models: ProviderModelConfig[] = [ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 1e7a3afb1..d93f55180 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -100,7 +100,7 @@ ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, - GetBearerToken, + BearerTokenProvider, InfiniteSessionConfig, InputOptions, LargeToolOutputConfig, @@ -216,7 +216,7 @@ "ExtensionInfo", "CopilotWebSocketForwarder", "GetAuthStatusResponse", - "GetBearerToken", + "BearerTokenProvider", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", diff --git a/python/copilot/client.py b/python/copilot/client.py index ebfdcf992..c7d11d12b 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -81,6 +81,7 @@ ) from .session import ( AutoModeSwitchHandler, + BearerTokenProvider, CommandDefinition, ContextTier, CopilotSession, @@ -89,7 +90,6 @@ DefaultAgentConfig, ElicitationHandler, ExitPlanModeHandler, - GetBearerToken, InfiniteSessionConfig, LargeToolOutputConfig, MCPServerConfig, @@ -180,8 +180,8 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any] def _collect_bearer_token_callbacks( provider: ProviderConfig | None, providers: list[NamedProviderConfig] | None, -) -> dict[str, GetBearerToken]: - """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name. +) -> dict[str, BearerTokenProvider]: + """Collect per-provider ``bearer_token_provider`` callbacks keyed by provider name. The singular, whole-session ``provider`` uses the implicit ``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own @@ -189,14 +189,14 @@ def _collect_bearer_token_callbacks( ``hasBearerTokenProvider: true`` instead and the runtime calls back over ``providerToken.getToken``. """ - callbacks: dict[str, GetBearerToken] = {} + callbacks: dict[str, BearerTokenProvider] = {} if provider is not None: - singular = provider.get("get_bearer_token") + singular = provider.get("bearer_token_provider") if singular is not None: callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular if providers: for named in providers: - callback = named.get("get_bearer_token") + callback = named.get("bearer_token_provider") if callback is not None: callbacks[named["name"]] = callback return callbacks @@ -3266,7 +3266,7 @@ def _convert_provider_to_wire_format( wire_provider["transport"] = provider["transport"] if "bearer_token" in provider: wire_provider["bearerToken"] = provider["bearer_token"] - if provider.get("get_bearer_token") is not None: + if provider.get("bearer_token_provider") is not None: wire_provider["hasBearerTokenProvider"] = True if "headers" in provider: wire_provider["headers"] = provider["headers"] @@ -3304,7 +3304,7 @@ def _convert_named_provider_to_wire_format( wire["apiKey"] = provider["api_key"] if "bearer_token" in provider: wire["bearerToken"] = provider["bearer_token"] - if provider.get("get_bearer_token") is not None: + if provider.get("bearer_token_provider") is not None: wire["hasBearerTokenProvider"] = True if "headers" in provider: wire["headers"] = provider["headers"] diff --git a/python/copilot/session.py b/python/copilot/session.py index c0b52a42e..0dc569f25 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -1080,7 +1080,7 @@ class AzureProviderOptions(TypedDict, total=False): class ProviderTokenArgs(TypedDict): - """Arguments passed to a :data:`GetBearerToken` callback when the runtime + """Arguments passed to a :data:`BearerTokenProvider` callback when the runtime needs a fresh bearer token for a BYOK provider. **Experimental.** Part of the bearer-token-provider surface and may change or @@ -1104,7 +1104,7 @@ class ProviderTokenArgs(TypedDict): # Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on # the wire and answer the runtime's ``providerToken.getToken`` requests. May be # sync or async. -GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]] +BearerTokenProvider = Callable[[ProviderTokenArgs], str | Awaitable[str]] class ProviderConfig(TypedDict, total=False): @@ -1147,8 +1147,10 @@ class ProviderConfig(TypedDict, total=False): # provider (for example via Azure Managed Identity). Never serialized — the # SDK sends hasBearerTokenProvider: true on the wire and answers the # runtime's providerToken.getToken requests with this callback's result. - # Mutually exclusive with api_key and bearer_token. - get_bearer_token: GetBearerToken + # When set alongside api_key/bearer_token, this callback takes precedence: the + # runtime applies the token it returns as the Authorization: Bearer header for + # each request and does not send the static credential. + bearer_token_provider: BearerTokenProvider class NamedProviderConfig(TypedDict, total=False): @@ -1177,9 +1179,11 @@ class NamedProviderConfig(TypedDict, total=False): headers: dict[str, str] # Per-request bearer-token callback for this named BYOK provider. Never # serialized; the SDK sends hasBearerTokenProvider: true and answers the - # runtime's providerToken.getToken requests. Mutually exclusive with api_key - # and bearer_token. - get_bearer_token: GetBearerToken + # runtime's providerToken.getToken requests. When set alongside + # api_key/bearer_token, this callback takes precedence: the runtime applies + # the token it returns as the Authorization: Bearer header for each request + # and does not send the static credential. + bearer_token_provider: BearerTokenProvider class ProviderModelConfig(TypedDict, total=False): @@ -1253,7 +1257,7 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError: class _BearerTokenProviderAdapter: """Routes runtime ``providerToken.getToken`` requests to the matching - per-provider :data:`GetBearerToken` callback registered on the session. + per-provider :data:`BearerTokenProvider` callback registered on the session. The runtime calls this once per outbound request for a BYOK provider that declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK @@ -1348,7 +1352,7 @@ def __init__( self._transform_callbacks_lock = threading.Lock() self._command_handlers: dict[str, CommandHandler] = {} self._command_handlers_lock = threading.Lock() - self._bearer_token_providers: dict[str, GetBearerToken] = {} + self._bearer_token_providers: dict[str, BearerTokenProvider] = {} self._bearer_token_providers_lock = threading.Lock() self._elicitation_handler: ElicitationHandler | None = None self._elicitation_handler_lock = threading.Lock() @@ -2090,7 +2094,9 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None: for cmd in commands: self._command_handlers[cmd.name] = cmd.handler - def _register_bearer_token_providers(self, providers: dict[str, GetBearerToken] | None) -> None: + def _register_bearer_token_providers( + self, providers: dict[str, BearerTokenProvider] | None + ) -> None: """Register per-provider bearer-token callbacks for this session. The runtime never receives the callbacks themselves; the SDK strips them diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py index 855fb8b78..37dfbc009 100644 --- a/python/e2e/test_byok_bearer_token_provider_e2e.py +++ b/python/e2e/test_byok_bearer_token_provider_e2e.py @@ -5,7 +5,7 @@ """E2E coverage for the experimental BYOK bearer-token-provider surface. Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK -provider config may carry a ``get_bearer_token`` callback; the callback stays +provider config may carry a ``bearer_token_provider`` callback; the callback stays entirely on the SDK/client side. The SDK strips it from the wire config, sets the ``hasBearerTokenProvider`` flag, and the runtime calls back over the session-scoped ``providerToken.getToken`` RPC before each outbound model @@ -31,7 +31,7 @@ import pytest_asyncio from copilot import CopilotRequestContext, CopilotRequestHandler -from copilot.session import GetBearerToken, PermissionHandler +from copilot.session import BearerTokenProvider, PermissionHandler from ._copilot_request_helpers import build_isolated_client, build_non_inference_response from .testharness import E2ETestContext @@ -56,7 +56,7 @@ class _CapturingRequestHandler(CopilotRequestHandler): The runtime invokes :meth:`send_request` for every model-layer HTTP request. Requests aimed at a fake BYOK host are captured — recording the ``Authorization`` header the runtime applied after calling the provider's - ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered + ``bearer_token_provider`` callback over ``providerToken.getToken`` — and answered with a synthetic ``404`` (non-retryable, so each outbound model request yields exactly one capture). Every other request (CAPI bootstrap: model catalog, policy, …) is fabricated locally so no real network or CAPI proxy @@ -126,6 +126,7 @@ async def _run_turn(client, providers, models, selection_id: str, prompt: str) - try: await session.send_and_wait(prompt) except Exception: + # The fake BYOK endpoint intentionally errors after capture. pass finally: try: @@ -154,7 +155,7 @@ async def get_bearer_token(args) -> str: "type": "openai", "wire_api": "completions", "base_url": PRIMARY_BASE_URL, - "get_bearer_token": get_bearer_token, + "bearer_token_provider": get_bearer_token, } ] models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] @@ -185,7 +186,7 @@ async def get_bearer_token(args) -> str: "type": "openai", "wire_api": "completions", "base_url": PRIMARY_BASE_URL, - "get_bearer_token": get_bearer_token, + "bearer_token_provider": get_bearer_token, } ] models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}] @@ -208,7 +209,7 @@ async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture): token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"} acquired_for: list[str] = [] - def make_callback(provider_name: str) -> GetBearerToken: + def make_callback(provider_name: str) -> BearerTokenProvider: async def callback(args) -> str: # The runtime forwards the requesting provider's name so the # client can dispatch to the right credential. @@ -227,14 +228,14 @@ async def callback(args) -> str: "type": "openai", "wire_api": "completions", "base_url": RED_BASE_URL, - "get_bearer_token": make_callback("red"), + "bearer_token_provider": make_callback("red"), }, { "name": "blue", "type": "openai", "wire_api": "completions", "base_url": BLUE_BASE_URL, - "get_bearer_token": make_callback("blue"), + "bearer_token_provider": make_callback("blue"), }, ] models = [ diff --git a/rust/src/types.rs b/rust/src/types.rs index d62e20f70..75408db02 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -1053,7 +1053,7 @@ pub struct ProviderConfig { /// **Experimental.** Callback used to acquire a bearer token before each /// outbound request to this provider. #[serde(skip)] - pub get_bearer_token: Option>, + pub bearer_token_provider: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. @@ -1097,8 +1097,8 @@ impl std::fmt::Debug for ProviderConfig { .field("api_key", &self.api_key) .field("bearer_token", &self.bearer_token) .field( - "get_bearer_token", - &self.get_bearer_token.as_ref().map(|_| ""), + "bearer_token_provider", + &self.bearer_token_provider.as_ref().map(|_| ""), ) .field("has_bearer_token_provider", &self.has_bearer_token_provider) .field("azure", &self.azure) @@ -1158,8 +1158,8 @@ impl ProviderConfig { /// /// **Experimental.** This method is part of an experimental wire-protocol /// surface and may change or be removed in a future release. - pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { - self.get_bearer_token = Some(provider); + pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self { + self.bearer_token_provider = Some(provider); self } @@ -1291,7 +1291,7 @@ pub struct NamedProviderConfig { /// **Experimental.** Callback used to acquire a bearer token before each /// outbound request to this provider. #[serde(skip)] - pub get_bearer_token: Option>, + pub bearer_token_provider: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub(crate) has_bearer_token_provider: Option, /// Azure-specific options. @@ -1312,8 +1312,8 @@ impl std::fmt::Debug for NamedProviderConfig { .field("api_key", &self.api_key) .field("bearer_token", &self.bearer_token) .field( - "get_bearer_token", - &self.get_bearer_token.as_ref().map(|_| ""), + "bearer_token_provider", + &self.bearer_token_provider.as_ref().map(|_| ""), ) .field("has_bearer_token_provider", &self.has_bearer_token_provider) .field("azure", &self.azure) @@ -1363,8 +1363,8 @@ impl NamedProviderConfig { /// /// **Experimental.** This method is part of an experimental wire-protocol /// surface and may change or be removed in a future release. - pub fn with_get_bearer_token(mut self, provider: Arc) -> Self { - self.get_bearer_token = Some(provider); + pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self { + self.bearer_token_provider = Some(provider); self } @@ -1388,7 +1388,7 @@ fn prepare_bearer_token_providers( let mut bearer_token_providers = HashMap::new(); if let Some(provider) = provider.as_mut() - && let Some(token_provider) = provider.get_bearer_token.take() + && let Some(token_provider) = provider.bearer_token_provider.take() { provider.has_bearer_token_provider = Some(true); bearer_token_providers.insert("default".to_string(), token_provider); @@ -1396,7 +1396,7 @@ fn prepare_bearer_token_providers( if let Some(providers) = providers.as_mut() { for provider in providers { - if let Some(token_provider) = provider.get_bearer_token.take() { + if let Some(token_provider) = provider.bearer_token_provider.take() { provider.has_bearer_token_provider = Some(true); bearer_token_providers.insert(provider.name.clone(), token_provider); } diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs index 5e9b28170..fc3ef89d9 100644 --- a/rust/tests/e2e/byok_bearer_token_provider.rs +++ b/rust/tests/e2e/byok_bearer_token_provider.rs @@ -155,7 +155,7 @@ async fn callback_token_is_applied_as_authorization_header() { NamedProviderConfig::new("mi", PRIMARY_BASE_URL) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| { let callback_calls = callback_calls.clone(); async move { callback_calls.fetch_add(1, Ordering::SeqCst); @@ -202,7 +202,7 @@ async fn reacquires_a_fresh_token_for_each_request() { NamedProviderConfig::new("mi", PRIMARY_BASE_URL) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| { let callback_calls = callback_calls.clone(); async move { let call = callback_calls.fetch_add(1, Ordering::SeqCst) + 1; @@ -266,7 +266,7 @@ async fn dispatches_token_acquisition_per_provider() { NamedProviderConfig::new(name, base_url) .with_provider_type("openai") .with_wire_api("completions") - .with_get_bearer_token(Arc::new(move |args: ProviderTokenArgs| { + .with_bearer_token_provider(Arc::new(move |args: ProviderTokenArgs| { let acquired_for = acquired_for.clone(); async move { assert_eq!(args.provider_name, name); From 903db34b45ef8098d1abdb8c1d7e323d6d83a36e Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 25 Jun 2026 10:27:39 +0100 Subject: [PATCH 3/4] Make ProviderTokenArgs immutable in Node and Java The args object is constructed solely by the SDK and only read inside the consumer callback, so it should not be mutable by consumers. - Java: make fields final, keep only the all-args constructor, and drop the no-arg constructor and the fluent setters. The runtime already builds it via new ProviderTokenArgs(providerName, sessionId). - Node: mark the interface fields readonly. .NET already uses init-only required properties and Rust passes a shared reference, so both are already immutable. Go and Python have no idiomatic readonly equivalent for these DTO shapes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../github/copilot/rpc/ProviderTokenArgs.java | 34 ++----------------- nodejs/src/types.ts | 4 +-- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java index 1515c9c5e..dde5ed86d 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -17,15 +17,9 @@ @CopilotExperimental public class ProviderTokenArgs { - private String providerName; + private final String providerName; - private String sessionId; - - /** - * Creates an empty argument object. - */ - public ProviderTokenArgs() { - } + private final String sessionId; /** * Creates argument object for the named provider. @@ -54,18 +48,6 @@ public String getProviderName() { return providerName; } - /** - * Sets the name of the BYOK provider needing a token. - * - * @param providerName - * the provider name - * @return this args instance for method chaining - */ - public ProviderTokenArgs setProviderName(String providerName) { - this.providerName = providerName; - return this; - } - /** * Gets the id of the session that triggered this token request. *

@@ -78,16 +60,4 @@ public ProviderTokenArgs setProviderName(String providerName) { public String getSessionId() { return sessionId; } - - /** - * Sets the id of the session that triggered this token request. - * - * @param sessionId - * the session id - * @return this args instance for method chaining - */ - public ProviderTokenArgs setSessionId(String sessionId) { - this.sessionId = sessionId; - return this; - } } diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 946ef7423..e354bd821 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -2230,7 +2230,7 @@ export interface ProviderTokenArgs { * The callback closes over its own token scope/audience; the runtime is * provider-agnostic and forwards only the provider name. */ - providerName: string; + readonly providerName: string; /** * Id of the session that triggered this token request. A client-level shared @@ -2238,7 +2238,7 @@ export interface ProviderTokenArgs { * session (e.g. via the client's session lookup) to scope token acquisition * or caching per session. */ - sessionId: string; + readonly sessionId: string; } /** From e1ea80c5225ef08f1c30995e0346bd84862e8068 Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 25 Jun 2026 10:29:55 +0100 Subject: [PATCH 4/4] Fix Java spotless line-wrapping and Python import order Apply the exact reflows spotless computed (javadoc wrapping in the e2e test and ProviderTokenArgs, and the provider.getToken call in the dispatcher), and move BearerTokenProvider to its correct alphabetical position in the python package exports after the getBearerToken rename. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../main/java/com/github/copilot/RpcHandlerDispatcher.java | 3 ++- .../main/java/com/github/copilot/rpc/ProviderTokenArgs.java | 4 ++-- .../com/github/copilot/ByokBearerTokenProviderE2ETest.java | 6 +++--- python/copilot/__init__.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java index d6526e6a1..9a42a8e22 100644 --- a/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java +++ b/java/src/main/java/com/github/copilot/RpcHandlerDispatcher.java @@ -328,7 +328,8 @@ private void handleProviderTokenGetToken(JsonRpcClient rpc, String requestId, Js return; } - CompletableFuture tokenFuture = provider.getToken(new ProviderTokenArgs(providerName, sessionId)); + CompletableFuture tokenFuture = provider + .getToken(new ProviderTokenArgs(providerName, sessionId)); if (tokenFuture == null) { rpc.sendErrorResponse(requestIdLong, -32603, "Bearer-token provider returned null future for provider " + providerName); diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java index dde5ed86d..009734ad1 100644 --- a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java +++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java @@ -51,8 +51,8 @@ public String getProviderName() { /** * Gets the id of the session that triggered this token request. *

- * A client-level shared callback registered for many sessions can use this - * to resolve the owning session and scope token acquisition or caching per + * A client-level shared callback registered for many sessions can use this to + * resolve the owning session and scope token acquisition or caching per * session. * * @return the session id diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java index 280e5fd24..b035bd54d 100644 --- a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java +++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java @@ -43,9 +43,9 @@ /** * End-to-end coverage for the experimental BYOK bearer-token-provider surface - * ({@code BearerTokenProvider} on a provider config). The callback stays entirely on - * the SDK/client side: the SDK keeps it off the wire, sends only the - * {@code hasBearerTokenProvider} flag, and the runtime calls back over the + * ({@code BearerTokenProvider} on a provider config). The callback stays + * entirely on the SDK/client side: the SDK keeps it off the wire, sends only + * the {@code hasBearerTokenProvider} flag, and the runtime calls back over the * session-scoped {@code providerToken.getToken} RPC before each outbound model * request. */ diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index d93f55180..ff13d47de 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -86,6 +86,7 @@ AutoModeSwitchHandler, AutoModeSwitchRequest, AutoModeSwitchResponse, + BearerTokenProvider, CommandContext, CommandDefinition, CopilotSession, @@ -100,7 +101,6 @@ ExitPlanModeHandler, ExitPlanModeRequest, ExitPlanModeResult, - BearerTokenProvider, InfiniteSessionConfig, InputOptions, LargeToolOutputConfig,