Skip to content

Commit dc7d02b

Browse files
Add Rust BYOK get_bearer_token glue + e2e
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 8b8ce9c commit dc7d02b

7 files changed

Lines changed: 707 additions & 2 deletions

File tree

rust/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ pub mod hooks;
2222
mod jsonrpc;
2323
/// Permission-policy helpers that produce a [`handler::PermissionHandler`].
2424
pub mod permission;
25+
/// BYOK bearer-token provider callbacks.
26+
pub mod provider_token;
27+
mod provider_token_dispatch;
2528
/// GitHub Copilot CLI binary resolution (env var, embedded, dev cache).
2629
pub(crate) mod resolve;
2730
mod router;
@@ -72,6 +75,7 @@ pub(crate) use jsonrpc::{
7275
JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes,
7376
};
7477
pub use mode::{BUILTIN_TOOLS_ISOLATED, ClientMode, ToolSet};
78+
pub use provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs};
7579

7680
/// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature).
7781
#[cfg(feature = "test-support")]

rust/src/provider_token.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
*--------------------------------------------------------------------------------------------*/
4+
5+
//! BYOK bearer-token provider callbacks.
6+
//!
7+
//! <div class="warning">
8+
//!
9+
//! **Experimental.** These types are part of an experimental wire-protocol
10+
//! surface and may change or be removed in future SDK or CLI releases.
11+
//!
12+
//! </div>
13+
14+
use std::future::Future;
15+
16+
use async_trait::async_trait;
17+
18+
/// Arguments passed to a BYOK bearer-token provider callback.
19+
///
20+
/// <div class="warning">
21+
///
22+
/// **Experimental.** This type is part of an experimental wire-protocol
23+
/// surface and may change or be removed in future SDK or CLI releases.
24+
///
25+
/// </div>
26+
#[derive(Debug, Clone, PartialEq, Eq)]
27+
pub struct ProviderTokenArgs {
28+
/// Name of the BYOK provider needing a token.
29+
///
30+
/// This is `"default"` for the singular whole-session provider, otherwise
31+
/// the named provider's `name`.
32+
pub provider_name: String,
33+
}
34+
35+
/// Error returned by a [`BearerTokenProvider`].
36+
///
37+
/// <div class="warning">
38+
///
39+
/// **Experimental.** This type is part of an experimental wire-protocol
40+
/// surface and may change or be removed in future SDK or CLI releases.
41+
///
42+
/// </div>
43+
#[derive(Debug, Clone, PartialEq, Eq)]
44+
pub struct BearerTokenError {
45+
message: String,
46+
}
47+
48+
impl BearerTokenError {
49+
/// Construct a bearer-token error with a human-readable message.
50+
pub fn message(message: impl Into<String>) -> Self {
51+
Self {
52+
message: message.into(),
53+
}
54+
}
55+
56+
/// Return the human-readable error message.
57+
pub fn as_str(&self) -> &str {
58+
&self.message
59+
}
60+
}
61+
62+
impl std::fmt::Display for BearerTokenError {
63+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64+
f.write_str(&self.message)
65+
}
66+
}
67+
68+
impl std::error::Error for BearerTokenError {}
69+
70+
impl From<String> for BearerTokenError {
71+
fn from(message: String) -> Self {
72+
Self::message(message)
73+
}
74+
}
75+
76+
impl From<&str> for BearerTokenError {
77+
fn from(message: &str) -> Self {
78+
Self::message(message)
79+
}
80+
}
81+
82+
/// Provider-side callback used to acquire bearer tokens for BYOK providers.
83+
///
84+
/// <div class="warning">
85+
///
86+
/// **Experimental.** This trait is part of an experimental wire-protocol
87+
/// surface and may change or be removed in future SDK or CLI releases.
88+
///
89+
/// </div>
90+
#[async_trait]
91+
pub trait BearerTokenProvider: Send + Sync {
92+
/// Acquire a bearer token without the `Bearer ` prefix.
93+
async fn get_token(&self, args: ProviderTokenArgs) -> Result<String, BearerTokenError>;
94+
}
95+
96+
#[async_trait]
97+
impl<F, Fut> BearerTokenProvider for F
98+
where
99+
F: Fn(ProviderTokenArgs) -> Fut + Send + Sync,
100+
Fut: Future<Output = Result<String, BearerTokenError>> + Send,
101+
{
102+
async fn get_token(&self, args: ProviderTokenArgs) -> Result<String, BearerTokenError> {
103+
(self)(args).await
104+
}
105+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*---------------------------------------------------------------------------------------------
2+
* Copyright (c) Microsoft Corporation. All rights reserved.
3+
*--------------------------------------------------------------------------------------------*/
4+
5+
//! Inbound `providerToken.*` JSON-RPC request dispatch helpers.
6+
7+
use std::collections::HashMap;
8+
use std::sync::Arc;
9+
10+
use serde::Serialize;
11+
use serde_json::Value;
12+
use tracing::warn;
13+
14+
use crate::generated::api_types::{
15+
ProviderTokenAcquireRequest, ProviderTokenAcquireResult, rpc_methods,
16+
};
17+
use crate::provider_token::{BearerTokenError, BearerTokenProvider, ProviderTokenArgs};
18+
use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes};
19+
20+
async fn respond<T: Serialize>(client: &Client, request_id: u64, result: T) {
21+
let value = match serde_json::to_value(&result) {
22+
Ok(value) => value,
23+
Err(error) => {
24+
warn!(error = %error, "failed to serialize provider token response");
25+
send_error(
26+
client,
27+
request_id,
28+
error_codes::INTERNAL_ERROR,
29+
"serialization failure",
30+
)
31+
.await;
32+
return;
33+
}
34+
};
35+
36+
let _ = client
37+
.send_response(&JsonRpcResponse {
38+
jsonrpc: "2.0".to_string(),
39+
id: request_id,
40+
result: Some(value),
41+
error: None,
42+
})
43+
.await;
44+
}
45+
46+
async fn send_error(client: &Client, request_id: u64, code: i32, message: &str) {
47+
let _ = client
48+
.send_response(&JsonRpcResponse {
49+
jsonrpc: "2.0".to_string(),
50+
id: request_id,
51+
result: None,
52+
error: Some(crate::JsonRpcError {
53+
code,
54+
message: message.to_string(),
55+
data: None,
56+
}),
57+
})
58+
.await;
59+
}
60+
61+
async fn parse_params<T: serde::de::DeserializeOwned>(
62+
client: &Client,
63+
request: &JsonRpcRequest,
64+
) -> Option<T> {
65+
let params = request
66+
.params
67+
.as_ref()
68+
.cloned()
69+
.unwrap_or(Value::Object(serde_json::Map::new()));
70+
match serde_json::from_value(params) {
71+
Ok(params) => Some(params),
72+
Err(error) => {
73+
send_error(
74+
client,
75+
request.id,
76+
error_codes::INVALID_PARAMS,
77+
&format!("invalid params: {error}"),
78+
)
79+
.await;
80+
None
81+
}
82+
}
83+
}
84+
85+
fn token_provider_or_err(
86+
providers: &HashMap<String, Arc<dyn BearerTokenProvider>>,
87+
provider_name: &str,
88+
) -> Result<Arc<dyn BearerTokenProvider>, BearerTokenError> {
89+
providers.get(provider_name).cloned().ok_or_else(|| {
90+
BearerTokenError::message(format!(
91+
"No bearer-token provider installed for BYOK provider {provider_name:?}"
92+
))
93+
})
94+
}
95+
96+
async fn get_token(
97+
client: &Client,
98+
providers: &HashMap<String, Arc<dyn BearerTokenProvider>>,
99+
request: JsonRpcRequest,
100+
) {
101+
let Some(params) = parse_params::<ProviderTokenAcquireRequest>(client, &request).await else {
102+
return;
103+
};
104+
105+
let token_provider = match token_provider_or_err(providers, &params.provider_name) {
106+
Ok(provider) => provider,
107+
Err(error) => {
108+
send_error(
109+
client,
110+
request.id,
111+
error_codes::INTERNAL_ERROR,
112+
&error.to_string(),
113+
)
114+
.await;
115+
return;
116+
}
117+
};
118+
119+
match token_provider
120+
.get_token(ProviderTokenArgs {
121+
provider_name: params.provider_name,
122+
})
123+
.await
124+
{
125+
Ok(token) => respond(client, request.id, ProviderTokenAcquireResult { token }).await,
126+
Err(error) => {
127+
send_error(
128+
client,
129+
request.id,
130+
error_codes::INTERNAL_ERROR,
131+
&format!("Bearer-token provider failed: {error}"),
132+
)
133+
.await;
134+
}
135+
}
136+
}
137+
138+
pub(crate) async fn dispatch(
139+
client: &Client,
140+
providers: &HashMap<String, Arc<dyn BearerTokenProvider>>,
141+
request: JsonRpcRequest,
142+
) {
143+
let method = request.method.as_str();
144+
match method {
145+
rpc_methods::PROVIDERTOKEN_GETTOKEN => get_token(client, providers, request).await,
146+
_ => {
147+
warn!(method = %method, "unknown providerToken.* method");
148+
send_error(
149+
client,
150+
request.id,
151+
error_codes::METHOD_NOT_FOUND,
152+
&format!("unknown method: {method}"),
153+
)
154+
.await;
155+
}
156+
}
157+
}

