Skip to content

Commit f800768

Browse files
Address review feedback on HTTP request callback support (+ cross-SDK parity) (#1775)
1 parent cafa530 commit f800768

7 files changed

Lines changed: 229 additions & 75 deletions

File tree

dotnet/src/CopilotRequestHandler.cs

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*--------------------------------------------------------------------------------------------*/
44

55
using GitHub.Copilot.Rpc;
6+
using System.Buffers;
67
using System.Collections.Concurrent;
78
using System.Diagnostics.CodeAnalysis;
89
using System.Net.WebSockets;
@@ -76,13 +77,10 @@ public readonly struct CopilotWebSocketMessage(ReadOnlyMemory<byte> data, bool i
7677
public bool IsBinary { get; } = isBinary;
7778

7879
/// <summary>Decodes the payload as UTF-8 text.</summary>
79-
public string GetText() => Encoding.UTF8.GetString(Data.ToArray());
80+
public string GetText() => Encoding.UTF8.GetString(Data.Span);
8081

8182
/// <summary>Creates a text message from a UTF-8 string.</summary>
82-
public static CopilotWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false);
83-
84-
/// <summary>Creates a binary message from raw bytes.</summary>
85-
public static CopilotWebSocketMessage Binary(ReadOnlyMemory<byte> data) => new(data, isBinary: true);
83+
public static CopilotWebSocketMessage FromText(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false);
8684
}
8785

8886
/// <summary>
@@ -253,7 +251,12 @@ internal override async Task OpenAsync()
253251
await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false);
254252
_upstream = socket;
255253
_pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken);
256-
_responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token);
254+
255+
// Start the pump without a cancellation token on Task.Run itself: if the
256+
// linked token is already cancelled, we still want PumpResponsesAsync to
257+
// run so its cleanup (closing the upstream and finalising the response)
258+
// executes rather than the task being cancelled before it ever starts.
259+
_responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token));
257260
}
258261

259262
/// <summary>
@@ -270,10 +273,10 @@ public override Task SendRequestMessageAsync(CopilotWebSocketMessage message)
270273

271274
var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text;
272275
return _upstream.SendAsync(
273-
new ArraySegment<byte>(message.Data.ToArray()),
276+
message.Data,
274277
type,
275278
endOfMessage: true,
276-
Context.CancellationToken);
279+
Context.CancellationToken).AsTask();
277280
}
278281

279282
/// <inheritdoc />
@@ -346,34 +349,41 @@ await CloseAsync(new CopilotWebSocketCloseStatus
346349

347350
private static async Task<CopilotWebSocketMessage?> ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken)
348351
{
349-
var buffer = new byte[16 * 1024];
350-
using var assembled = new MemoryStream();
351-
WebSocketReceiveResult result;
352-
do
352+
var buffer = ArrayPool<byte>.Shared.Rent(16 * 1024);
353+
try
353354
{
354-
try
355-
{
356-
result = await socket.ReceiveAsync(new ArraySegment<byte>(buffer), cancellationToken).ConfigureAwait(false);
357-
}
358-
catch (OperationCanceledException)
359-
{
360-
return null;
361-
}
362-
catch (WebSocketException)
355+
using var assembled = new MemoryStream();
356+
ValueWebSocketReceiveResult result;
357+
do
363358
{
364-
return null;
365-
}
359+
try
360+
{
361+
result = await socket.ReceiveAsync(buffer.AsMemory(), cancellationToken).ConfigureAwait(false);
362+
}
363+
catch (OperationCanceledException)
364+
{
365+
return null;
366+
}
367+
catch (WebSocketException)
368+
{
369+
return null;
370+
}
366371

367-
if (result.MessageType == WebSocketMessageType.Close)
368-
{
369-
return null;
372+
if (result.MessageType == WebSocketMessageType.Close)
373+
{
374+
return null;
375+
}
376+
377+
assembled.Write(buffer, 0, result.Count);
370378
}
379+
while (!result.EndOfMessage);
371380

372-
assembled.Write(buffer, 0, result.Count);
381+
return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary);
382+
}
383+
finally
384+
{
385+
ArrayPool<byte>.Shared.Return(buffer);
373386
}
374-
while (!result.EndOfMessage);
375-
376-
return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary);
377387
}
378388

379389
private static async Task CloseWebSocketQuietlyAsync(WebSocket socket)
@@ -430,13 +440,34 @@ public class CopilotRequestHandler
430440
{
431441
private static readonly HttpClient s_sharedHttpClient = new();
432442

443+
private readonly HttpClient _httpClient;
444+
445+
/// <summary>
446+
/// Initializes a new instance that issues upstream requests using a shared
447+
/// process-wide <see cref="HttpClient"/>.
448+
/// </summary>
449+
public CopilotRequestHandler()
450+
: this(null)
451+
{
452+
}
453+
454+
/// <summary>
455+
/// Initializes a new instance that issues upstream requests using the supplied
456+
/// <see cref="HttpClient"/>, or a shared process-wide instance when <paramref name="httpClient"/> is <see langword="null"/>.
457+
/// </summary>
458+
/// <param name="httpClient">The <see cref="HttpClient"/> to use, or <see langword="null"/> to use the shared instance.</param>
459+
public CopilotRequestHandler(HttpClient? httpClient)
460+
{
461+
_httpClient = httpClient ?? s_sharedHttpClient;
462+
}
463+
433464
/// <summary>
434465
/// Issue the upstream HTTP request. Override to mutate the request before
435466
/// calling <c>base</c>, mutate the returned response after, or replace the
436467
/// call entirely.
437468
/// </summary>
438469
protected virtual Task<HttpResponseMessage> SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) =>
439-
s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken);
470+
_httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken);
440471

