Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dotnet/src/BearerTokenProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace GitHub.Copilot;

/// <summary>
/// Arguments passed to a bearer-token callback (the <c>GetBearerToken</c> property
/// Arguments passed to a bearer-token callback (the <c>BearerTokenProvider</c> property
/// on <see cref="ProviderConfig"/> / <see cref="NamedProviderConfig"/>) when the
/// runtime needs a fresh bearer token for a BYOK provider.
/// </summary>
Expand All @@ -29,4 +29,11 @@ public sealed class ProviderTokenArgs
/// provider-agnostic and forwards only the provider name.
/// </remarks>
public required string ProviderName { get; init; }

/// <summary>
/// Id of the session that triggered this token request. A client-level
/// shared callback registered for many sessions can use this to resolve the
/// owning session and scope token acquisition or caching per session.
/// </summary>
public required string SessionId { get; init; }
Comment thread
SteveSandersonMS marked this conversation as resolved.
}
11 changes: 4 additions & 7 deletions dotnet/src/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -671,26 +671,23 @@ private CopilotSession InitializeSession(
private const string DefaultBearerTokenProviderName = "default";

/// <summary>
/// Collects the per-provider <c>GetBearerToken</c> callbacks keyed by
/// Collects the per-provider <c>BearerTokenProvider</c> callbacks keyed by
/// provider name for session-side registration. The singular, whole-session
/// <see cref="ProviderConfig"/> uses the implicit
/// <see cref="DefaultBearerTokenProviderName"/>.
/// </summary>
private static Dictionary<string, Func<ProviderTokenArgs, Task<string>>> BuildBearerTokenCallbacks(SessionConfigBase config)
{
var callbacks = new Dictionary<string, Func<ProviderTokenArgs, Task<string>>>(StringComparer.Ordinal);
if (config.Provider?.GetBearerToken is { } singular)
if (config.Provider?.BearerTokenProvider is { } singular)
{
callbacks[DefaultBearerTokenProviderName] = singular;
}
if (config.Providers != null)
{
foreach (var provider in config.Providers)
foreach (var provider in config.Providers.Where(provider => provider.BearerTokenProvider is not null))
{
if (provider.GetBearerToken is { } callback)
{
callbacks[provider.Name] = callback;
}
callbacks[provider.Name] = provider.BearerTokenProvider!;
}
}
return callbacks;
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ internal void RegisterAutoModeSwitchHandler(Func<AutoModeSwitchRequest, AutoMode
}

/// <summary>
/// Registers per-provider <c>GetBearerToken</c> callbacks for BYOK
/// Registers per-provider <c>BearerTokenProvider</c> callbacks for BYOK
/// providers configured with managed-identity / on-demand bearer-token auth.
/// </summary>
/// <remarks>
Expand Down Expand Up @@ -899,7 +899,7 @@ internal void RegisterBearerTokenProviders(IReadOnlyDictionary<string, Func<Prov

/// <summary>
/// Routes runtime <c>providerToken.getToken</c> requests to the matching
/// per-provider <c>GetBearerToken</c> callback registered on the session.
/// per-provider <c>BearerTokenProvider</c> callback registered on the session.
/// </summary>
private sealed class BearerTokenProviderHandler(CopilotSession session) : IProviderTokenHandler
{
Expand All @@ -910,7 +910,7 @@ public async Task<ProviderTokenAcquireResult> GetTokenAsync(ProviderTokenAcquire
throw new InvalidOperationException(
$"No bearer-token provider registered for provider \"{request.ProviderName}\"");
}
var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName }).ConfigureAwait(false);
var token = await callback(new ProviderTokenArgs { ProviderName = request.ProviderName, SessionId = request.SessionId }).ConfigureAwait(false);
return new ProviderTokenAcquireResult { Token = token };
}
}
Expand Down
26 changes: 14 additions & 12 deletions dotnet/src/Types.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2044,26 +2044,27 @@ public sealed class ProviderConfig
public string? BearerToken { get; set; }

