diff --git a/api_client.go b/api_client.go index be46112c..22e568f1 100644 --- a/api_client.go +++ b/api_client.go @@ -19,11 +19,14 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "iter" "log" "math" + "math/rand" + "net" "net/http" "net/textproto" "net/url" @@ -71,7 +74,7 @@ func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiC } req = req.WithContext(requestContext) - resp, err := doRequest(ac, req) + resp, err := doRequestWithRetry(ac, req, httpOptions.RetryOptions) if err != nil { return err } @@ -102,7 +105,7 @@ func sendRequest(ctx context.Context, ac *apiClient, path string, method string, } req = req.WithContext(requestContext) - resp, err := doRequest(ac, req) + resp, err := doRequestWithRetry(ac, req, httpOptions.RetryOptions) if err != nil { return nil, err } @@ -115,13 +118,13 @@ func sendRequest(ctx context.Context, ac *apiClient, path string, method string, func downloadFile(ctx context.Context, ac *apiClient, path string, httpOptions *HTTPOptions) ([]byte, error) { // The client and request timeout are not used for downloadFile. // TODO(b/427540996): implement timeout. - req, _, err := buildRequest(ctx, ac, path, nil, http.MethodGet, httpOptions) + req, patchedOptions, err := buildRequest(ctx, ac, path, nil, http.MethodGet, httpOptions) if err != nil { return nil, err } req = req.WithContext(ctx) - resp, err := doRequest(ac, req) + resp, err := doRequestWithRetry(ac, req, patchedOptions.RetryOptions) if err != nil { return nil, err } @@ -216,6 +219,9 @@ func patchHTTPOptions(options, patchOptions HTTPOptions) (*HTTPOptions, error) { if patchOptions.ExtraBody != nil { copyOption.ExtraBody = patchOptions.ExtraBody } + if patchOptions.RetryOptions != nil { + copyOption.RetryOptions = patchOptions.RetryOptions + } // Request timeout config overrides client timeout config. // So we need a pointer type so that we know the request timeout // is explicitly set or not. @@ -412,6 +418,120 @@ func doRequest(ac *apiClient, req *http.Request) (*http.Response, error) { return resp, nil } +// Default retry settings. +// See https://cloud.google.com/storage/docs/retry-strategy. +const ( + defaultRetryAttempts = 5 + defaultRetryInitialDelay = time.Second + defaultRetryMaxDelay = 60 * time.Second + defaultRetryExpBase = 2.0 + defaultRetryJitter = time.Second +) + +var defaultRetryHTTPStatusCodes = []int{ + http.StatusRequestTimeout, + http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout, +} + +func resolvedRetryOptions(opts *HTTPRetryOptions) *HTTPRetryOptions { + if opts == nil { + return nil + } + resolved := *opts + if resolved.Attempts == 0 { + resolved.Attempts = defaultRetryAttempts + } + if resolved.Attempts <= 1 { + return nil + } + if resolved.InitialDelay <= 0 { + resolved.InitialDelay = defaultRetryInitialDelay + } + if resolved.MaxDelay <= 0 { + resolved.MaxDelay = defaultRetryMaxDelay + } + if resolved.ExpBase <= 0 { + resolved.ExpBase = defaultRetryExpBase + } + if resolved.Jitter < 0 { + resolved.Jitter = 0 + } else if resolved.Jitter == 0 { + resolved.Jitter = defaultRetryJitter + } + if len(resolved.HTTPStatusCodes) == 0 { + resolved.HTTPStatusCodes = defaultRetryHTTPStatusCodes + } + return &resolved +} + +func backoffDelay(opts *HTTPRetryOptions, retryNum int) time.Duration { + delay := float64(opts.InitialDelay) * math.Pow(opts.ExpBase, float64(retryNum-1)) + delay += rand.Float64() * float64(opts.Jitter) + if maxD := float64(opts.MaxDelay); delay > maxD { + delay = maxD + } + return time.Duration(delay) +} + +func isRetriableTransportErr(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) +} + +func doRequestWithRetry(ac *apiClient, req *http.Request, retryOpts *HTTPRetryOptions) (*http.Response, error) { + resolved := resolvedRetryOptions(retryOpts) + if resolved == nil { + return doRequest(ac, req) + } + // http.NewRequest sets GetBody for bytes.Buffer/Reader bodies so the body + // can be rewound on retry. Skip retry if the body is non-empty and not + // rewindable to avoid re-sending an empty payload. + canRewindBody := req.Body == nil || req.Body == http.NoBody || req.GetBody != nil + + var resp *http.Response + var lastErr error + for attempt := 1; attempt <= resolved.Attempts; attempt++ { + if attempt > 1 { + if !canRewindBody { + return resp, lastErr + } + if req.GetBody != nil { + body, gerr := req.GetBody() + if gerr != nil { + return nil, fmt.Errorf("doRequestWithRetry: rewinding body: %w", gerr) + } + req.Body = body + } + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(backoffDelay(resolved, attempt-1)): + } + } + resp, lastErr = doRequest(ac, req) + if lastErr != nil { + if attempt == resolved.Attempts || !isRetriableTransportErr(lastErr) { + return resp, lastErr + } + continue + } + if httpStatusOk(resp) { + return resp, nil + } + if !slices.Contains(resolved.HTTPStatusCodes, resp.StatusCode) || attempt == resolved.Attempts { + return resp, nil + } + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + lastErr = fmt.Errorf("doRequestWithRetry: retriable status %d", resp.StatusCode) + } + return resp, lastErr +} + func deserializeUnaryResponse(resp *http.Response) (map[string]any, error) { if !httpStatusOk(resp) { return nil, newAPIError(resp) diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 00000000..7b1dc110 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,358 @@ +package genai + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +func fastRetry(attempts int) *HTTPRetryOptions { + return &HTTPRetryOptions{ + Attempts: attempts, + InitialDelay: time.Millisecond, + MaxDelay: 5 * time.Millisecond, + ExpBase: 2, + Jitter: time.Millisecond, + } +} + +func TestResolvedRetryOptions(t *testing.T) { + tests := []struct { + desc string + in *HTTPRetryOptions + wantNil bool + wantAttempts int + wantInitial time.Duration + wantMaxDelay time.Duration + wantExpBase float64 + wantJitter time.Duration + wantCodesLen int + }{ + { + desc: "nil returns nil", + in: nil, + wantNil: true, + }, + { + desc: "attempts=1 disables retry", + in: &HTTPRetryOptions{Attempts: 1}, + wantNil: true, + }, + { + desc: "defaults applied for empty options", + in: &HTTPRetryOptions{}, + wantAttempts: defaultRetryAttempts, + wantInitial: defaultRetryInitialDelay, + wantMaxDelay: defaultRetryMaxDelay, + wantExpBase: defaultRetryExpBase, + wantJitter: defaultRetryJitter, + wantCodesLen: len(defaultRetryHTTPStatusCodes), + }, + { + desc: "user values preserved", + in: &HTTPRetryOptions{ + Attempts: 7, + InitialDelay: 2 * time.Second, + MaxDelay: 30 * time.Second, + ExpBase: 3, + Jitter: 500 * time.Millisecond, + HTTPStatusCodes: []int{500, 503}, + }, + wantAttempts: 7, + wantInitial: 2 * time.Second, + wantMaxDelay: 30 * time.Second, + wantExpBase: 3, + wantJitter: 500 * time.Millisecond, + wantCodesLen: 2, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + got := resolvedRetryOptions(tt.in) + if tt.wantNil { + if got != nil { + t.Fatalf("resolvedRetryOptions() = %#v, want nil", got) + } + return + } + if got == nil { + t.Fatal("resolvedRetryOptions() = nil, want non-nil") + } + if got.Attempts != tt.wantAttempts || got.InitialDelay != tt.wantInitial || + got.MaxDelay != tt.wantMaxDelay || got.ExpBase != tt.wantExpBase || + got.Jitter != tt.wantJitter || len(got.HTTPStatusCodes) != tt.wantCodesLen { + t.Errorf("resolvedRetryOptions() = %#v", got) + } + }) + } +} + +func TestBackoffDelay(t *testing.T) { + tests := []struct { + desc string + opts *HTTPRetryOptions + n int + want time.Duration + }{ + { + desc: "first retry uses initial delay", + opts: &HTTPRetryOptions{InitialDelay: 100 * time.Millisecond, MaxDelay: 10 * time.Second, ExpBase: 2}, + n: 1, + want: 100 * time.Millisecond, + }, + { + desc: "second retry doubles", + opts: &HTTPRetryOptions{InitialDelay: 100 * time.Millisecond, MaxDelay: 10 * time.Second, ExpBase: 2}, + n: 2, + want: 200 * time.Millisecond, + }, + { + desc: "third retry doubles again", + opts: &HTTPRetryOptions{InitialDelay: 100 * time.Millisecond, MaxDelay: 10 * time.Second, ExpBase: 2}, + n: 3, + want: 400 * time.Millisecond, + }, + { + desc: "capped by max delay", + opts: &HTTPRetryOptions{InitialDelay: time.Second, MaxDelay: 2 * time.Second, ExpBase: 10}, + n: 5, + want: 2 * time.Second, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if got := backoffDelay(tt.opts, tt.n); got != tt.want { + t.Errorf("backoffDelay() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestBackoffDelay_CapsJitter(t *testing.T) { + opts := &HTTPRetryOptions{ + InitialDelay: time.Second, + MaxDelay: time.Second, + ExpBase: 2, + Jitter: 5 * time.Second, + } + for i := 0; i < 50; i++ { + if got := backoffDelay(opts, 1); got > time.Second { + t.Fatalf("backoffDelay() = %v, want <= MaxDelay (1s)", got) + } + } +} + +func TestDoRequestWithRetry(t *testing.T) { + tests := []struct { + desc string + retryOptions *HTTPRetryOptions + serverHandler func(calls *int32) http.HandlerFunc + wantCalls int32 + wantErr bool + wantStatus int // expected APIError.Code when wantErr is true + }{ + { + desc: "no retry options, single attempt on 5xx", + retryOptions: nil, + serverHandler: func(calls *int32) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(calls, 1) + w.WriteHeader(http.StatusServiceUnavailable) + } + }, + wantCalls: 1, + wantErr: true, + wantStatus: http.StatusServiceUnavailable, + }, + { + desc: "retries until success", + retryOptions: fastRetry(5), + serverHandler: func(calls *int32) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(calls, 1) < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + } + }, + wantCalls: 3, + wantErr: false, + }, + { + desc: "exhausts attempts on persistent retriable status", + retryOptions: fastRetry(3), + serverHandler: func(calls *int32) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(calls, 1) + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":{"code":502,"message":"bad gateway"}}`)) + } + }, + wantCalls: 3, + wantErr: true, + wantStatus: http.StatusBadGateway, + }, + { + desc: "non-retriable status returns immediately", + retryOptions: fastRetry(5), + serverHandler: func(calls *int32) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(calls, 1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"code":400,"message":"bad"}}`)) + } + }, + wantCalls: 1, + wantErr: true, + wantStatus: http.StatusBadRequest, + }, + { + desc: "custom status codes trigger retry", + retryOptions: func() *HTTPRetryOptions { + o := fastRetry(4) + o.HTTPStatusCodes = []int{http.StatusTeapot} + return o + }(), + serverHandler: func(calls *int32) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if atomic.AddInt32(calls, 1) < 2 { + w.WriteHeader(http.StatusTeapot) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + } + }, + wantCalls: 2, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var calls int32 + ts := httptest.NewServer(tt.serverHandler(&calls)) + defer ts.Close() + + ac := &apiClient{ + clientConfig: &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + HTTPClient: ts.Client(), + }, + } + + _, err := sendRequest(context.Background(), ac, "foo", http.MethodPost, + map[string]any{"k": "v"}, + &HTTPOptions{BaseURL: ts.URL, RetryOptions: tt.retryOptions}) + + if (err != nil) != tt.wantErr { + t.Fatalf("sendRequest() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + apiErr, ok := err.(APIError) + if !ok { + t.Fatalf("want APIError, got %T: %v", err, err) + } + if apiErr.Code != tt.wantStatus { + t.Errorf("APIError.Code = %d, want %d", apiErr.Code, tt.wantStatus) + } + } + if got := atomic.LoadInt32(&calls); got != tt.wantCalls { + t.Errorf("calls = %d, want %d", got, tt.wantCalls) + } + }) + } +} + +func TestDoRequestWithRetry_ContextCancellation(t *testing.T) { + var calls int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer ts.Close() + + ac := &apiClient{ + clientConfig: &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: ts.URL}, + HTTPClient: ts.Client(), + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + opts := &HTTPRetryOptions{ + Attempts: 10, + InitialDelay: 200 * time.Millisecond, + MaxDelay: time.Second, + ExpBase: 2, + } + _, err := sendRequest(ctx, ac, "foo", http.MethodPost, + map[string]any{"k": "v"}, + &HTTPOptions{BaseURL: ts.URL, RetryOptions: opts}) + if err == nil { + t.Fatal("sendRequest() returned nil error, want cancellation error") + } + if got := atomic.LoadInt32(&calls); got > 2 { + t.Errorf("calls = %d, want at most 2 before cancel", got) + } +} + +func TestDoRequestWithRetry_TransportError(t *testing.T) { + // Stand up a server then close it so connections fail. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + url := ts.URL + ts.Close() + + ac := &apiClient{ + clientConfig: &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: url}, + HTTPClient: &http.Client{Timeout: 50 * time.Millisecond}, + }, + } + _, err := sendRequest(context.Background(), ac, "foo", http.MethodPost, + map[string]any{"k": "v"}, + &HTTPOptions{BaseURL: url, RetryOptions: fastRetry(3)}) + if err == nil { + t.Fatal("sendRequest() returned nil error, want transport error") + } + if !strings.Contains(err.Error(), "doRequest") { + t.Errorf("error = %q, want doRequest-wrapped error", err.Error()) + } +} + +func TestBuildRequest_BodyIsRewindable(t *testing.T) { + ac := &apiClient{ + clientConfig: &ClientConfig{ + HTTPOptions: HTTPOptions{BaseURL: "http://example.com"}, + HTTPClient: &http.Client{}, + }, + } + req, _, err := buildRequest(context.Background(), ac, "foo", + map[string]any{"k": "v"}, http.MethodPost, + &HTTPOptions{BaseURL: "http://example.com"}) + if err != nil { + t.Fatalf("buildRequest() error = %v", err) + } + if req.GetBody == nil { + t.Fatal("req.GetBody = nil, want non-nil so retry can rewind body") + } + body, err := req.GetBody() + if err != nil { + t.Fatalf("GetBody() error = %v", err) + } + defer body.Close() + buf := make([]byte, 64) + n, _ := body.Read(buf) + if !strings.Contains(string(buf[:n]), `"k"`) { + t.Errorf("rewound body = %q, want it to contain key", string(buf[:n])) + } +} diff --git a/types.go b/types.go index 451ea648..a135ae99 100644 --- a/types.go +++ b/types.go @@ -1649,6 +1649,27 @@ type HTTPOptions struct { // It is executed after ExtraBody has been merged, offering more advanced // control over the request body than the static ExtraBody. ExtrasRequestProvider ExtrasRequestProvider `json:"-"` + // Optional. RetryOptions configures automatic retries on transient HTTP + // failures. If nil, no retry is performed. + RetryOptions *HTTPRetryOptions `json:"retryOptions,omitempty"` +} + +// HTTP retry options to be used in each of the requests. +type HTTPRetryOptions struct { + // Optional. Maximum number of attempts, including the original request. + // If 0 or 1, no retries are performed. If not specified, defaults to 5. + Attempts int `json:"attempts,omitempty"` + // Optional. Initial delay before the first retry. If not specified, defaults to 1 second. + InitialDelay time.Duration `json:"initialDelay,omitempty"` + // Optional. Maximum delay between retries. If not specified, defaults to 60 seconds. + MaxDelay time.Duration `json:"maxDelay,omitempty"` + // Optional. Multiplier by which the delay grows after each attempt. If not specified, defaults to 2. + ExpBase float64 `json:"expBase,omitempty"` + // Optional. Maximum random jitter added to each delay. If not specified, defaults to 1 second. + Jitter time.Duration `json:"jitter,omitempty"` + // Optional. HTTP status codes that should trigger a retry. If not specified, + // defaults to 408, 429, 500, 502, 503, 504. + HTTPStatusCodes []int `json:"httpStatusCodes,omitempty"` } // ExtrasRequestProvider provides a way to dynamically modify the request body