441472
/// <summary>
442473
/// Open the upstream WebSocket connection. Override to return a custom
@@ -464,7 +495,7 @@ private async Task HandleHttpAsync(LlmInferenceExchange exchange)
464495

465496
private static async Task<HttpRequestMessage> BuildHttpRequestAsync(LlmInferenceExchange exchange)
466497
{
467-
var method = new HttpMethod(exchange.Method.ToUpperInvariant());
498+
var method = new HttpMethod(exchange.Method);
468499
var message = new HttpRequestMessage(method, exchange.Context.Url);
469500

470501
var hasBody = method != HttpMethod.Get && method != HttpMethod.Head;
@@ -499,18 +530,10 @@ await exchange.StartResponseAsync(
499530
HeadersToMultiMap(response)).ConfigureAwait(false);
500531

501532
var ct = exchange.Context.CancellationToken;
502-
#if NETSTANDARD2_0
503-
using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
504-
#else
505533
using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false);
506-
#endif
507534
var buffer = new byte[16 * 1024];
508535
int read;
509-
#if NETSTANDARD2_0
510-
while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ct).ConfigureAwait(false)) > 0)
511-
#else
512536
while ((read = await stream.ReadAsync(buffer.AsMemory(), ct).ConfigureAwait(false)) > 0)
513-
#endif
514537
{
515538
await exchange.WriteResponseAsync(new ReadOnlyMemory<byte>(buffer, 0, read)).ConfigureAwait(false);
516539
}
@@ -579,7 +602,7 @@ private static async Task<byte[]> DrainAsync(IAsyncEnumerable<ReadOnlyMemory<byt
579602
{
580603
if (chunk.Length > 0)
581604
{
582-
buffer.Write(chunk.ToArray(), 0, chunk.Length);
605+
buffer.Write(chunk.Span);
583606
}
584607
}
585608

dotnet/src/Polyfills/DownlevelExtensions.cs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,25 @@ public async ValueTask ReadExactlyAsync(Memory<byte> buffer, Threading.Cancellat
376376
totalRead += bytesRead;
377377
}
378378
}
379+
380+
public void Write(ReadOnlySpan<byte> buffer)
381+
{
382+
if (buffer.IsEmpty)
383+
{
384+
return;
385+
}
386+
387+
var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
388+
try
389+
{
390+
buffer.CopyTo(rented);
391+
stream.Write(rented, 0, buffer.Length);
392+
}
393+
finally
394+
{
395+
ArrayPool<byte>.Shared.Return(rented);
396+
}
397+
}
379398
}
380399

