diff --git a/support/http/httptest/client_expectation.go b/support/http/httptest/client_expectation.go index f056ae7ebf..474f691285 100644 --- a/support/http/httptest/client_expectation.go +++ b/support/http/httptest/client_expectation.go @@ -1,6 +1,7 @@ package httptest import ( + "fmt" "net/http" "net/url" "strconv" @@ -85,6 +86,37 @@ func (ce *ClientExpectation) ReturnStringWithHeader( return ce.Return(httpmock.ResponderFromResponse(&cResp)) } +// ReturnMultipleResults registers multiple sequential responses for a given client expectation. +// Useful for testing retries +func (ce *ClientExpectation) ReturnMultipleResults(responseSets []ResponseData) *ClientExpectation { + var allResponses []httpmock.Responder + for _, response := range responseSets { + resp := http.Response{ + Status: strconv.Itoa(response.Status), + StatusCode: response.Status, + Body: httpmock.NewRespBodyFromString(response.Body), + Header: response.Header, + } + allResponses = append(allResponses, httpmock.ResponderFromResponse(&resp)) + } + responseIndex := 0 + ce.Client.MockTransport.RegisterResponder( + ce.Method, + ce.URL, + func(req *http.Request) (*http.Response, error) { + if responseIndex >= len(allResponses) { + panic(fmt.Sprintf("no responses available")) + } + + resp := allResponses[responseIndex] + responseIndex++ + return resp(req) + }, + ) + + return ce +} + // ReturnJSONWithHeader causes this expectation to resolve to a json-based body with the provided // status code and response header. Panics when the provided body cannot be encoded to JSON. func (ce *ClientExpectation) ReturnJSONWithHeader( diff --git a/support/http/httptest/main.go b/support/http/httptest/main.go index 47a00b1991..18b986ba1b 100644 --- a/support/http/httptest/main.go +++ b/support/http/httptest/main.go @@ -67,3 +67,9 @@ func NewServer(t *testing.T, handler http.Handler) *Server { Expect: httpexpect.New(t, server.URL), } } + +type ResponseData struct { + Status int + Body string + Header http.Header +} diff --git a/utils/apiclient/client.go b/utils/apiclient/client.go index 5b39da610b..6f89443b9f 100644 --- a/utils/apiclient/client.go +++ b/utils/apiclient/client.go @@ -1,98 +1,85 @@ package apiclient import ( - "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" + "time" "github.com/pkg/errors" ) -func (c *APIClient) createRequestBody(endpoint string, queryParams url.Values) (*http.Request, error) { - fullURL := c.url(endpoint, queryParams) - req, err := http.NewRequest("GET", fullURL, nil) - if err != nil { - return nil, errors.Wrap(err, "http GET request creation failed") - } - return req, nil +const ( + maxRetries = 5 + initialBackoff = 1 * time.Second +) + +func isRetryableStatusCode(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable } -func (c *APIClient) callAPI(req *http.Request) (interface{}, error) { - client := c.HTTP - if client == nil { - client = &http.Client{} - } +func (c *APIClient) GetURL(endpoint string, qstr url.Values) string { + return fmt.Sprintf("%s/%s?%s", c.BaseURL, endpoint, qstr.Encode()) +} - resp, err := client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "http GET request failed") +func (c *APIClient) CallAPI(reqParams RequestParams) (interface{}, error) { + if reqParams.QueryParams == nil { + reqParams.QueryParams = url.Values{} } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d", resp.StatusCode) + if reqParams.Headers == nil { + reqParams.Headers = map[string]interface{}{} } - body, err := ioutil.ReadAll(resp.Body) + url := c.GetURL(reqParams.Endpoint, reqParams.QueryParams) + reqBody, err := CreateRequestBody(reqParams.RequestType, url) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return nil, errors.Wrap(err, "http request creation failed") } - var result interface{} - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + SetAuthHeaders(reqBody, c.authType, c.authHeaders) + SetHeaders(reqBody, reqParams.Headers) + client := c.HTTP + if client == nil { + client = &http.Client{} } - return result, nil -} - -func setHeaders(req *http.Request, args map[string]interface{}) { - for key, value := range args { - strValue, ok := value.(string) - if !ok { - fmt.Printf("Skipping non-string value for header %s\n", key) - continue - } - - req.Header.Set(key, strValue) - } -} + var result interface{} + retries := 0 -func setAuthHeaders(req *http.Request, authType string, args map[string]interface{}) error { - switch authType { - case "basic": - username, ok := args["username"].(string) - if !ok { - return fmt.Errorf("missing or invalid username") - } - password, ok := args["password"].(string) - if !ok { - return fmt.Errorf("missing or invalid password") + for retries <= maxRetries { + resp, err := client.Do(reqBody) + if err != nil { + return nil, errors.Wrap(err, "http request failed") } - - authHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) - setHeaders(req, map[string]interface{}{ - "Authorization": authHeader, - }) - - case "api_key": - apiKey, ok := args["api_key"].(string) - if !ok { - return fmt.Errorf("missing or invalid API key") + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + return result, nil + } else if isRetryableStatusCode(resp.StatusCode) { + retries++ + backoffDuration := initialBackoff * time.Duration(1<