diff --git a/middleware/middleware.go b/middleware/middleware.go index b9e95e3..a5612d5 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -5,10 +5,10 @@ import ( "net/http" ) -//go:generate moq -out mocks/repeater.go -pkg mocks -skip-ensure -fmt goimports . RepeaterSvc -//go:generate moq -out mocks/circuit_breaker.go -pkg mocks -skip-ensure -fmt goimports . CircuitBreakerSvc -//go:generate moq -out mocks/logger.go -pkg mocks -skip-ensure -fmt goimports logger Service:LoggerSvc -//go:generate moq -out mocks/cache.go -pkg mocks -skip-ensure -fmt goimports cache Service:CacheSvc +//go:generate moq -out mocks/repeater.go -pkg mocks -skip-ensure -with-resets -fmt goimports . RepeaterSvc +//go:generate moq -out mocks/circuit_breaker.go -pkg mocks -skip-ensure -with-resets -fmt goimports . CircuitBreakerSvc +//go:generate moq -out mocks/logger.go -pkg mocks -skip-ensure -with-resets -fmt goimports logger Service:LoggerSvc +//go:generate moq -out mocks/cache.go -pkg mocks -skip-ensure -with-resets -fmt goimports cache Service:CacheSvc // RoundTripperHandler is a type for middleware handler type RoundTripperHandler func(http.RoundTripper) http.RoundTripper diff --git a/middleware/mocks/cache.go b/middleware/mocks/cache.go index 86d65f2..f9ede72 100644 --- a/middleware/mocks/cache.go +++ b/middleware/mocks/cache.go @@ -9,19 +9,19 @@ import ( // CacheSvc is a mock implementation of cache.Service. // -// func TestSomethingThatUsesService(t *testing.T) { +// func TestSomethingThatUsesService(t *testing.T) { // -// // make and configure a mocked cache.Service -// mockedService := &CacheSvc{ -// GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { -// panic("mock out the Get method") -// }, -// } +// // make and configure a mocked cache.Service +// mockedService := &CacheSvc{ +// GetFunc: func(key string, fn func() (interface{}, error)) (interface{}, error) { +// panic("mock out the Get method") +// }, +// } // -// // use mockedService in code that requires cache.Service -// // and then make assertions. +// // use mockedService in code that requires cache.Service +// // and then make assertions. // -// } +// } type CacheSvc struct { // GetFunc mocks the Get method. GetFunc func(key string, fn func() (interface{}, error)) (interface{}, error) @@ -59,7 +59,8 @@ func (mock *CacheSvc) Get(key string, fn func() (interface{}, error)) (interface // GetCalls gets all the calls that were made to Get. // Check the length with: -// len(mockedService.GetCalls()) +// +// len(mockedService.GetCalls()) func (mock *CacheSvc) GetCalls() []struct { Key string Fn func() (interface{}, error) @@ -73,3 +74,17 @@ func (mock *CacheSvc) GetCalls() []struct { mock.lockGet.RUnlock() return calls } + +// ResetGetCalls reset all the calls that were made to Get. +func (mock *CacheSvc) ResetGetCalls() { + mock.lockGet.Lock() + mock.calls.Get = nil + mock.lockGet.Unlock() +} + +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *CacheSvc) ResetCalls() { + mock.lockGet.Lock() + mock.calls.Get = nil + mock.lockGet.Unlock() +} diff --git a/middleware/mocks/circuit_breaker.go b/middleware/mocks/circuit_breaker.go index f9f6c15..6e0d0c9 100644 --- a/middleware/mocks/circuit_breaker.go +++ b/middleware/mocks/circuit_breaker.go @@ -9,19 +9,19 @@ import ( // CircuitBreakerSvcMock is a mock implementation of middleware.CircuitBreakerSvc. // -// func TestSomethingThatUsesCircuitBreakerSvc(t *testing.T) { +// func TestSomethingThatUsesCircuitBreakerSvc(t *testing.T) { // -// // make and configure a mocked middleware.CircuitBreakerSvc -// mockedCircuitBreakerSvc := &CircuitBreakerSvcMock{ -// ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { -// panic("mock out the Execute method") -// }, -// } +// // make and configure a mocked middleware.CircuitBreakerSvc +// mockedCircuitBreakerSvc := &CircuitBreakerSvcMock{ +// ExecuteFunc: func(req func() (interface{}, error)) (interface{}, error) { +// panic("mock out the Execute method") +// }, +// } // -// // use mockedCircuitBreakerSvc in code that requires middleware.CircuitBreakerSvc -// // and then make assertions. +// // use mockedCircuitBreakerSvc in code that requires middleware.CircuitBreakerSvc +// // and then make assertions. // -// } +// } type CircuitBreakerSvcMock struct { // ExecuteFunc mocks the Execute method. ExecuteFunc func(req func() (interface{}, error)) (interface{}, error) @@ -55,7 +55,8 @@ func (mock *CircuitBreakerSvcMock) Execute(req func() (interface{}, error)) (int // ExecuteCalls gets all the calls that were made to Execute. // Check the length with: -// len(mockedCircuitBreakerSvc.ExecuteCalls()) +// +// len(mockedCircuitBreakerSvc.ExecuteCalls()) func (mock *CircuitBreakerSvcMock) ExecuteCalls() []struct { Req func() (interface{}, error) } { @@ -67,3 +68,17 @@ func (mock *CircuitBreakerSvcMock) ExecuteCalls() []struct { mock.lockExecute.RUnlock() return calls } + +// ResetExecuteCalls reset all the calls that were made to Execute. +func (mock *CircuitBreakerSvcMock) ResetExecuteCalls() { + mock.lockExecute.Lock() + mock.calls.Execute = nil + mock.lockExecute.Unlock() +} + +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *CircuitBreakerSvcMock) ResetCalls() { + mock.lockExecute.Lock() + mock.calls.Execute = nil + mock.lockExecute.Unlock() +} diff --git a/middleware/mocks/logger.go b/middleware/mocks/logger.go index 153fe31..112a581 100644 --- a/middleware/mocks/logger.go +++ b/middleware/mocks/logger.go @@ -9,19 +9,19 @@ import ( // LoggerSvc is a mock implementation of logger.Service. // -// func TestSomethingThatUsesService(t *testing.T) { +// func TestSomethingThatUsesService(t *testing.T) { // -// // make and configure a mocked logger.Service -// mockedService := &LoggerSvc{ -// LogfFunc: func(format string, args ...interface{}) { -// panic("mock out the Logf method") -// }, -// } +// // make and configure a mocked logger.Service +// mockedService := &LoggerSvc{ +// LogfFunc: func(format string, args ...interface{}) { +// panic("mock out the Logf method") +// }, +// } // -// // use mockedService in code that requires logger.Service -// // and then make assertions. +// // use mockedService in code that requires logger.Service +// // and then make assertions. // -// } +// } type LoggerSvc struct { // LogfFunc mocks the Logf method. LogfFunc func(format string, args ...interface{}) @@ -59,7 +59,8 @@ func (mock *LoggerSvc) Logf(format string, args ...interface{}) { // LogfCalls gets all the calls that were made to Logf. // Check the length with: -// len(mockedService.LogfCalls()) +// +// len(mockedService.LogfCalls()) func (mock *LoggerSvc) LogfCalls() []struct { Format string Args []interface{} @@ -73,3 +74,17 @@ func (mock *LoggerSvc) LogfCalls() []struct { mock.lockLogf.RUnlock() return calls } + +// ResetLogfCalls reset all the calls that were made to Logf. +func (mock *LoggerSvc) ResetLogfCalls() { + mock.lockLogf.Lock() + mock.calls.Logf = nil + mock.lockLogf.Unlock() +} + +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *LoggerSvc) ResetCalls() { + mock.lockLogf.Lock() + mock.calls.Logf = nil + mock.lockLogf.Unlock() +} diff --git a/middleware/mocks/repeater.go b/middleware/mocks/repeater.go index b430bb5..63335ed 100644 --- a/middleware/mocks/repeater.go +++ b/middleware/mocks/repeater.go @@ -10,19 +10,19 @@ import ( // RepeaterSvcMock is a mock implementation of middleware.RepeaterSvc. // -// func TestSomethingThatUsesRepeaterSvc(t *testing.T) { +// func TestSomethingThatUsesRepeaterSvc(t *testing.T) { // -// // make and configure a mocked middleware.RepeaterSvc -// mockedRepeaterSvc := &RepeaterSvcMock{ -// DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { -// panic("mock out the Do method") -// }, -// } +// // make and configure a mocked middleware.RepeaterSvc +// mockedRepeaterSvc := &RepeaterSvcMock{ +// DoFunc: func(ctx context.Context, fun func() error, errs ...error) error { +// panic("mock out the Do method") +// }, +// } // -// // use mockedRepeaterSvc in code that requires middleware.RepeaterSvc -// // and then make assertions. +// // use mockedRepeaterSvc in code that requires middleware.RepeaterSvc +// // and then make assertions. // -// } +// } type RepeaterSvcMock struct { // DoFunc mocks the Do method. DoFunc func(ctx context.Context, fun func() error, errs ...error) error @@ -64,7 +64,8 @@ func (mock *RepeaterSvcMock) Do(ctx context.Context, fun func() error, errs ...e // DoCalls gets all the calls that were made to Do. // Check the length with: -// len(mockedRepeaterSvc.DoCalls()) +// +// len(mockedRepeaterSvc.DoCalls()) func (mock *RepeaterSvcMock) DoCalls() []struct { Ctx context.Context Fun func() error @@ -80,3 +81,17 @@ func (mock *RepeaterSvcMock) DoCalls() []struct { mock.lockDo.RUnlock() return calls } + +// ResetDoCalls reset all the calls that were made to Do. +func (mock *RepeaterSvcMock) ResetDoCalls() { + mock.lockDo.Lock() + mock.calls.Do = nil + mock.lockDo.Unlock() +} + +// ResetCalls reset all the calls that were made to all mocked methods. +func (mock *RepeaterSvcMock) ResetCalls() { + mock.lockDo.Lock() + mock.calls.Do = nil + mock.lockDo.Unlock() +} diff --git a/middleware/mocks/roundtripper.go b/middleware/mocks/roundtripper.go index 5b33ed7..b7ea114 100644 --- a/middleware/mocks/roundtripper.go +++ b/middleware/mocks/roundtripper.go @@ -21,3 +21,8 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r *RoundTripper) Calls() int { return int(atomic.LoadInt32(&r.calls)) } + +// ResetCalls resets calls counter +func (r *RoundTripper) ResetCalls() { + atomic.StoreInt32(&r.calls, 0) +} diff --git a/middleware/repeater.go b/middleware/repeater.go index e2f91b0..7bdac1b 100644 --- a/middleware/repeater.go +++ b/middleware/repeater.go @@ -29,6 +29,11 @@ func Repeater(repeater RepeaterSvc, failOnCodes ...int) RoundTripperHandler { if err != nil { return err } + // no explicit codes provided, fail on any 4xx or 5xx + if len(failOnCodes) == 0 && resp.StatusCode >= 400 { + return errors.New(resp.Status) + } + // fail on provided codes only for _, fc := range failOnCodes { if resp.StatusCode == fc { return errors.New(resp.Status) diff --git a/middleware/repeater_test.go b/middleware/repeater_test.go index 70d24e8..f3752ac 100644 --- a/middleware/repeater_test.go +++ b/middleware/repeater_test.go @@ -89,19 +89,30 @@ func TestRepeater_FailedStatus(t *testing.T) { return err }} - { - h := Repeater(repeater, 300, 400, 401) - + t.Run("no codes", func(t *testing.T) { + rmock.ResetCalls() + h := Repeater(repeater) req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) require.NoError(t, err) _, err = h(rmock).RoundTrip(req) require.EqualError(t, err, "repeater: 400 Bad Request") - } + assert.Equal(t, 5, rmock.Calls()) + }) - assert.Equal(t, 5, rmock.Calls()) + t.Run("with codes", func(t *testing.T) { + rmock.ResetCalls() + h := Repeater(repeater, 300, 400, 401) + req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) + require.NoError(t, err) + + _, err = h(rmock).RoundTrip(req) + require.EqualError(t, err, "repeater: 400 Bad Request") + assert.Equal(t, 5, rmock.Calls()) + }) - { + t.Run("no codes, no match", func(t *testing.T) { + rmock.ResetCalls() h := Repeater(repeater, 300, 401) req, err := http.NewRequest("GET", "http://example.com/blah", http.NoBody) @@ -110,6 +121,7 @@ func TestRepeater_FailedStatus(t *testing.T) { resp, err := h(rmock).RoundTrip(req) require.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) - } + assert.Equal(t, 1, rmock.Calls()) + }) }