/// <summary>
/// Wire-only flag, emitted automatically when <see cref="GetBearerToken"/> is set, that tells
/// Wire-only flag, emitted automatically when <see cref="BearerTokenProvider"/> is set, that tells
/// the runtime to request a token over the session-scoped <c>providerToken.getToken</c> RPC
/// before each outbound request to this provider. Derived from <see cref="GetBearerToken"/>;
/// before each outbound request to this provider. Derived from <see cref="BearerTokenProvider"/>;
/// internal and never part of the public API.
/// </summary>
[JsonInclude]
[JsonPropertyName("hasBearerTokenProvider")]
internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null;
internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null;

/// <summary>
/// Per-request callback that resolves a bearer token on demand for this BYOK provider (for
/// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a
/// callback backed by your own identity library. Never serialized — setting it makes the SDK send
/// <c>hasBearerTokenProvider: true</c> on the wire and answer the runtime's
/// <c>providerToken.getToken</c> requests. Mutually exclusive with <see cref="ApiKey"/> and
/// <see cref="BearerToken"/>.
/// <c>providerToken.getToken</c> requests. When set alongside <see cref="ApiKey"/>/<see cref="BearerToken"/>, this callback takes precedence:
/// the runtime applies the token it returns as the Authorization: Bearer header for each request
/// and does not send the static credential.
/// </summary>
[JsonIgnore]
[Experimental(Diagnostics.Experimental)]
public Func<ProviderTokenArgs, Task<string>>? GetBearerToken { get; set; }
public Func<ProviderTokenArgs, Task<string>>? BearerTokenProvider { get; set; }

/// <summary>
/// Azure-specific configuration options.
Expand Down Expand Up @@ -2198,26 +2199,27 @@ public sealed class NamedProviderConfig
public string? BearerToken { get; set; }

/// <summary>
/// Wire-only flag, emitted automatically when <see cref="GetBearerToken"/> is set, that tells
/// Wire-only flag, emitted automatically when <see cref="BearerTokenProvider"/> is set, that tells
/// the runtime to request a token over the session-scoped <c>providerToken.getToken</c> RPC
/// before each outbound request to this provider. Derived from <see cref="GetBearerToken"/>;
/// before each outbound request to this provider. Derived from <see cref="BearerTokenProvider"/>;
/// internal and never part of the public API.
/// </summary>
[JsonInclude]
[JsonPropertyName("hasBearerTokenProvider")]
internal bool? HasBearerTokenProvider => GetBearerToken is not null ? true : null;
internal bool? HasBearerTokenProvider => BearerTokenProvider is not null ? true : null;

/// <summary>
/// Per-request callback that resolves a bearer token on demand for this BYOK provider (for
/// example via Azure Managed Identity). The Copilot SDK takes no identity dependency: supply a
/// callback backed by your own identity library. Never serialized — setting it makes the SDK send
/// <c>hasBearerTokenProvider: true</c> on the wire and answer the runtime's
/// <c>providerToken.getToken</c> requests. Mutually exclusive with <see cref="ApiKey"/> and
/// <see cref="BearerToken"/>.
/// <c>providerToken.getToken</c> requests. When set alongside <see cref="ApiKey"/>/<see cref="BearerToken"/>, this callback takes precedence:
/// the runtime applies the token it returns as the Authorization: Bearer header for each request
/// and does not send the static credential.
/// </summary>
[JsonIgnore]
[Experimental(Diagnostics.Experimental)]
public Func<ProviderTokenArgs, Task<string>>? GetBearerToken { get; set; }
public Func<ProviderTokenArgs, Task<string>>? BearerTokenProvider { get; set; }

/// <summary>
/// Azure-specific configuration options.
Expand Down
22 changes: 13 additions & 9 deletions dotnet/test/E2E/ByokBearerTokenProviderE2ETests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

