bearerTokenProviders = new HashMap<>();
+ if (provider != null && provider.getBearerTokenProvider() != null) {
+ bearerTokenProviders.put("default", provider.getBearerTokenProvider());
}
if (providers != null) {
for (NamedProviderConfig namedProvider : providers) {
if (namedProvider != null && namedProvider.getName() != null
- && namedProvider.getGetBearerToken() != null) {
- bearerTokenProviders.put(namedProvider.getName(), namedProvider.getGetBearerToken());
+ && namedProvider.getBearerTokenProvider() != null) {
+ bearerTokenProviders.put(namedProvider.getName(), namedProvider.getBearerTokenProvider());
}
}
}
diff --git a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java
similarity index 87%
rename from java/src/main/java/com/github/copilot/rpc/GetBearerToken.java
rename to java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java
index 27ec7f09c..7b37925aa 100644
--- a/java/src/main/java/com/github/copilot/rpc/GetBearerToken.java
+++ b/java/src/main/java/com/github/copilot/rpc/BearerTokenProvider.java
@@ -20,13 +20,13 @@
* Experimental. This managed-identity surface may change or be
* removed in future SDK or CLI releases.
*
- * @see ProviderConfig#setGetBearerToken(GetBearerToken)
- * @see NamedProviderConfig#setGetBearerToken(GetBearerToken)
+ * @see ProviderConfig#setBearerTokenProvider(BearerTokenProvider)
+ * @see NamedProviderConfig#setBearerTokenProvider(BearerTokenProvider)
* @since 1.0.0
*/
@CopilotExperimental
@FunctionalInterface
-public interface GetBearerToken {
+public interface BearerTokenProvider {
/**
* Gets a bearer token for the provider identified by {@code args}.
diff --git a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
index 2bdf2678f..e3b090019 100644
--- a/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
+++ b/java/src/main/java/com/github/copilot/rpc/NamedProviderConfig.java
@@ -61,7 +61,7 @@ public class NamedProviderConfig {
private String bearerToken;
@JsonIgnore
- private GetBearerToken getBearerToken;
+ private BearerTokenProvider bearerTokenProvider;
@JsonProperty("azure")
private AzureOptions azure;
@@ -221,8 +221,8 @@ public NamedProviderConfig setBearerToken(String bearerToken) {
*
* @return the bearer-token provider callback, or {@code null} if not set
*/
- public GetBearerToken getGetBearerToken() {
- return getBearerToken;
+ public BearerTokenProvider getBearerTokenProvider() {
+ return bearerTokenProvider;
}
/**
@@ -234,19 +234,19 @@ public GetBearerToken getGetBearerToken() {
* RPC before each model request. Return the raw token without a {@code Bearer }
* prefix.
*
- * @param getBearerToken
+ * @param bearerTokenProvider
* the bearer-token provider callback
* @return this config for method chaining
*/
- public NamedProviderConfig setGetBearerToken(GetBearerToken getBearerToken) {
- this.getBearerToken = getBearerToken;
+ public NamedProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) {
+ this.bearerTokenProvider = bearerTokenProvider;
return this;
}
@JsonProperty("hasBearerTokenProvider")
@JsonInclude(JsonInclude.Include.NON_NULL)
Boolean hasBearerTokenProviderWireFlag() {
- return getBearerToken != null ? Boolean.TRUE : null;
+ return bearerTokenProvider != null ? Boolean.TRUE : null;
}
/**
diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
index ae59e7ead..3d6faba34 100644
--- a/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
+++ b/java/src/main/java/com/github/copilot/rpc/ProviderConfig.java
@@ -57,7 +57,7 @@ public class ProviderConfig {
private String bearerToken;
@JsonIgnore
- private GetBearerToken getBearerToken;
+ private BearerTokenProvider bearerTokenProvider;
@JsonProperty("azure")
private AzureOptions azure;
@@ -230,8 +230,8 @@ public ProviderConfig setBearerToken(String bearerToken) {
*
* @return the bearer-token provider callback, or {@code null} if not set
*/
- public GetBearerToken getGetBearerToken() {
- return getBearerToken;
+ public BearerTokenProvider getBearerTokenProvider() {
+ return bearerTokenProvider;
}
/**
@@ -243,19 +243,19 @@ public GetBearerToken getGetBearerToken() {
* RPC before each model request. Return the raw token without a {@code Bearer }
* prefix.
*
- * @param getBearerToken
+ * @param bearerTokenProvider
* the bearer-token provider callback
* @return this config for method chaining
*/
- public ProviderConfig setGetBearerToken(GetBearerToken getBearerToken) {
- this.getBearerToken = getBearerToken;
+ public ProviderConfig setBearerTokenProvider(BearerTokenProvider bearerTokenProvider) {
+ this.bearerTokenProvider = bearerTokenProvider;
return this;
}
@JsonProperty("hasBearerTokenProvider")
@JsonInclude(JsonInclude.Include.NON_NULL)
Boolean hasBearerTokenProviderWireFlag() {
- return getBearerToken != null ? Boolean.TRUE : null;
+ return bearerTokenProvider != null ? Boolean.TRUE : null;
}
/**
diff --git a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java
index 3866cc0ad..009734ad1 100644
--- a/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java
+++ b/java/src/main/java/com/github/copilot/rpc/ProviderTokenArgs.java
@@ -17,13 +17,9 @@
@CopilotExperimental
public class ProviderTokenArgs {
- private String providerName;
+ private final String providerName;
- /**
- * Creates an empty argument object.
- */
- public ProviderTokenArgs() {
- }
+ private final String sessionId;
/**
* Creates argument object for the named provider.
@@ -32,9 +28,12 @@ public ProviderTokenArgs() {
* the name of the BYOK provider needing a token; {@code "default"}
* for the singular whole-session provider, otherwise the named
* provider's {@code name}
+ * @param sessionId
+ * the id of the session that triggered this token request
*/
- public ProviderTokenArgs(String providerName) {
+ public ProviderTokenArgs(String providerName, String sessionId) {
this.providerName = providerName;
+ this.sessionId = sessionId;
}
/**
@@ -50,14 +49,15 @@ public String getProviderName() {
}
/**
- * Sets the name of the BYOK provider needing a token.
+ * Gets the id of the session that triggered this token request.
+ *
+ * A client-level shared callback registered for many sessions can use this to
+ * resolve the owning session and scope token acquisition or caching per
+ * session.
*
- * @param providerName
- * the provider name
- * @return this args instance for method chaining
+ * @return the session id
*/
- public ProviderTokenArgs setProviderName(String providerName) {
- this.providerName = providerName;
- return this;
+ public String getSessionId() {
+ return sessionId;
}
}
diff --git a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java
index 253ce136c..b035bd54d 100644
--- a/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java
+++ b/java/src/test/java/com/github/copilot/ByokBearerTokenProviderE2ETest.java
@@ -34,7 +34,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-import com.github.copilot.rpc.GetBearerToken;
+import com.github.copilot.rpc.BearerTokenProvider;
import com.github.copilot.rpc.MessageOptions;
import com.github.copilot.rpc.NamedProviderConfig;
import com.github.copilot.rpc.PermissionHandler;
@@ -43,9 +43,9 @@
/**
* End-to-end coverage for the experimental BYOK bearer-token-provider surface
- * ({@code getBearerToken} on a provider config). The callback stays entirely on
- * the SDK/client side: the SDK keeps it off the wire, sends only the
- * {@code hasBearerTokenProvider} flag, and the runtime calls back over the
+ * ({@code BearerTokenProvider} on a provider config). The callback stays
+ * entirely on the SDK/client side: the SDK keeps it off the wire, sends only
+ * the {@code hasBearerTokenProvider} flag, and the runtime calls back over the
* session-scoped {@code providerToken.getToken} RPC before each outbound model
* request.
*/
@@ -82,13 +82,13 @@ void resetHandler() {
void appliesCallbackTokenAsAuthorizationHeader() throws Exception {
String sentinel = "sentinel-bearer-token-abc123";
AtomicInteger calls = new AtomicInteger();
- GetBearerToken getBearerToken = args -> {
+ BearerTokenProvider tokenProvider = args -> {
calls.incrementAndGet();
return CompletableFuture.completedFuture(sentinel);
};
List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai")
- .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken));
+ .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider));
List models = List
.of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o"));
@@ -102,11 +102,11 @@ void appliesCallbackTokenAsAuthorizationHeader() throws Exception {
@Test
void reacquiresFreshTokenForEachRequest() throws Exception {
AtomicInteger calls = new AtomicInteger();
- GetBearerToken getBearerToken = args -> CompletableFuture
+ BearerTokenProvider tokenProvider = args -> CompletableFuture
.completedFuture("rotating-token-" + calls.incrementAndGet());
List providers = List.of(new NamedProviderConfig().setName("mi").setType("openai")
- .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setGetBearerToken(getBearerToken));
+ .setWireApi("completions").setBaseUrl(PRIMARY_BASE_URL).setBearerTokenProvider(tokenProvider));
List models = List
.of(new ProviderModelConfig().setId("default").setProvider("mi").setWireModel("byok-gpt-4o"));
@@ -124,15 +124,19 @@ void reacquiresFreshTokenForEachRequest() throws Exception {
@Test
void dispatchesTokenAcquisitionPerProvider() throws Exception {
List acquiredFor = new ArrayList<>();
- GetBearerToken redCallback = args -> {
+ BearerTokenProvider redCallback = args -> {
assertEquals("red", args.getProviderName(), "Expected providerName to be forwarded");
+ assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(),
+ "Expected a non-empty session id in token args");
synchronized (acquiredFor) {
acquiredFor.add("red");
}
return CompletableFuture.completedFuture("token-for-red");
};
- GetBearerToken blueCallback = args -> {
+ BearerTokenProvider blueCallback = args -> {
assertEquals("blue", args.getProviderName(), "Expected providerName to be forwarded");
+ assertTrue(args.getSessionId() != null && !args.getSessionId().isEmpty(),
+ "Expected a non-empty session id in token args");
synchronized (acquiredFor) {
acquiredFor.add("blue");
}
@@ -141,9 +145,9 @@ void dispatchesTokenAcquisitionPerProvider() throws Exception {
List providers = List.of(
new NamedProviderConfig().setName("red").setType("openai").setWireApi("completions")
- .setBaseUrl(RED_BASE_URL).setGetBearerToken(redCallback),
+ .setBaseUrl(RED_BASE_URL).setBearerTokenProvider(redCallback),
new NamedProviderConfig().setName("blue").setType("openai").setWireApi("completions")
- .setBaseUrl(BLUE_BASE_URL).setGetBearerToken(blueCallback));
+ .setBaseUrl(BLUE_BASE_URL).setBearerTokenProvider(blueCallback));
List models = List.of(
new ProviderModelConfig().setId("default").setProvider("red").setWireModel("byok-gpt-4o"),
new ProviderModelConfig().setId("default").setProvider("blue").setWireModel("byok-gpt-4o"));
diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts
index e0315b211..53686a6ca 100644
--- a/nodejs/src/client.ts
+++ b/nodejs/src/client.ts
@@ -51,7 +51,7 @@ import type {
ExitPlanModeResult,
ForegroundSessionInfo,
GetAuthStatusResponse,
- GetBearerToken,
+ BearerTokenProvider,
GetStatusResponse,
InternalRuntimeConnection,
LargeToolOutputConfig,
@@ -161,17 +161,17 @@ function toJsonSchema(parameters: Tool["parameters"]): Record |
const DEFAULT_PROVIDER_NAME = "default";
/** Wire-safe singular provider config carrying the `hasBearerTokenProvider` flag. */
-type WireProviderConfig = Omit & {
+type WireProviderConfig = Omit & {
hasBearerTokenProvider?: boolean;
};
/** Wire-safe named provider config carrying the `hasBearerTokenProvider` flag. */
-type WireNamedProviderConfig = Omit & {
+type WireNamedProviderConfig = Omit & {
hasBearerTokenProvider?: boolean;
};
/**
- * Strips the non-serializable {@link GetBearerToken} callbacks from the singular
+ * Strips the non-serializable {@link BearerTokenProvider} callbacks from the singular
* and named provider configs before they cross the RPC boundary, replacing each
* with a `hasBearerTokenProvider: true` wire flag. The callback closes over its
* own token scope/audience, so nothing scope-related crosses the wire — the
@@ -185,14 +185,14 @@ function extractBearerTokenProviders(
): {
wireProvider: WireProviderConfig | undefined;
wireProviders: WireNamedProviderConfig[] | undefined;
- callbacks: Map;
+ callbacks: Map;
} {
- const callbacks = new Map();
+ const callbacks = new Map();
let wireProvider: WireProviderConfig | undefined = provider;
- if (provider?.getBearerToken) {
- const { getBearerToken, ...rest } = provider;
- callbacks.set(DEFAULT_PROVIDER_NAME, getBearerToken);
+ if (provider?.bearerTokenProvider) {
+ const { bearerTokenProvider, ...rest } = provider;
+ callbacks.set(DEFAULT_PROVIDER_NAME, bearerTokenProvider);
wireProvider = {
...rest,
hasBearerTokenProvider: true,
@@ -200,11 +200,11 @@ function extractBearerTokenProviders(
}
let wireProviders: WireNamedProviderConfig[] | undefined = providers;
- if (providers?.some((p) => p.getBearerToken)) {
+ if (providers?.some((p) => p.bearerTokenProvider)) {
wireProviders = providers.map((p) => {
- if (!p.getBearerToken) return p;
- const { getBearerToken, ...rest } = p;
- callbacks.set(p.name, getBearerToken);
+ if (!p.bearerTokenProvider) return p;
+ const { bearerTokenProvider, ...rest } = p;
+ callbacks.set(p.name, bearerTokenProvider);
return {
...rest,
hasBearerTokenProvider: true,
@@ -1305,7 +1305,7 @@ export class CopilotClient {
const useServerGeneratedId = config.cloud != null && callerSessionId == null;
const localSessionId = useServerGeneratedId ? undefined : (callerSessionId ?? randomUUID());
- // Strip non-serializable getBearerToken callbacks from provider configs,
+ // Strip non-serializable bearerTokenProvider callbacks from provider configs,
// replacing them with a wire flag; keep the callbacks for session-side
// registration so the runtime can call back to acquire tokens.
const {
diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts
index 740a7bc89..eebf9add5 100644
--- a/nodejs/src/index.ts
+++ b/nodejs/src/index.ts
@@ -84,7 +84,7 @@ export type {
MCPHTTPServerConfig,
MCPServerConfig,
DefaultAgentConfig,
- GetBearerToken,
+ BearerTokenProvider,
MessageOptions,
ModelBilling,
ModelBillingTokenPrices,
diff --git a/nodejs/src/session.ts b/nodejs/src/session.ts
index d87d2b9de..8bf9589c3 100644
--- a/nodejs/src/session.ts
+++ b/nodejs/src/session.ts
@@ -26,7 +26,7 @@ import type {
ExitPlanModeHandler,
ExitPlanModeRequest,
ExitPlanModeResult,
- GetBearerToken,
+ BearerTokenProvider,
UiInputOptions,
MessageOptions,
PermissionHandler,
@@ -121,7 +121,7 @@ export class CopilotSession {
new Map();
private toolHandlers: Map = new Map();
private canvases: Map = new Map();
- private bearerTokenProviders: Map = new Map();
+ private bearerTokenProviders: Map = new Map();
private commandHandlers: Map = new Map();
private permissionHandler?: PermissionHandler;
private userInputHandler?: UserInputHandler;
@@ -798,7 +798,7 @@ export class CopilotSession {
}
/**
- * Registers per-provider {@link GetBearerToken} callbacks for BYOK providers
+ * Registers per-provider {@link BearerTokenProvider} callbacks for BYOK providers
* configured with managed-identity / on-demand bearer-token auth.
*
* The runtime never receives the callback itself; the SDK strips it from the
@@ -809,7 +809,7 @@ export class CopilotSession {
* @param providers - Map of provider name → callback, or undefined/empty to clear.
* @internal This method is called internally when creating/resuming a session.
*/
- registerBearerTokenProviders(providers?: Map): void {
+ registerBearerTokenProviders(providers?: Map): void {
this.bearerTokenProviders.clear();
if (!providers || providers.size === 0) {
delete this.clientSessionApis.providerToken;
@@ -830,6 +830,7 @@ export class CopilotSession {
}
const token = await callback({
providerName: params.providerName,
+ sessionId: params.sessionId,
});
return { token };
},
diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts
index 61d9ca06d..e354bd821 100644
--- a/nodejs/src/types.ts
+++ b/nodejs/src/types.ts
@@ -2215,7 +2215,7 @@ export interface ResumeSessionConfig extends SessionConfigBase {
}
/**
- * Arguments passed to a {@link GetBearerToken} callback when the runtime needs a
+ * Arguments passed to a {@link BearerTokenProvider} callback when the runtime needs a
* fresh bearer token for a BYOK provider.
*
* @experimental Part of the experimental managed-identity / bearer-token-provider
@@ -2230,7 +2230,15 @@ export interface ProviderTokenArgs {
* The callback closes over its own token scope/audience; the runtime is
* provider-agnostic and forwards only the provider name.
*/
- providerName: string;
+ readonly providerName: string;
+
+ /**
+ * Id of the session that triggered this token request. A client-level shared
+ * callback registered for many sessions can use this to resolve the owning
+ * session (e.g. via the client's session lookup) to scope token acquisition
+ * or caching per session.
+ */
+ readonly sessionId: string;
}
/**
@@ -2245,7 +2253,7 @@ export interface ProviderTokenArgs {
* @experimental Part of the experimental managed-identity / bearer-token-provider
* surface and may change or be removed in future SDK or CLI releases.
*/
-export type GetBearerToken = (args: ProviderTokenArgs) => Promise;
+export type BearerTokenProvider = (args: ProviderTokenArgs) => Promise;
/**
* Configuration for a custom API provider.
@@ -2294,12 +2302,14 @@ export interface ProviderConfig {
* When set, the SDK keeps this function client-side (it is never serialized)
* and the runtime calls back into this client to acquire a token before each
* outbound request. The runtime does no caching of its own, so the callback
- * owns token caching and refresh. Mutually exclusive with {@link apiKey} /
- * {@link bearerToken}.
+ * owns token caching and refresh. When set alongside {@link apiKey} /
+ * {@link bearerToken}, this callback takes precedence: the runtime applies
+ * the token it returns as the `Authorization: Bearer` header for each
+ * request and does not send the static credential.
*
* @experimental
*/
- getBearerToken?: GetBearerToken;
+ bearerTokenProvider?: BearerTokenProvider;
/**
* Azure-specific options
@@ -2397,12 +2407,14 @@ export interface NamedProviderConfig {
* When set, the SDK keeps this function client-side (it is never serialized)
* and the runtime calls back into this client to acquire a token before each
* outbound request. The runtime does no caching of its own, so the callback
- * owns token caching and refresh. Mutually exclusive with {@link apiKey} /
- * {@link bearerToken}.
+ * owns token caching and refresh. When set alongside {@link apiKey} /
+ * {@link bearerToken}, this callback takes precedence: the runtime applies
+ * the token it returns as the `Authorization: Bearer` header for each
+ * request and does not send the static credential.
*
* @experimental
*/
- getBearerToken?: GetBearerToken;
+ bearerTokenProvider?: BearerTokenProvider;
/**
* Azure-specific options.
diff --git a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts
index 228b7a022..c528fb23d 100644
--- a/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts
+++ b/nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts
@@ -6,7 +6,7 @@ import { beforeEach, describe, expect, it } from "vitest";
import { approveAll, CopilotRequestHandler } from "../../src/index.js";
import type {
CopilotRequestContext,
- GetBearerToken,
+ BearerTokenProvider,
NamedProviderConfig,
ProviderModelConfig,
} from "../../src/index.js";
@@ -40,7 +40,7 @@ const BLUE_BASE_URL = `https://${BLUE_HOST}/v1`;
* The runtime invokes {@link sendRequest} for every model-layer HTTP request it
* would otherwise issue. We capture the ones aimed at a fake BYOK host —
* recording the `Authorization` header the runtime applied after calling the
- * provider's `getBearerToken` callback over the session-scoped
+ * provider's `bearerTokenProvider` callback over the session-scoped
* `providerToken.getToken` RPC — and answer them with a synthetic `404` (a
* non-retryable status, so each outbound model request yields exactly one
* capture). Every other request (CAPI bootstrap: model catalog, policy, …) is
@@ -89,7 +89,7 @@ class CapturingRequestHandler extends CopilotRequestHandler {
/**
* End-to-end coverage for the experimental BYOK bearer-token-provider surface
- * (`getBearerToken` on a provider config). The callback stays entirely on the
+ * (`bearerTokenProvider` on a provider config). The callback stays entirely on the
* SDK/client side: the SDK strips it from the wire config, sets the
* `hasBearerTokenProvider` flag, and the runtime calls back over the session-scoped
* `providerToken.getToken` RPC before each outbound model request, applying the
@@ -144,7 +144,7 @@ describe("BYOK bearer-token provider", async () => {
it("applies the callback's token as the Authorization header", async () => {
const SENTINEL = "sentinel-bearer-token-abc123";
let calls = 0;
- const getBearerToken: GetBearerToken = async () => {
+ const getBearerToken: BearerTokenProvider = async () => {
calls += 1;
return SENTINEL;
};
@@ -155,7 +155,7 @@ describe("BYOK bearer-token provider", async () => {
type: "openai",
wireApi: "completions",
baseUrl: PRIMARY_BASE_URL,
- getBearerToken,
+ bearerTokenProvider: getBearerToken,
},
];
const models: ProviderModelConfig[] = [
@@ -172,7 +172,7 @@ describe("BYOK bearer-token provider", async () => {
it("re-acquires a fresh token for each request (no runtime caching)", async () => {
let calls = 0;
- const getBearerToken: GetBearerToken = async () => {
+ const getBearerToken: BearerTokenProvider = async () => {
calls += 1;
// A distinct token per acquisition proves the runtime re-invokes the
// callback per request rather than caching a previous token.
@@ -185,7 +185,7 @@ describe("BYOK bearer-token provider", async () => {
type: "openai",
wireApi: "completions",
baseUrl: PRIMARY_BASE_URL,
- getBearerToken,
+ bearerTokenProvider: getBearerToken,
},
];
const models: ProviderModelConfig[] = [
@@ -211,11 +211,15 @@ describe("BYOK bearer-token provider", async () => {
};
const acquiredFor: string[] = [];
const makeCallback =
- (providerName: string): GetBearerToken =>
+ (providerName: string): BearerTokenProvider =>
async (args) => {
// The runtime forwards the requesting provider's name so the client
// can dispatch to the right credential.
expect(args.providerName).toBe(providerName);
+ // The runtime also forwards the owning session id so a
+ // client-level shared callback can resolve the session.
+ expect(typeof args.sessionId).toBe("string");
+ expect(args.sessionId.length).toBeGreaterThan(0);
acquiredFor.push(providerName);
return tokenByProvider[providerName];
};
@@ -226,14 +230,14 @@ describe("BYOK bearer-token provider", async () => {
type: "openai",
wireApi: "completions",
baseUrl: RED_BASE_URL,
- getBearerToken: makeCallback("red"),
+ bearerTokenProvider: makeCallback("red"),
},
{
name: "blue",
type: "openai",
wireApi: "completions",
baseUrl: BLUE_BASE_URL,
- getBearerToken: makeCallback("blue"),
+ bearerTokenProvider: makeCallback("blue"),
},
];
const models: ProviderModelConfig[] = [
diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py
index 1e7a3afb1..ff13d47de 100644
--- a/python/copilot/__init__.py
+++ b/python/copilot/__init__.py
@@ -86,6 +86,7 @@
AutoModeSwitchHandler,
AutoModeSwitchRequest,
AutoModeSwitchResponse,
+ BearerTokenProvider,
CommandContext,
CommandDefinition,
CopilotSession,
@@ -100,7 +101,6 @@
ExitPlanModeHandler,
ExitPlanModeRequest,
ExitPlanModeResult,
- GetBearerToken,
InfiniteSessionConfig,
InputOptions,
LargeToolOutputConfig,
@@ -216,7 +216,7 @@
"ExtensionInfo",
"CopilotWebSocketForwarder",
"GetAuthStatusResponse",
- "GetBearerToken",
+ "BearerTokenProvider",
"GetStatusResponse",
"InfiniteSessionConfig",
"InputOptions",
diff --git a/python/copilot/client.py b/python/copilot/client.py
index ebfdcf992..c7d11d12b 100644
--- a/python/copilot/client.py
+++ b/python/copilot/client.py
@@ -81,6 +81,7 @@
)
from .session import (
AutoModeSwitchHandler,
+ BearerTokenProvider,
CommandDefinition,
ContextTier,
CopilotSession,
@@ -89,7 +90,6 @@
DefaultAgentConfig,
ElicitationHandler,
ExitPlanModeHandler,
- GetBearerToken,
InfiniteSessionConfig,
LargeToolOutputConfig,
MCPServerConfig,
@@ -180,8 +180,8 @@ def _capi_session_options_to_wire(options: CapiSessionOptions) -> dict[str, Any]
def _collect_bearer_token_callbacks(
provider: ProviderConfig | None,
providers: list[NamedProviderConfig] | None,
-) -> dict[str, GetBearerToken]:
- """Collect per-provider ``get_bearer_token`` callbacks keyed by provider name.
+) -> dict[str, BearerTokenProvider]:
+ """Collect per-provider ``bearer_token_provider`` callbacks keyed by provider name.
The singular, whole-session ``provider`` uses the implicit
``_DEFAULT_BEARER_TOKEN_PROVIDER_NAME``; ``providers`` entries use their own
@@ -189,14 +189,14 @@ def _collect_bearer_token_callbacks(
``hasBearerTokenProvider: true`` instead and the runtime calls back over
``providerToken.getToken``.
"""
- callbacks: dict[str, GetBearerToken] = {}
+ callbacks: dict[str, BearerTokenProvider] = {}
if provider is not None:
- singular = provider.get("get_bearer_token")
+ singular = provider.get("bearer_token_provider")
if singular is not None:
callbacks[_DEFAULT_BEARER_TOKEN_PROVIDER_NAME] = singular
if providers:
for named in providers:
- callback = named.get("get_bearer_token")
+ callback = named.get("bearer_token_provider")
if callback is not None:
callbacks[named["name"]] = callback
return callbacks
@@ -3266,7 +3266,7 @@ def _convert_provider_to_wire_format(
wire_provider["transport"] = provider["transport"]
if "bearer_token" in provider:
wire_provider["bearerToken"] = provider["bearer_token"]
- if provider.get("get_bearer_token") is not None:
+ if provider.get("bearer_token_provider") is not None:
wire_provider["hasBearerTokenProvider"] = True
if "headers" in provider:
wire_provider["headers"] = provider["headers"]
@@ -3304,7 +3304,7 @@ def _convert_named_provider_to_wire_format(
wire["apiKey"] = provider["api_key"]
if "bearer_token" in provider:
wire["bearerToken"] = provider["bearer_token"]
- if provider.get("get_bearer_token") is not None:
+ if provider.get("bearer_token_provider") is not None:
wire["hasBearerTokenProvider"] = True
if "headers" in provider:
wire["headers"] = provider["headers"]
diff --git a/python/copilot/session.py b/python/copilot/session.py
index 94fba994a..0dc569f25 100644
--- a/python/copilot/session.py
+++ b/python/copilot/session.py
@@ -1080,7 +1080,7 @@ class AzureProviderOptions(TypedDict, total=False):
class ProviderTokenArgs(TypedDict):
- """Arguments passed to a :data:`GetBearerToken` callback when the runtime
+ """Arguments passed to a :data:`BearerTokenProvider` callback when the runtime
needs a fresh bearer token for a BYOK provider.
**Experimental.** Part of the bearer-token-provider surface and may change or
@@ -1092,6 +1092,11 @@ class ProviderTokenArgs(TypedDict):
# ``NamedProviderConfig`` entries it is ``NamedProviderConfig.name``.
provider_name: str
+ # Id of the session that triggered this token request. A client-level shared
+ # callback registered for many sessions can use this to resolve the owning
+ # session and scope token acquisition or caching per session.
+ session_id: str
+
# Per-request callback that resolves a bearer token on demand for a BYOK
# provider (for example via Azure Managed Identity). The Copilot SDK takes no
@@ -1099,7 +1104,7 @@ class ProviderTokenArgs(TypedDict):
# Never serialized — setting it makes the SDK send ``hasBearerTokenProvider`` on
# the wire and answer the runtime's ``providerToken.getToken`` requests. May be
# sync or async.
-GetBearerToken = Callable[[ProviderTokenArgs], str | Awaitable[str]]
+BearerTokenProvider = Callable[[ProviderTokenArgs], str | Awaitable[str]]
class ProviderConfig(TypedDict, total=False):
@@ -1142,8 +1147,10 @@ class ProviderConfig(TypedDict, total=False):
# provider (for example via Azure Managed Identity). Never serialized — the
# SDK sends hasBearerTokenProvider: true on the wire and answers the
# runtime's providerToken.getToken requests with this callback's result.
- # Mutually exclusive with api_key and bearer_token.
- get_bearer_token: GetBearerToken
+ # When set alongside api_key/bearer_token, this callback takes precedence: the
+ # runtime applies the token it returns as the Authorization: Bearer header for
+ # each request and does not send the static credential.
+ bearer_token_provider: BearerTokenProvider
class NamedProviderConfig(TypedDict, total=False):
@@ -1172,9 +1179,11 @@ class NamedProviderConfig(TypedDict, total=False):
headers: dict[str, str]
# Per-request bearer-token callback for this named BYOK provider. Never
# serialized; the SDK sends hasBearerTokenProvider: true and answers the
- # runtime's providerToken.getToken requests. Mutually exclusive with api_key
- # and bearer_token.
- get_bearer_token: GetBearerToken
+ # runtime's providerToken.getToken requests. When set alongside
+ # api_key/bearer_token, this callback takes precedence: the runtime applies
+ # the token it returns as the Authorization: Bearer header for each request
+ # and does not send the static credential.
+ bearer_token_provider: BearerTokenProvider
class ProviderModelConfig(TypedDict, total=False):
@@ -1248,7 +1257,7 @@ def _canvas_handler_error(err: Exception) -> JsonRpcError:
class _BearerTokenProviderAdapter:
"""Routes runtime ``providerToken.getToken`` requests to the matching
- per-provider :data:`GetBearerToken` callback registered on the session.
+ per-provider :data:`BearerTokenProvider` callback registered on the session.
The runtime calls this once per outbound request for a BYOK provider that
declared ``hasBearerTokenProvider: true``; it does no caching, so the SDK
@@ -1268,7 +1277,10 @@ async def get_token(self, params: ProviderTokenAcquireRequest) -> ProviderTokenA
-32603,
f"No bearer-token provider registered for provider: {provider_name!r}",
)
- args: ProviderTokenArgs = {"provider_name": provider_name}
+ args: ProviderTokenArgs = {
+ "provider_name": provider_name,
+ "session_id": params.session_id,
+ }
result = callback(args)
if inspect.isawaitable(result):
result = await result
@@ -1340,7 +1352,7 @@ def __init__(
self._transform_callbacks_lock = threading.Lock()
self._command_handlers: dict[str, CommandHandler] = {}
self._command_handlers_lock = threading.Lock()
- self._bearer_token_providers: dict[str, GetBearerToken] = {}
+ self._bearer_token_providers: dict[str, BearerTokenProvider] = {}
self._bearer_token_providers_lock = threading.Lock()
self._elicitation_handler: ElicitationHandler | None = None
self._elicitation_handler_lock = threading.Lock()
@@ -2082,7 +2094,9 @@ def _register_commands(self, commands: list[CommandDefinition] | None) -> None:
for cmd in commands:
self._command_handlers[cmd.name] = cmd.handler
- def _register_bearer_token_providers(self, providers: dict[str, GetBearerToken] | None) -> None:
+ def _register_bearer_token_providers(
+ self, providers: dict[str, BearerTokenProvider] | None
+ ) -> None:
"""Register per-provider bearer-token callbacks for this session.
The runtime never receives the callbacks themselves; the SDK strips them
diff --git a/python/e2e/test_byok_bearer_token_provider_e2e.py b/python/e2e/test_byok_bearer_token_provider_e2e.py
index 28f9e0586..37dfbc009 100644
--- a/python/e2e/test_byok_bearer_token_provider_e2e.py
+++ b/python/e2e/test_byok_bearer_token_provider_e2e.py
@@ -5,7 +5,7 @@
"""E2E coverage for the experimental BYOK bearer-token-provider surface.
Mirrors ``nodejs/test/e2e/byok_bearer_token_provider.e2e.test.ts``. A BYOK
-provider config may carry a ``get_bearer_token`` callback; the callback stays
+provider config may carry a ``bearer_token_provider`` callback; the callback stays
entirely on the SDK/client side. The SDK strips it from the wire config, sets
the ``hasBearerTokenProvider`` flag, and the runtime calls back over the
session-scoped ``providerToken.getToken`` RPC before each outbound model
@@ -31,7 +31,7 @@
import pytest_asyncio
from copilot import CopilotRequestContext, CopilotRequestHandler
-from copilot.session import GetBearerToken, PermissionHandler
+from copilot.session import BearerTokenProvider, PermissionHandler
from ._copilot_request_helpers import build_isolated_client, build_non_inference_response
from .testharness import E2ETestContext
@@ -56,7 +56,7 @@ class _CapturingRequestHandler(CopilotRequestHandler):
The runtime invokes :meth:`send_request` for every model-layer HTTP request.
Requests aimed at a fake BYOK host are captured — recording the
``Authorization`` header the runtime applied after calling the provider's
- ``get_bearer_token`` callback over ``providerToken.getToken`` — and answered
+ ``bearer_token_provider`` callback over ``providerToken.getToken`` — and answered
with a synthetic ``404`` (non-retryable, so each outbound model request
yields exactly one capture). Every other request (CAPI bootstrap: model
catalog, policy, …) is fabricated locally so no real network or CAPI proxy
@@ -126,6 +126,7 @@ async def _run_turn(client, providers, models, selection_id: str, prompt: str) -
try:
await session.send_and_wait(prompt)
except Exception:
+ # The fake BYOK endpoint intentionally errors after capture.
pass
finally:
try:
@@ -154,7 +155,7 @@ async def get_bearer_token(args) -> str:
"type": "openai",
"wire_api": "completions",
"base_url": PRIMARY_BASE_URL,
- "get_bearer_token": get_bearer_token,
+ "bearer_token_provider": get_bearer_token,
}
]
models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}]
@@ -185,7 +186,7 @@ async def get_bearer_token(args) -> str:
"type": "openai",
"wire_api": "completions",
"base_url": PRIMARY_BASE_URL,
- "get_bearer_token": get_bearer_token,
+ "bearer_token_provider": get_bearer_token,
}
]
models = [{"id": "default", "provider": "mi", "wire_model": "byok-gpt-4o"}]
@@ -208,11 +209,14 @@ async def test_dispatches_token_acquisition_per_provider(self, bearer_fixture):
token_by_provider = {"red": "token-for-red", "blue": "token-for-blue"}
acquired_for: list[str] = []
- def make_callback(provider_name: str) -> GetBearerToken:
+ def make_callback(provider_name: str) -> BearerTokenProvider:
async def callback(args) -> str:
# The runtime forwards the requesting provider's name so the
# client can dispatch to the right credential.
assert args["provider_name"] == provider_name
+ # The runtime also forwards the owning session id so a
+ # client-level shared callback can resolve the session.
+ assert isinstance(args["session_id"], str) and args["session_id"]
acquired_for.append(provider_name)
return token_by_provider[provider_name]
@@ -224,14 +228,14 @@ async def callback(args) -> str:
"type": "openai",
"wire_api": "completions",
"base_url": RED_BASE_URL,
- "get_bearer_token": make_callback("red"),
+ "bearer_token_provider": make_callback("red"),
},
{
"name": "blue",
"type": "openai",
"wire_api": "completions",
"base_url": BLUE_BASE_URL,
- "get_bearer_token": make_callback("blue"),
+ "bearer_token_provider": make_callback("blue"),
},
]
models = [
diff --git a/rust/src/provider_token.rs b/rust/src/provider_token.rs
index f92715006..a8b75f196 100644
--- a/rust/src/provider_token.rs
+++ b/rust/src/provider_token.rs
@@ -30,6 +30,13 @@ pub struct ProviderTokenArgs {
/// This is `"default"` for the singular whole-session provider, otherwise
/// the named provider's `name`.
pub provider_name: String,
+
+ /// Id of the session that triggered this token request.
+ ///
+ /// A client-level shared callback registered for many sessions can use this
+ /// to resolve the owning session and scope token acquisition or caching per
+ /// session.
+ pub session_id: String,
}
/// Error returned by a [`BearerTokenProvider`].
diff --git a/rust/src/provider_token_dispatch.rs b/rust/src/provider_token_dispatch.rs
index c100443cd..0631260a4 100644
--- a/rust/src/provider_token_dispatch.rs
+++ b/rust/src/provider_token_dispatch.rs
@@ -119,6 +119,7 @@ async fn get_token(
match token_provider
.get_token(ProviderTokenArgs {
provider_name: params.provider_name,
+ session_id: params.session_id.into_inner(),
})
.await
{
diff --git a/rust/src/types.rs b/rust/src/types.rs
index d62e20f70..75408db02 100644
--- a/rust/src/types.rs
+++ b/rust/src/types.rs
@@ -1053,7 +1053,7 @@ pub struct ProviderConfig {
/// **Experimental.** Callback used to acquire a bearer token before each
/// outbound request to this provider.
#[serde(skip)]
- pub get_bearer_token: Option>,
+ pub bearer_token_provider: Option>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub(crate) has_bearer_token_provider: Option,
/// Azure-specific options.
@@ -1097,8 +1097,8 @@ impl std::fmt::Debug for ProviderConfig {
.field("api_key", &self.api_key)
.field("bearer_token", &self.bearer_token)
.field(
- "get_bearer_token",
- &self.get_bearer_token.as_ref().map(|_| ""),
+ "bearer_token_provider",
+ &self.bearer_token_provider.as_ref().map(|_| ""),
)
.field("has_bearer_token_provider", &self.has_bearer_token_provider)
.field("azure", &self.azure)
@@ -1158,8 +1158,8 @@ impl ProviderConfig {
///
/// **Experimental.** This method is part of an experimental wire-protocol
/// surface and may change or be removed in a future release.
- pub fn with_get_bearer_token(mut self, provider: Arc) -> Self {
- self.get_bearer_token = Some(provider);
+ pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self {
+ self.bearer_token_provider = Some(provider);
self
}
@@ -1291,7 +1291,7 @@ pub struct NamedProviderConfig {
/// **Experimental.** Callback used to acquire a bearer token before each
/// outbound request to this provider.
#[serde(skip)]
- pub get_bearer_token: Option>,
+ pub bearer_token_provider: Option>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub(crate) has_bearer_token_provider: Option,
/// Azure-specific options.
@@ -1312,8 +1312,8 @@ impl std::fmt::Debug for NamedProviderConfig {
.field("api_key", &self.api_key)
.field("bearer_token", &self.bearer_token)
.field(
- "get_bearer_token",
- &self.get_bearer_token.as_ref().map(|_| ""),
+ "bearer_token_provider",
+ &self.bearer_token_provider.as_ref().map(|_| ""),
)
.field("has_bearer_token_provider", &self.has_bearer_token_provider)
.field("azure", &self.azure)
@@ -1363,8 +1363,8 @@ impl NamedProviderConfig {
///
/// **Experimental.** This method is part of an experimental wire-protocol
/// surface and may change or be removed in a future release.
- pub fn with_get_bearer_token(mut self, provider: Arc) -> Self {
- self.get_bearer_token = Some(provider);
+ pub fn with_bearer_token_provider(mut self, provider: Arc) -> Self {
+ self.bearer_token_provider = Some(provider);
self
}
@@ -1388,7 +1388,7 @@ fn prepare_bearer_token_providers(
let mut bearer_token_providers = HashMap::new();
if let Some(provider) = provider.as_mut()
- && let Some(token_provider) = provider.get_bearer_token.take()
+ && let Some(token_provider) = provider.bearer_token_provider.take()
{
provider.has_bearer_token_provider = Some(true);
bearer_token_providers.insert("default".to_string(), token_provider);
@@ -1396,7 +1396,7 @@ fn prepare_bearer_token_providers(
if let Some(providers) = providers.as_mut() {
for provider in providers {
- if let Some(token_provider) = provider.get_bearer_token.take() {
+ if let Some(token_provider) = provider.bearer_token_provider.take() {
provider.has_bearer_token_provider = Some(true);
bearer_token_providers.insert(provider.name.clone(), token_provider);
}
diff --git a/rust/tests/e2e/byok_bearer_token_provider.rs b/rust/tests/e2e/byok_bearer_token_provider.rs
index c3cd9ef4b..fc3ef89d9 100644
--- a/rust/tests/e2e/byok_bearer_token_provider.rs
+++ b/rust/tests/e2e/byok_bearer_token_provider.rs
@@ -155,7 +155,7 @@ async fn callback_token_is_applied_as_authorization_header() {
NamedProviderConfig::new("mi", PRIMARY_BASE_URL)
.with_provider_type("openai")
.with_wire_api("completions")
- .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| {
+ .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| {
let callback_calls = callback_calls.clone();
async move {
callback_calls.fetch_add(1, Ordering::SeqCst);
@@ -202,7 +202,7 @@ async fn reacquires_a_fresh_token_for_each_request() {
NamedProviderConfig::new("mi", PRIMARY_BASE_URL)
.with_provider_type("openai")
.with_wire_api("completions")
- .with_get_bearer_token(Arc::new(move |_args: ProviderTokenArgs| {
+ .with_bearer_token_provider(Arc::new(move |_args: ProviderTokenArgs| {
let callback_calls = callback_calls.clone();
async move {
let call = callback_calls.fetch_add(1, Ordering::SeqCst) + 1;
@@ -266,10 +266,14 @@ async fn dispatches_token_acquisition_per_provider() {
NamedProviderConfig::new(name, base_url)
.with_provider_type("openai")
.with_wire_api("completions")
- .with_get_bearer_token(Arc::new(move |args: ProviderTokenArgs| {
+ .with_bearer_token_provider(Arc::new(move |args: ProviderTokenArgs| {
let acquired_for = acquired_for.clone();
async move {
assert_eq!(args.provider_name, name);
+ assert!(
+ !args.session_id.is_empty(),
+ "expected a non-empty session id in token args"
+ );
acquired_for.lock().unwrap().push(name.to_string());
Ok::<_, BearerTokenError>(token.to_string())
}