Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions src/_api_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,12 @@ export class ApiClient implements GeminiNextGenAPIClientAdapter {
url.toString(),
request.abortSignal,
);
return this.unaryApiCall(url, requestInit, request.httpMethod);
return this.unaryApiCall(
url,
requestInit,
request.httpMethod,
patchedHttpOptions,
);
}

private patchHttpOptions(
Expand Down Expand Up @@ -504,7 +509,12 @@ export class ApiClient implements GeminiNextGenAPIClientAdapter {
url.toString(),
request.abortSignal,
);
return this.streamApiCall(url, requestInit, request.httpMethod);
return this.streamApiCall(
url,
requestInit,
request.httpMethod,
patchedHttpOptions,
);
}

private async includeExtraHttpOptionsToRequestInit(
Expand Down Expand Up @@ -579,11 +589,16 @@ export class ApiClient implements GeminiNextGenAPIClientAdapter {
url: URL,
requestInit: RequestInit,
httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE',
httpOptions: types.HttpOptions,
): Promise<types.HttpResponse> {
return this.apiCall(url.toString(), {
...requestInit,
method: httpMethod,
})
return this.apiCall(
url.toString(),
{
...requestInit,
method: httpMethod,
},
httpOptions,
)
.then(async (response) => {
await throwErrorIfNotOK(response);
return new types.HttpResponse(response);
Expand All @@ -601,11 +616,16 @@ export class ApiClient implements GeminiNextGenAPIClientAdapter {
url: URL,
requestInit: RequestInit,
httpMethod: 'GET' | 'POST' | 'PATCH' | 'DELETE',
httpOptions: types.HttpOptions,
): Promise<AsyncGenerator<types.HttpResponse>> {
return this.apiCall(url.toString(), {
...requestInit,
method: httpMethod,
})
return this.apiCall(
url.toString(),
{
...requestInit,
method: httpMethod,
},
httpOptions,
)
.then(async (response) => {
await throwErrorIfNotOK(response);
return this.processStreamResponse(response);
Expand Down Expand Up @@ -724,15 +744,13 @@ export class ApiClient implements GeminiNextGenAPIClientAdapter {
private async apiCall(
url: string,
requestInit: RequestInit,
httpOptions: types.HttpOptions,
): Promise<Response> {
if (
!this.clientOptions.httpOptions ||
!this.clientOptions.httpOptions.retryOptions
) {
if (!httpOptions.retryOptions) {
return fetch(url, requestInit);
}

const retryOptions = this.clientOptions.httpOptions.retryOptions;
const retryOptions = httpOptions.retryOptions;
const runFetch = async () => {
const response = await fetch(url, requestInit);

Expand Down
99 changes: 99 additions & 0 deletions test/unit/api_client_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,75 @@ describe('ApiClient', () => {
expect(fetchSpy).toHaveBeenCalledTimes(2);
});

it('should retry requests if retry options are set on the request', async () => {
const client = new ApiClient({
auth: new FakeAuth(),
project: 'vertex-project',
location: 'vertex-location',
vertexai: true,
apiVersion: 'v1beta1',
uploader: new CrossUploader(),
downloader: new CrossDownloader(),
});
const fetchSpy = spyOn(global, 'fetch').and.returnValue(
Promise.resolve(
new Response(
JSON.stringify({'error': 'Internal Server Error'}),
fetch500Options,
),
),
);
await client
.request({
path: 'test-path',
httpMethod: 'POST',
httpOptions: {
retryOptions: {
attempts: 2,
},
},
})
.catch(() => {});
expect(fetchSpy).toHaveBeenCalledTimes(2);
});

it('should let request retry options override client retry options', async () => {
const client = new ApiClient({
auth: new FakeAuth(),
project: 'vertex-project',
location: 'vertex-location',
vertexai: true,
apiVersion: 'v1beta1',
httpOptions: {
retryOptions: {
attempts: 3,
},
},
uploader: new CrossUploader(),
downloader: new CrossDownloader(),
});
const fetchSpy = spyOn(global, 'fetch').and.returnValue(
Promise.resolve(
new Response(
JSON.stringify({'error': 'Internal Server Error'}),
fetch500Options,
),
),
);
await client
.request({
path: 'test-path',
httpMethod: 'POST',
httpOptions: {
retryOptions: {
attempts: 1,
},
},
})
.catch(() => {});
expect(fetchSpy).toHaveBeenCalledTimes(1);
});

it('should not retry requests if retry options are not set', async () => {
const client = new ApiClient({
auth: new FakeAuth(),
Expand Down Expand Up @@ -1998,6 +2067,36 @@ describe('ApiClient', () => {
'https://custom-request-base-url.googleapis.com/v1alpha/test-path?alt=sse',
);
});
it('should retry requestStream if retry options are set on the request', async () => {
const client = new ApiClient({
auth: new FakeAuth('test-api-key'),
apiKey: 'test-api-key',
uploader: new CrossUploader(),
downloader: new CrossDownloader(),
});
const fetchSpy = spyOn(global, 'fetch').and.returnValue(
Promise.resolve(
new Response(
JSON.stringify({'error': 'Internal Server Error'}),
fetch500Options,
),
),
);

await client
.requestStream({
path: 'test-path',
httpMethod: 'POST',
httpOptions: {
retryOptions: {
attempts: 2,
},
},
})
.catch(() => {});

expect(fetchSpy).toHaveBeenCalledTimes(2);
});
it('should set bearer token for vertexai', async () => {
const client = new ApiClient({
auth: new FakeAuth(),
Expand Down