Skip to content

Commit a8255cb

Browse files
Clean up HTTP passthrough API (#1784)
1 parent f079862 commit a8255cb

9 files changed

Lines changed: 92 additions & 132 deletions

File tree

dotnet/src/CopilotRequestHandler.cs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,40 @@ public enum CopilotRequestTransport
4141
[Experimental(Diagnostics.Experimental)]
4242
public sealed class CopilotRequestContext
4343
{
44+
/// <summary>
45+
/// Creates an instance of <see cref="CopilotRequestContext"/> by copying the values from another instance.
46+
/// </summary>
47+
/// <param name="original">A <see cref="CopilotRequestContext"/> instance to copy values from.</param>
48+
public CopilotRequestContext(CopilotRequestContext original)
49+
: this(original.RequestId, original.Url, original.Headers)
50+
{
51+
SessionId = original.SessionId;
52+
Transport = original.Transport;
53+
CancellationToken = original.CancellationToken;
54+
WebSocketResponse = original.WebSocketResponse;
55+
}
56+
57+
internal CopilotRequestContext(string requestId, string url, IReadOnlyDictionary<string, IReadOnlyList<string>> headers)
58+
{
59+
RequestId = requestId;
60+
Url = url;
61+
Headers = headers;
62+
}
63+
4464
/// <summary>Opaque runtime-minted id, stable across the request lifecycle.</summary>
45-
public required string RequestId { get; init; }
65+
public string RequestId { get; init; }
4666

4767
/// <summary>Runtime session id that triggered the request, if any.</summary>
4868
public string? SessionId { get; init; }
4969

5070
/// <summary>Transport the runtime would otherwise use.</summary>
5171
public CopilotRequestTransport Transport { get; init; }
5272

53-
/// <summary>Original request URL.</summary>
54-
public required string Url { get; init; }
73+
/// <summary>Request URL.</summary>
74+
public string Url { get; init; }
5575

56-
/// <summary>Original request headers.</summary>
57-
public required IReadOnlyDictionary<string, IReadOnlyList<string>> Headers { get; init; }
76+
/// <summary>Request headers.</summary>
77+
public IReadOnlyDictionary<string, IReadOnlyList<string>> Headers { get; init; }
5878

5979
/// <summary>
6080
/// Cancelled when the runtime aborts this in-flight request. Subclasses that
@@ -199,25 +219,17 @@ public virtual async ValueTask DisposeAsync()
199219
[Experimental(Diagnostics.Experimental)]
200220
public class CopilotWebSocketForwarder : CopilotWebSocketHandler
201221
{
202-
private readonly string _url;
203-
private readonly IReadOnlyDictionary<string, IReadOnlyList<string>> _headers;
204222
private WebSocket? _upstream;
205223
private CancellationTokenSource? _pumpCts;
206224
private Task? _responsePump;
207225

208226
/// <summary>
209227
/// Initializes a forwarding handler that will open the upstream socket on
210-
/// demand using the supplied URL/headers (or the values from
211-
/// <paramref name="context"/> when omitted).
228+
/// demand using the supplied URL/headers from <paramref name="context"/>.
212229
/// </summary>
213-
public CopilotWebSocketForwarder(
214-
CopilotRequestContext context,
215-
string? url = null,
216-
IReadOnlyDictionary<string, IReadOnlyList<string>>? headers = null)
230+
public CopilotWebSocketForwarder(CopilotRequestContext context)
217231
: base(context)
218232
{
219-
_url = url ?? context.Url;
220-
_headers = headers ?? context.Headers;
221233
}
222234

223235
/// <summary>
@@ -231,7 +243,7 @@ internal override async Task OpenAsync()
231243
}
232244

233245
var socket = new ClientWebSocket();
234-
foreach (var (name, values) in _headers)
246+
foreach (var (name, values) in Context.Headers)
235247
{
236248
if (LlmInferenceHeaders.Forbidden.Contains(name))
237249
{
@@ -248,7 +260,7 @@ internal override async Task OpenAsync()
248260
}
249261
}
250262

251-
await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false);
263+
await socket.ConnectAsync(ToWebSocketUri(Context.Url), Context.CancellationToken).ConfigureAwait(false);
252264
_upstream = socket;
253265
_pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken);
254266