rust/src/session.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use crate::handler::{
2121
PermissionHandler, PermissionResult, UserInputHandler, UserInputResponse,
2222
};
2323
use crate::hooks::SessionHooks;
24+
use crate::provider_token::BearerTokenProvider;
2425
use crate::session_fs::SessionFsProvider;
2526
use crate::trace_context::inject_trace_context;
2627
use crate::transforms::SystemMessageTransform;
@@ -893,6 +894,7 @@ impl Client {
893894
let command_handlers = build_command_handler_map(runtime.commands.as_deref());
894895
let canvas_handler = runtime.canvas_handler.take();
895896
let session_fs_provider = runtime.session_fs_provider.take();
897+
let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
896898
if self.inner.session_fs_configured && session_fs_provider.is_none() {
897899
return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
898900
}
@@ -1011,6 +1013,7 @@ impl Client {
10111013
command_handlers,
10121014
canvas_handler,
10131015
session_fs_provider,
1016+
bearer_token_providers,
10141017
channels,
10151018
idle_waiter.clone(),
10161019
capabilities.clone(),
@@ -1149,6 +1152,7 @@ impl Client {
11491152
let command_handlers = build_command_handler_map(runtime.commands.as_deref());
11501153
let canvas_handler = runtime.canvas_handler.take();
11511154
let session_fs_provider = runtime.session_fs_provider.take();
1155+
let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
11521156
if self.inner.session_fs_configured && session_fs_provider.is_none() {
11531157
return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
11541158
}
@@ -1183,6 +1187,7 @@ impl Client {
11831187
command_handlers,
11841188
canvas_handler,
11851189
session_fs_provider,
1190+
bearer_token_providers,
11861191
channels,
11871192
idle_waiter.clone(),
11881193
capabilities.clone(),
@@ -1391,6 +1396,7 @@ fn spawn_event_loop(
13911396
command_handlers: Arc<CommandHandlerMap>,
13921397
canvas_handler: Option<Arc<dyn CanvasHandler>>,
13931398
session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
1399+
bearer_token_providers: HashMap<String, Arc<dyn BearerTokenProvider>>,
13941400
channels: crate::router::SessionChannels,
13951401
idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
13961402
capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
@@ -1432,6 +1438,7 @@ fn spawn_event_loop(
14321438
transforms: transforms.as_deref(),
14331439
canvas_handler: canvas_handler.as_ref(),
14341440
session_fs_provider: session_fs_provider.as_ref(),
1441+
bearer_token_providers: &bearer_token_providers,
14351442
};
14361443
handle_request(&session_id, ctx, request).await;
14371444
}
@@ -2010,6 +2017,7 @@ struct RequestDispatchContext<'a> {
20102017
transforms: Option<&'a dyn SystemMessageTransform>,
20112018
canvas_handler: Option<&'a Arc<dyn CanvasHandler>>,
20122019
session_fs_provider: Option<&'a Arc<dyn SessionFsProvider>>,
2020+
bearer_token_providers: &'a HashMap<String, Arc<dyn BearerTokenProvider>>,
20132021
}
20142022

20152023
/// Process a JSON-RPC request from the CLI.
@@ -2025,6 +2033,7 @@ async fn handle_request(
20252033
let transforms = ctx.transforms;
20262034
let canvas_handler = ctx.canvas_handler;
20272035
let session_fs_provider = ctx.session_fs_provider;
2036+
let bearer_token_providers = ctx.bearer_token_providers;
20282037

20292038
if request.method.starts_with("sessionFs.") {
20302039
crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await;
@@ -2036,6 +2045,11 @@ async fn handle_request(
20362045
return;
20372046
}
20382047

2048+
if request.method == crate::generated::api_types::rpc_methods::PROVIDERTOKEN_GETTOKEN {
2049+
crate::provider_token_dispatch::dispatch(client, bearer_token_providers, request).await;
2050+
return;
2051+
}
2052+
20392053
match request.method.as_str() {
20402054
"hooks.invoke" => {
20412055
let params = request.params.as_ref();

0 commit comments

Comments
 (0)