381400
private static async ValueTask<int> ReadAsyncSlow(Stream stream, Memory<byte> buffer, Threading.CancellationToken cancellationToken)
@@ -646,3 +665,125 @@ public async Task<T> WaitAsync(TimeSpan timeout, CancellationToken cancellationT
646665
}
647666
}
648667
}
668+
669+
namespace System.Text
670+
{
671+
internal static class DownlevelEncodingExtensions
672+
{
673+
extension(Encoding encoding)
674+
{
675+
public string GetString(ReadOnlySpan<byte> bytes)
676+
{
677+
if (bytes.IsEmpty)
678+
{
679+
return string.Empty;
680+
}
681+
682+
var rented = ArrayPool<byte>.Shared.Rent(bytes.Length);
683+
try
684+
{
685+
bytes.CopyTo(rented);
686+
return encoding.GetString(rented, 0, bytes.Length);
687+
}
688+
finally
689+
{
690+
ArrayPool<byte>.Shared.Return(rented);
691+
}
692+
}
693+
}
694+
}
695+
}
696+
697+
namespace System.Net.Http
698+
{
699+
internal static class DownlevelHttpContentExtensions
700+
{
701+
extension(HttpContent content)
702+
{
703+
public Task<IO.Stream> ReadAsStreamAsync(Threading.CancellationToken cancellationToken)
704+
{
705+
// The underlying netstandard2.0 ReadAsStreamAsync() can't be cancelled,
706+
// but honour an already-cancelled token to match the BCL overload.
707+
cancellationToken.ThrowIfCancellationRequested();
708+
return content.ReadAsStreamAsync();
709+
}
710+
}
711+
}
712+
}
713+
714+
namespace System.Net.WebSockets
715+
{
716+
/// <summary>
717+
/// Polyfill for the <c>System.Net.WebSockets.ValueWebSocketReceiveResult</c>
718+
/// struct, which is unavailable on .NET Standard 2.0.
719+
/// </summary>
720+
internal readonly struct ValueWebSocketReceiveResult
721+
{
722+
public ValueWebSocketReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage)
723+
{
724+
Count = count;
725+
MessageType = messageType;
726+
EndOfMessage = endOfMessage;
727+
}
728+
729+
public int Count { get; }
730+
731+
public WebSocketMessageType MessageType { get; }
732+
733+
public bool EndOfMessage { get; }
734+
}
735+
736+
internal static class DownlevelWebSocketExtensions
737+
{
738+
extension(WebSocket socket)
739+
{
740+
public ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken)
741+
{
742+
if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> segment))
743+
{
744+
return new ValueTask(socket.SendAsync(segment, messageType, endOfMessage, cancellationToken));
745+
}
746+
747+
return SendAsyncSlow(socket, buffer, messageType, endOfMessage, cancellationToken);
748+
}
749+
750+
public ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte> buffer, Threading.CancellationToken cancellationToken) =>
751+
ReceiveAsyncCore(socket, buffer, cancellationToken);
752+
}
753+
754+
private static async ValueTask SendAsyncSlow(WebSocket socket, ReadOnlyMemory<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, Threading.CancellationToken cancellationToken)
755+
{
756+
var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
757+
try
758+
{
759+
buffer.CopyTo(rented);
760+
await socket.SendAsync(new ArraySegment<byte>(rented, 0, buffer.Length), messageType, endOfMessage, cancellationToken).ConfigureAwait(false);
761+
}
762+
finally
763+
{
764+
ArrayPool<byte>.Shared.Return(rented);
765+
}
766+
}
767+
768+
private static async ValueTask<ValueWebSocketReceiveResult> ReceiveAsyncCore(WebSocket socket, Memory<byte> buffer, Threading.CancellationToken cancellationToken)
769+
{
770+
if (Runtime.InteropServices.MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> segment))
771+
{
772+
var result = await socket.ReceiveAsync(segment, cancellationToken).ConfigureAwait(false);
773+
return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage);
774+
}
775+
776+
var rented = ArrayPool<byte>.Shared.Rent(buffer.Length);
777+
try
778+
{
779+
var result = await socket.ReceiveAsync(new ArraySegment<byte>(rented, 0, buffer.Length), cancellationToken).ConfigureAwait(false);
780+
new ReadOnlyMemory<byte>(rented, 0, result.Count).CopyTo(buffer);
781+
return new ValueWebSocketReceiveResult(result.Count, result.MessageType, result.EndOfMessage);
782+
}
783+
finally
784+
{
785+
ArrayPool<byte>.Shared.Return(rented);
786+
}
787+
}
788+
}
789+
}

go/copilot_request_handler.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,12 @@ type CopilotWebSocketMessage struct {
9090
// Text decodes the frame payload as a UTF-8 string.
9191
func (m CopilotWebSocketMessage) Text() string { return string(m.Data) }
9292

93-
// NewTextMessage creates a text-frame message from a UTF-8 string.
93+
// NewTextMessage creates a text-frame message from a UTF-8 string. Binary
94+
// frames are constructed directly with CopilotWebSocketMessage{Data: ..., Binary: true}.
9495
func NewTextMessage(text string) CopilotWebSocketMessage {
9596
return CopilotWebSocketMessage{Data: []byte(text), Binary: false}
9697
}
9798

98-
// NewBinaryMessage creates a binary-frame message from raw bytes.
99-
func NewBinaryMessage(data []byte) CopilotWebSocketMessage {
100-
return CopilotWebSocketMessage{Data: data, Binary: true}
101-
}
102-
10399
// CopilotRequestHandler is the idiomatic handler for intercepting or replacing
104100
// LLM inference requests. HTTP requests are forwarded through Transport (an
105101
// [http.RoundTripper]); supply a custom RoundTripper to mutate the request,
@@ -227,9 +223,9 @@ func streamResponseToSink(resp *http.Response, sink *responseSink) error {
227223
for {
228224
n, readErr := resp.Body.Read(buf)
229225
if n > 0 {
230-
frame := make([]byte, n)
231-
copy(frame, buf[:n])
232-
if err := sink.writeText(frame); err != nil {
226+
// writeText copies eagerly via string(...), so the reused read
227+
// buffer can be passed directly without an extra per-chunk alloc.
228+
if err := sink.writeText(buf[:n]); err != nil {
233229
return err
234230
}
235231
}

java/src/main/java/com/github/copilot/CopilotRequestHandler.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ private static void streamResponse(HttpResponse<InputStream> response, LlmInfere
143143
int n;
144144
while ((n = body.read(buffer)) != -1) {
145145
if (n > 0) {
146-
byte[] frame = new byte[n];
147-
System.arraycopy(buffer, 0, frame, 0, n);
148-
exchange.writeResponseBinary(frame);
146+
exchange.writeResponseBinary(buffer, 0, n);
149147
}
150148
}
151149
} catch (IOException e) {

0 commit comments

Comments
 (0)