diff --git a/api/hook_test.go b/api/hook_test.go index 1362647a..4f952bb1 100644 --- a/api/hook_test.go +++ b/api/hook_test.go @@ -114,10 +114,11 @@ func TestHookRetry(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ assert.EqualValues(t, 0, r.ContentLength) + // 503 is retriable; eventual 200 succeeds on the third attempt. if callCount == 3 { w.WriteHeader(http.StatusOK) } else { - w.WriteHeader(http.StatusBadRequest) + w.WriteHeader(http.StatusServiceUnavailable) } })) defer svr.Close() @@ -143,6 +144,66 @@ func TestHookRetry(t *testing.T) { assert.Equal(t, 3, callCount) } +func TestHookNonRetriable4xxDoesNotRetry(t *testing.T) { + var callCount int + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusBadRequest) + })) + defer svr.Close() + localhost := removeLocalhostFromPrivateIPBlock() + defer unshiftPrivateIPBlock(localhost) + + config := &conf.WebhookConfig{ + URL: svr.URL, + Retries: 3, + } + w := Webhook{ + WebhookConfig: config, + } + _, err := w.trigger() + require.Error(t, err) + + herr, ok := err.(*HTTPError) + require.True(t, ok, "expected an *HTTPError, got %T", err) + assert.Equal(t, http.StatusBadRequest, herr.Code) + + // 4xx (other than 401, which is the deny path) must not retry. + assert.Equal(t, 1, callCount) +} + +func TestHookRetries429(t *testing.T) { + var callCount int + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + if callCount == 3 { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusTooManyRequests) + } + })) + defer svr.Close() + localhost := removeLocalhostFromPrivateIPBlock() + defer unshiftPrivateIPBlock(localhost) + + config := &conf.WebhookConfig{ + URL: svr.URL, + Retries: 3, + } + w := Webhook{ + WebhookConfig: config, + } + b, err := w.trigger() + defer func() { + if b != nil { + b.Close() + } + }() + require.NoError(t, err) + + assert.Equal(t, 3, callCount) +} + func TestHookTimeout(t *testing.T) { var callCount int svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/api/hooks.go b/api/hooks.go index 80242246..751e39ca 100644 --- a/api/hooks.go +++ b/api/hooks.go @@ -80,6 +80,9 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { client.Transport = SafeRoundtripper(client.Transport, hooklog) for i := 0; i < w.Retries; i++ { + if i > 0 { + time.Sleep(backoffDelay(i)) + } hooklog = hooklog.WithField("attempt", i+1) hooklog.Info("Starting to perform signup hook request") @@ -134,9 +137,17 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { body = rsp.Body } return body, nil - default: - rspLog.Infof("Bad response for webhook %d in %s", rsp.StatusCode, dur) } + + if rsp.StatusCode == http.StatusTooManyRequests || rsp.StatusCode >= 500 { + rspLog.Infof("Retriable response from webhook %d in %s", rsp.StatusCode, dur) + closeBody(rsp) + continue + } + + rspLog.Infof("Non-retriable response from webhook %d in %s", rsp.StatusCode, dur) + closeBody(rsp) + return nil, httpError(rsp.StatusCode, "Webhook returned status %d", rsp.StatusCode) } hooklog.Infof("Failed to process webhook for %s after %d attempts", w.URL, w.Retries) @@ -314,3 +325,13 @@ type connectionWatcher struct { func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { c.gotConn = true } + +func backoffDelay(attempt int) time.Duration { + const base = 100 * time.Millisecond + const max = 2 * time.Second + delay := base * time.Duration(1<<(attempt-1)) + if delay > max { + delay = max + } + return delay +}