@@ -855,13 +867,10 @@ public Task<LlmInferenceHttpRequestStartResult> HttpRequestStartAsync(LlmInferen
855867
// dropping those frames and hanging the body drain.
856868
var exchange = _pending.GetOrAdd(request.RequestId, id => new LlmInferenceExchange(id, _getServerRpc));
857869
exchange.Method = request.Method;
858-
exchange.Context = new CopilotRequestContext
870+
exchange.Context = new CopilotRequestContext(request.RequestId, request.Url, ToReadOnlyHeaders(request.Headers))
859871
{
860-
RequestId = request.RequestId,
861872
SessionId = request.SessionId,
862873
Transport = transport,
863-
Url = request.Url,
864-
Headers = ToReadOnlyHeaders(request.Headers),
865874
CancellationToken = exchange.Abort.Token,
866875
};
867876

dotnet/test/E2E/CopilotRequestWebSocketE2ETests.cs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ protected override Task<HttpResponseMessage> SendRequestAsync(HttpRequestMessage
116116

117117
protected override Task<CopilotWebSocketHandler> OpenWebSocketAsync(CopilotRequestContext ctx)
118118
{
119-
var wsUrl = Rewrite(new Uri(ctx.Url)).ToString();
120-
return Task.FromResult<CopilotWebSocketHandler>(new CountingForwardingWebSocketHandler(ctx, wsUrl, counters));
119+
ctx = new CopilotRequestContext(ctx) { Url = Rewrite(new Uri(ctx.Url)).ToString() };
120+
return Task.FromResult<CopilotWebSocketHandler>(new CountingForwardingWebSocketHandler(ctx, counters));
121121
}
122122

123123
private Uri Rewrite(Uri original) => new UriBuilder(original)
@@ -133,9 +133,8 @@ protected override Task<CopilotWebSocketHandler> OpenWebSocketAsync(CopilotReque
133133
/// </summary>
134134
internal sealed class CountingForwardingWebSocketHandler(
135135
CopilotRequestContext context,
136-
string url,
137136
HandlerCounters counters)
138-
: CopilotWebSocketForwarder(context, url)
137+
: CopilotWebSocketForwarder(context)
139138
{
140139
public override Task SendRequestMessageAsync(CopilotWebSocketMessage message)
141140
{

java/src/main/java/com/github/copilot/CopilotRequestContext.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ public final class CopilotRequestContext {
3939
this.cancellation = cancellation;
4040
}
4141

42+
private CopilotRequestContext(String requestId, @Nullable String sessionId, CopilotRequestTransport transport,
43+
String url, Map<String, List<String>> headers, CompletableFuture<Void> cancellation,
44+
LlmWebSocketResponseBridge webSocketResponse) {
45+
this(requestId, sessionId, transport, url, headers, cancellation);
46+
this.webSocketResponse = webSocketResponse;
47+
}
48+
4249
/**
4350
* Gets the opaque runtime-minted request id, stable across the request
4451
* lifecycle.
@@ -88,6 +95,30 @@ public Map<String, List<String>> headers() {
8895
return headers;
8996
}
9097

98+
/**
99+
* Returns a copy of this context with a different request URL.
100+
*
101+
* @param url
102+
* the replacement request URL
103+
* @return the copied context
104+
*/
105+
public CopilotRequestContext withUrl(String url) {
106+
return new CopilotRequestContext(requestId, sessionId, transport, url, headers, cancellation,
107+
webSocketResponse);
108+
}
109+
110+
/**
111+
* Returns a copy of this context with different request headers.
112+
*
113+
* @param headers
114+
* the replacement request headers
115+
* @return the copied context
116+
*/
117+
public CopilotRequestContext withHeaders(Map<String, List<String>> headers) {
118+
return new CopilotRequestContext(requestId, sessionId, transport, url, headers, cancellation,
119+
webSocketResponse);
120+
}
121+
91122
/**
92123
* A future that completes when the runtime cancels this in-flight request (for
93124
* example because the agent turn was aborted upstream). Subclasses that issue

java/src/main/java/com/github/copilot/CopilotWebSocketForwarder.java

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
*/
2828
public class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
2929

30-
private final String url;
31-
private final Map<String, List<String>> headers;
32-
3330
private volatile WebSocket webSocket;
3431

3532
/**
@@ -40,37 +37,7 @@ public class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
4037
* the per-request context
4138
*/
4239
public CopilotWebSocketForwarder(CopilotRequestContext context) {
43-
this(context, context.url(), context.headers());
44-
}
45-
46-
/**
47-
* Creates a forwarding handler targeting {@code url} with the handshake headers
48-
* from {@code context}.
49-
*
50-
* @param context
51-
* the per-request context
52-
* @param url
53-
* the upstream WebSocket URL
54-
*/
55-
public CopilotWebSocketForwarder(CopilotRequestContext context, String url) {
56-
this(context, url, context.headers());
57-
}
58-
59-
/**
60-
* Creates a forwarding handler targeting {@code url} with the given handshake
61-
* headers.
62-
*
63-
* @param context
64-
* the per-request context
65-
* @param url
66-
* the upstream WebSocket URL
67-
* @param headers
68-
* the handshake headers, multi-valued
69-
*/
70-
public CopilotWebSocketForwarder(CopilotRequestContext context, String url, Map<String, List<String>> headers) {
7140
super(context);
72-
this.url = url;
73-
this.headers = headers;
7441
}
7542

7643
@Override
@@ -79,6 +46,7 @@ void open() throws Exception {
7946
return;
8047
}
8148
WebSocket.Builder builder = HttpClient.newHttpClient().newWebSocketBuilder();
49+
Map<String, List<String>> headers = context.headers();
8250
if (headers != null) {
8351
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
8452
if (CopilotRequestHandler.isForbiddenRequestHeader(entry.getKey()) || entry.getValue() == null) {
@@ -90,8 +58,8 @@ void open() throws Exception {
9058
}
9159
}
9260
try {
93-
this.webSocket = builder.buildAsync(URI.create(normalizeWebSocketScheme(url)), new ForwardingListener())
94-
.join();
61+
this.webSocket = builder
62+
.buildAsync(URI.create(normalizeWebSocketScheme(context.url())), new ForwardingListener()).join();
9563
} catch (Exception e) {
9664
throw unwrap(e);
9765
}

java/src/test/java/com/github/copilot/CopilotRequestHandlerE2ETest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ protected HttpResponse<InputStream> sendRequest(HttpRequest request, CopilotRequ
123123

124124
@Override
125125
protected CopilotWebSocketHandler openWebSocket(CopilotRequestContext rctx) {
126-
String rewritten = rewriteHost(wsBase, URI.create(rctx.url()));
127-
return new CopilotWebSocketForwarder(rctx, rewritten) {
126+
return new CopilotWebSocketForwarder(rctx.withUrl(rewriteHost(wsBase, URI.create(rctx.url())))) {
128127
@Override
129128
public void sendRequestMessage(CopilotWebSocketMessage message) throws Exception {
130129
wsRequestMessages.incrementAndGet();

nodejs/src/copilotRequestHandler.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ export interface CopilotRequestContext {
3434
readonly requestId: string;
3535
readonly sessionId?: string;
3636
readonly transport: "http" | "websocket";
37-
readonly url: string;
38-
readonly headers: LlmInferenceHeaders;
37+
url: string;
38+
headers: LlmInferenceHeaders;
3939
readonly signal: AbortSignal;
4040
}
4141

@@ -139,12 +139,10 @@ export abstract class CopilotWebSocketHandler implements AsyncDisposable {
139139
* @experimental
140140
*/
141141
export class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
142-
readonly #url: string;
143142
#upstream: WebSocket | null = null;
144143

145-
constructor(context: CopilotRequestContext, url = context.url) {
144+
constructor(context: CopilotRequestContext) {
146145
super(context);
147-
this.#url = url;
148146
}
149147

150148
override sendRequestMessage(data: string | Uint8Array): void {
@@ -159,7 +157,7 @@ export class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
159157
if (this.#upstream) {
160158
return;
161159
}
162-
const upstream = new WebSocket(this.#url);
160+
const upstream = new WebSocket(this.context.url);
163161
upstream.binaryType = "arraybuffer";
164162
this.#upstream = upstream;
165163
upstream.addEventListener("message", (event) => {

0 commit comments

Comments
 (0)