/// <summary>
/// End-to-end coverage for the experimental BYOK bearer-token-provider surface
/// (<c>GetBearerToken</c> on a provider config). The callback stays entirely on
/// (<c>BearerTokenProvider</c> on a provider config). The callback stays entirely on
/// the SDK/client side: the SDK strips it from the wire config, sets the
/// <c>hasBearerTokenProvider</c> flag, and the runtime calls back over the
/// session-scoped <c>providerToken.getToken</c> RPC before each outbound model
Expand Down Expand Up @@ -81,7 +81,7 @@
{
await session.SendAndWaitAsync(new MessageOptions { Prompt = prompt });
}
catch
catch (InvalidOperationException)
{
// The handler always 404s the BYOK endpoint, so the turn errors after
// the token-bearing request was already captured. Expected.
Expand Down Expand Up @@ -110,7 +110,7 @@
Type = "openai",
WireApi = "completions",
BaseUrl = PrimaryBaseUrl,
GetBearerToken = _ =>
BearerTokenProvider = _ =>
{
Interlocked.Increment(ref calls);
return Task.FromResult(sentinel);
Expand Down Expand Up @@ -149,7 +149,7 @@
BaseUrl = PrimaryBaseUrl,
// A distinct token per acquisition proves the runtime re-invokes
// the callback per request rather than caching a previous token.
GetBearerToken = _ =>
BearerTokenProvider = _ =>
{
var n = Interlocked.Increment(ref calls);
return Task.FromResult($"rotating-token-{n}");
Expand Down Expand Up @@ -189,6 +189,9 @@
// The runtime forwards the requesting provider's name so the client
// can dispatch to the right credential.
Assert.Equal(providerName, args.ProviderName);
// The runtime also forwards the owning session id so a
// client-level shared callback can resolve the session.
Assert.False(string.IsNullOrEmpty(args.SessionId));
acquiredFor.Add(providerName);
return Task.FromResult(tokenByProvider[providerName]);
};
Expand All @@ -205,15 +208,15 @@
Type = "openai",
WireApi = "completions",
BaseUrl = RedBaseUrl,
GetBearerToken = MakeCallback("red"),
BearerTokenProvider = MakeCallback("red"),
},
new()
{
Name = "blue",
Type = "openai",
WireApi = "completions",
BaseUrl = BlueBaseUrl,
GetBearerToken = MakeCallback("blue"),
BearerTokenProvider = MakeCallback("blue"),
},
};
var models = new List<ProviderModelConfig>
Expand All @@ -240,7 +243,7 @@
/// The runtime invokes <see cref="SendRequestAsync"/> for every model-layer HTTP
/// request. Requests aimed at a fake BYOK host (<c>*.invalid</c>) are captured —
/// recording the <c>Authorization</c> header the runtime applied after calling
/// the provider's <c>GetBearerToken</c> callback over the session-scoped
/// the provider's <c>BearerTokenProvider</c> callback over the session-scoped
/// <c>providerToken.getToken</c> RPC — and answered with a synthetic <c>404</c>
/// (a non-retryable status, so each outbound model request yields exactly one
/// capture). Every other request (CAPI bootstrap: model catalog, policy, …) is
Expand All @@ -262,13 +265,14 @@
? string.Join(", ", values)
: null));

return Task.FromResult(new HttpResponseMessage(HttpStatusCode.NotFound)
var response = new HttpResponseMessage(HttpStatusCode.NotFound)
{
Content = new StringContent(
"{\"error\":{\"message\":\"fake byok endpoint\"}}",
System.Text.Encoding.UTF8,
"application/json"),
});
};

Check warning

Code scanning / CodeQL

Missing Dispose call on local IDisposable Warning test

Disposable 'HttpResponseMessage' is created but not disposed.
Comment on lines +268 to +274
return Task.FromResult(response);
}

// CAPI bootstrap (model catalog, policy, …) — answered off-network.
Expand Down
14 changes: 7 additions & 7 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ import (
// whole-session [ProviderConfig]. Named providers are keyed by their own Name.
const defaultBearerTokenProviderName = "default"

// collectBearerTokenProviders gathers the per-provider [GetBearerToken] callbacks
// collectBearerTokenProviders gathers the per-provider [BearerTokenProvider] callbacks
// from the singular provider and any named providers, keyed by provider name. The
// singular provider uses the implicit name "default"; named providers use their
// own Name. Returns nil when no callbacks are configured.
func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]GetBearerToken {
callbacks := make(map[string]GetBearerToken)
if provider != nil && provider.GetBearerToken != nil {
callbacks[defaultBearerTokenProviderName] = provider.GetBearerToken
func collectBearerTokenProviders(provider *ProviderConfig, providers []NamedProviderConfig) map[string]BearerTokenProvider {
callbacks := make(map[string]BearerTokenProvider)
if provider != nil && provider.BearerTokenProvider != nil {
callbacks[defaultBearerTokenProviderName] = provider.BearerTokenProvider
}
for i := range providers {
if providers[i].GetBearerToken != nil {
callbacks[providers[i].Name] = providers[i].GetBearerToken
if providers[i].BearerTokenProvider != nil {
callbacks[providers[i].Name] = providers[i].BearerTokenProvider
}
}
if len(callbacks) == 0 {
Expand Down
51 changes: 28 additions & 23 deletions go/internal/e2e/byok_bearer_token_provider_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type capturedBYOKRequest struct {

// byokCapturingRoundTripper stands in for a real HTTP upstream. It records the
// `Authorization` header the runtime applied (after calling the provider's
// GetBearerToken callback over the session-scoped `providerToken.getToken` RPC)
// BearerTokenProvider callback over the session-scoped `providerToken.getToken` RPC)
// for every request aimed at a fake `.invalid` BYOK host, answering them with a
// synthetic 404 (a non-retryable status, so each outbound model request yields
// exactly one capture). Every other request (CAPI bootstrap: model catalog,
Expand Down Expand Up @@ -96,7 +96,7 @@ func (rt *byokCapturingRoundTripper) reset() {
}

// TestBYOKBearerTokenProvider is end-to-end coverage for the experimental BYOK
// bearer-token-provider surface (GetBearerToken on a provider config). The
// bearer-token-provider surface (BearerTokenProvider on a provider config). The
// callback stays entirely on the SDK/client side: the SDK strips it from the
// wire config, sets the `hasBearerTokenProvider` flag, and the runtime calls
// back over the session-scoped `providerToken.getToken` RPC before each outbound
Expand Down Expand Up @@ -151,11 +151,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) {
}

providers := []copilot.NamedProviderConfig{{
Name: "mi",
Type: "openai",
WireAPI: "completions",
BaseURL: byokPrimaryBaseURL,
GetBearerToken: getBearerToken,
Name: "mi",
Type: "openai",
WireAPI: "completions",
BaseURL: byokPrimaryBaseURL,
BearerTokenProvider: getBearerToken,
}}
models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}}

Expand Down Expand Up @@ -189,11 +189,11 @@ func TestBYOKBearerTokenProvider(t *testing.T) {
}

providers := []copilot.NamedProviderConfig{{
Name: "mi",
Type: "openai",
WireAPI: "completions",
BaseURL: byokPrimaryBaseURL,
GetBearerToken: getBearerToken,
Name: "mi",
Type: "openai",
WireAPI: "completions",
BaseURL: byokPrimaryBaseURL,
BearerTokenProvider: getBearerToken,
}}
models := []copilot.ProviderModelConfig{{ID: "default", Provider: "mi", WireModel: "byok-gpt-4o"}}

Expand Down Expand Up @@ -227,13 +227,18 @@ func TestBYOKBearerTokenProvider(t *testing.T) {
}
var mu sync.Mutex
var acquiredFor []string
makeCallback := func(providerName string) copilot.GetBearerToken {
makeCallback := func(providerName string) copilot.BearerTokenProvider {
return func(args copilot.ProviderTokenArgs) (string, error) {
// The runtime forwards the requesting provider's name so the
// client can dispatch to the right credential.
if args.ProviderName != providerName {
t.Errorf("Expected providerName %q, got %q", providerName, args.ProviderName)
}
// The runtime also forwards the owning session id so a
// client-level shared callback can resolve the session.
if args.SessionID == "" {
t.Errorf("Expected a non-empty session id in token args")
}
mu.Lock()
acquiredFor = append(acquiredFor, providerName)
mu.Unlock()
Expand All @@ -243,18 +248,18 @@ func TestBYOKBearerTokenProvider(t *testing.T) {

providers := []copilot.NamedProviderConfig{
{
Name: "red",
Type: "openai",
WireAPI: "completions",
BaseURL: byokRedBaseURL,
GetBearerToken: makeCallback("red"),
Name: "red",
Type: "openai",
WireAPI: "completions",
BaseURL: byokRedBaseURL,
BearerTokenProvider: makeCallback("red"),
},
{
Name: "blue",
Type: "openai",
WireAPI: "completions",
BaseURL: byokBlueBaseURL,
GetBearerToken: makeCallback("blue"),
Name: "blue",
Type: "openai",
WireAPI: "completions",
BaseURL: byokBlueBaseURL,
BearerTokenProvider: makeCallback("blue"),
},
}
models := []copilot.ProviderModelConfig{
Expand Down
Loading
Loading