From 9107aefecad5d2666ab8412274db6bca6c1f8ac9 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Tue, 9 Jan 2024 14:56:37 +0100 Subject: [PATCH 01/12] Start retry for vanished backends --- cmd/lingo/main.go | 3 +- pkg/endpoints/endpoints.go | 36 +++++++++++++-- pkg/endpoints/manager.go | 16 +++++-- pkg/proxy/handler.go | 29 ++++++++++-- pkg/proxy/handler_test.go | 83 ++++++++++++++++++++++++++++++++++ tests/integration/main_test.go | 3 +- 6 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 pkg/proxy/handler_test.go diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 7bfec8f0..d128013d 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -119,11 +119,10 @@ func run() error { metricsRegistry := prometheus.WrapRegistererWithPrefix("lingo_", metrics.Registry) queue.NewMetricsCollector(queueManager).MustRegister(metricsRegistry) - endpointManager, err := endpoints.NewManager(mgr) + endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas) if err != nil { return fmt.Errorf("setting up endpoint manager: %w", err) } - endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas // The autoscaling leader will scrape other lingo instances. // Exclude this instance from being scraped by itself. endpointManager.ExcludePods[hostname] = struct{}{} diff --git a/pkg/endpoints/endpoints.go b/pkg/endpoints/endpoints.go index 0e51083a..6312fbc4 100644 --- a/pkg/endpoints/endpoints.go +++ b/pkg/endpoints/endpoints.go @@ -2,7 +2,9 @@ package endpoints import ( "context" + "errors" "fmt" + "strings" "sync" "sync/atomic" ) @@ -16,7 +18,8 @@ func newEndpointGroup() *endpointGroup { } type endpoint struct { - inFlight *atomic.Int64 + inFlight *atomic.Int64 + terminated chan struct{} } type endpointGroup struct { @@ -104,12 +107,13 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) g.ports = ports for ip := range ips { if _, ok := g.endpoints[ip]; !ok { - g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}} + g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}, terminated: make(chan struct{})} } } - for ip := range g.endpoints { + for ip, endpoint := range g.endpoints { if _, ok := ips[ip]; !ok { delete(g.endpoints, ip) + close(endpoint.terminated) } } g.mtx.Unlock() @@ -127,3 +131,29 @@ func (g *endpointGroup) broadcastEndpoints() { close(g.bcast) g.bcast = make(chan struct{}) } + +func (e *endpointGroup) AddInflight(addr string, cancelRequest context.CancelFunc) (func(), error) { + tokens := strings.Split(addr, ":") + if len(tokens) != 2 { + return nil, errors.New("unsupported address format") + } + e.mtx.RLock() + endpoint, ok := e.endpoints[tokens[0]] + e.mtx.RUnlock() + if !ok { + return nil, errors.New("unsupported endpoint address") + } + endpoint.inFlight.Add(1) + done := make(chan struct{}) + go func() { + select { + case <-done: + case <-endpoint.terminated: + cancelRequest() + } + }() + return func() { + close(done) + endpoint.inFlight.Add(-1) + }, nil +} diff --git a/pkg/endpoints/manager.go b/pkg/endpoints/manager.go index bd2d3e39..2b1fa63b 100644 --- a/pkg/endpoints/manager.go +++ b/pkg/endpoints/manager.go @@ -12,11 +12,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -func NewManager(mgr ctrl.Manager) (*Manager, error) { +func NewManager(mgr ctrl.Manager, endpointSizeCallback func(deploymentName string, replicas int)) (*Manager, error) { r := &Manager{} r.Client = mgr.GetClient() r.endpoints = map[string]*endpointGroup{} r.ExcludePods = map[string]struct{}{} + r.EndpointSizeCallback = endpointSizeCallback if err := r.SetupWithManager(mgr); err != nil { return nil, err } @@ -86,6 +87,11 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, } } + r.SetEndpoints(serviceName, ips, ports) + return ctrl.Result{}, nil +} + +func (r *Manager) SetEndpoints(serviceName string, ips map[string]struct{}, ports map[string]int32) { priorLen := r.getEndpoints(serviceName).lenIPs() r.getEndpoints(serviceName).setIPs(ips, ports) @@ -95,8 +101,6 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, // replicas by something else. r.EndpointSizeCallback(serviceName, len(ips)) } - - return ctrl.Result{}, nil } func (r *Manager) getEndpoints(service string) *endpointGroup { @@ -122,3 +126,9 @@ func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string func (r *Manager) GetAllHosts(service, portName string) []string { return r.getEndpoints(service).getAllHosts(portName) } + +func (r *Manager) RegisterInFlight(ctx context.Context, service string, hostAddr string) (context.Context, func(), error) { + ctx, cancel := context.WithCancel(ctx) + completed, err := r.getEndpoints(service).AddInflight(hostAddr, cancel) + return ctx, completed, err +} diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index a51d9d3b..d6a4fbe3 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -16,20 +16,24 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - "github.com/substratusai/lingo/pkg/deployments" "github.com/substratusai/lingo/pkg/endpoints" "github.com/substratusai/lingo/pkg/queue" ) +type deploymentSource interface { + ResolveDeployment(model string) (string, bool) + AtLeastOne(deploymentName string) +} + // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - Deployments *deployments.Manager + Deployments deploymentSource Endpoints *endpoints.Manager Queues *queue.Manager } -func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { +func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues} } @@ -108,7 +112,24 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // TODO: Avoid creating new reverse proxies for each request. // TODO: Consider implementing a round robin scheme. log.Printf("Proxying request to host %v: %v\n", host, id) - newReverseProxy(host).ServeHTTP(w, proxyRequest) + middleware := withCancelDeadTargets(h.Endpoints, deploy, host) + middleware(newReverseProxy(host)).ServeHTTP(w, proxyRequest) +} + +func withCancelDeadTargets(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc { + return func(other http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + newCtx, done, err := endpoints.RegisterInFlight(r.Context(), deploy, host) + if err != nil { + log.Printf("error registering in-flight request: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer done() + + other.ServeHTTP(w, r.Clone(newCtx)) + } + } } // parseModel parses the model name from the request diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go new file mode 100644 index 00000000..83d44fcf --- /dev/null +++ b/pkg/proxy/handler_test.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substratusai/lingo/pkg/endpoints" + "github.com/substratusai/lingo/pkg/queue" +) + +func TestProxy(t *testing.T) { + specs := map[string]struct { + request *http.Request + expCode int + }{ + "ok": { + request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), + expCode: http.StatusBadGateway, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + deplMgr := mockDeploymentSource{ + ResolveDeploymentFn: func(model string) (string, bool) { + return "my-deployment", true + }, + AtLeastOneFn: func(deploymentName string) {}, + } + em, err := endpoints.NewManager(&fakeManager{}, func(deploymentName string, replicas int) {}) + require.NoError(t, err) + em.SetEndpoints("my-deployment", map[string]struct{}{"my-ip": {}}, map[string]int32{"my-port": 8080}) + h := NewHandler(deplMgr, em, queue.NewManager(10)) + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080}) + time.Sleep(time.Millisecond) + w.WriteHeader(999) + })) + recorder := httptest.NewRecorder() + + AdditionalProxyRewrite = func(r *httputil.ProxyRequest) { + r.SetURL(&url.URL{Scheme: "http", Host: svr.Listener.Addr().String()}) + } + + // when + // newCtx, cancel := context.WithCancel(spec.request.Context()) + // cancel() + // newCtx, _ := context.WithTimeout(spec.request.Context(), time.Nanosecond) + newCtx := context.Background() + + h.ServeHTTP(recorder, spec.request.Clone(newCtx)) + // then + assert.Equal(t, spec.expCode, recorder.Code) + }) + } +} + +type mockDeploymentSource struct { + ResolveDeploymentFn func(model string) (string, bool) + AtLeastOneFn func(deploymentName string) +} + +func (m mockDeploymentSource) ResolveDeployment(model string) (string, bool) { + if m.ResolveDeploymentFn == nil { + panic("not expected to be called") + } + return m.ResolveDeploymentFn(model) +} + +func (m mockDeploymentSource) AtLeastOne(deploymentName string) { + if m.AtLeastOneFn == nil { + panic("not expected to be called") + } + m.AtLeastOneFn(deploymentName) +} diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 74697559..7ec22fb7 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -81,9 +81,8 @@ func TestMain(m *testing.M) { const concurrencyPerReplica = 1 queueManager = queue.NewManager(concurrencyPerReplica) - endpointManager, err := endpoints.NewManager(mgr) + endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas) requireNoError(err) - endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas deploymentManager, err := deployments.NewManager(mgr) requireNoError(err) From 4424ef710935f700f1e055049e2430d66691de91 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Wed, 10 Jan 2024 10:40:34 +0100 Subject: [PATCH 02/12] Add retry middleware --- cmd/lingo/main.go | 4 +- pkg/proxy/metrics.go | 7 ++- pkg/proxy/metrics_test.go | 9 ++-- pkg/proxy/middleware.go | 77 ++++++++++++++++++++++++++++++ pkg/proxy/middleware_test.go | 92 ++++++++++++++++++++++++++++++++++++ 5 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 pkg/proxy/middleware.go create mode 100644 pkg/proxy/middleware_test.go diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index d128013d..4e19f0c7 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -113,6 +113,7 @@ func run() error { if err != nil { return fmt.Errorf("getting hostname: %v", err) } + le := leader.NewElection(clientset, hostname, namespace) queueManager := queue.NewManager(concurrencyPerReplica) @@ -152,7 +153,8 @@ func run() error { go autoscaler.Start() proxy.MustRegister(metricsRegistry) - proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager) + var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager) + proxyHandler = proxy.NewRetryMiddleware(3, proxyHandler) proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 5cd229e7..45d20e8e 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -12,8 +12,13 @@ var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Buckets: prometheus.DefBuckets, }, []string{"model", "status_code"}) +var totalRetries = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "http_request_retry_total", + Help: "Number of HTTP request retries.", +}) + func MustRegister(r prometheus.Registerer) { - r.MustRegister(httpDuration) + r.MustRegister(httpDuration, totalRetries) } // captureStatusResponseWriter is a custom HTTP response writer that captures the status code. diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index 5940e445..cf0d5c3e 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -65,10 +65,11 @@ func TestMetrics(t *testing.T) { assert.Equal(t, spec.expCode, recorder.Code) gathered, err := registry.Gather() require.NoError(t, err) - require.Len(t, gathered, 1) - require.Len(t, gathered[0].Metric, 1) - assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount()) - assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label)) + require.Len(t, gathered, 2) + require.Equal(t, "http_response_time_seconds", *gathered[1].Name) + require.Len(t, gathered[1].Metric, 1) + assert.NotEmpty(t, gathered[1].Metric[0].GetHistogram().GetSampleCount()) + assert.Equal(t, spec.expLabels, toMap(gathered[1].Metric[0].Label)) }) } } diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go new file mode 100644 index 00000000..fd09de85 --- /dev/null +++ b/pkg/proxy/middleware.go @@ -0,0 +1,77 @@ +package proxy + +import ( + "bytes" + "log" + "math/rand" + "net/http" + "time" +) + +var _ http.Handler = &RetryMiddleware{} + +type RetryMiddleware struct { + other http.Handler + MaxRetries int + rSource *rand.Rand +} + +func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { + if maxRetries < 1 { + panic("invalid retries") + } + return &RetryMiddleware{ + other: other, + MaxRetries: maxRetries, + rSource: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + var capturedResp *responseBuffer + for i := 0; ; i++ { + capturedResp = &responseBuffer{ + header: make(http.Header), + body: bytes.NewBuffer([]byte{}), + } + // call next handler in chain + r.other.ServeHTTP(capturedResp, request.Clone(request.Context())) + + if i == r.MaxRetries || // max retries reached + request.Context().Err() != nil || // abort early on timeout, context cancel + capturedResp.status != http.StatusBadGateway && + capturedResp.status != http.StatusServiceUnavailable { + break + } + totalRetries.Inc() + // Exponential backoff + jitter := time.Duration(r.rSource.Intn(50)) + time.Sleep(time.Millisecond*time.Duration(1< Date: Thu, 11 Jan 2024 10:06:58 +0100 Subject: [PATCH 03/12] Better naming --- pkg/proxy/middleware.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index fd09de85..93715957 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -11,9 +11,9 @@ import ( var _ http.Handler = &RetryMiddleware{} type RetryMiddleware struct { - other http.Handler - MaxRetries int - rSource *rand.Rand + nextHandler http.Handler + MaxRetries int + rSource *rand.Rand } func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { @@ -21,9 +21,9 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { panic("invalid retries") } return &RetryMiddleware{ - other: other, - MaxRetries: maxRetries, - rSource: rand.New(rand.NewSource(time.Now().UnixNano())), + nextHandler: other, + MaxRetries: maxRetries, + rSource: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -35,7 +35,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req body: bytes.NewBuffer([]byte{}), } // call next handler in chain - r.other.ServeHTTP(capturedResp, request.Clone(request.Context())) + r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) if i == r.MaxRetries || // max retries reached request.Context().Err() != nil || // abort early on timeout, context cancel From 9b99c6c5788f77a61b08fea1d69a891b594879c8 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Mon, 15 Jan 2024 13:52:50 +0100 Subject: [PATCH 04/12] Apply review feedback; refactorings --- pkg/endpoints/endpoints.go | 13 +---- pkg/endpoints/manager.go | 6 +-- pkg/proxy/handler.go | 19 ++++--- pkg/proxy/handler_test.go | 14 ++--- pkg/proxy/middleware.go | 102 +++++++++++++++++++++++++---------- pkg/proxy/middleware_test.go | 6 +++ 6 files changed, 100 insertions(+), 60 deletions(-) diff --git a/pkg/endpoints/endpoints.go b/pkg/endpoints/endpoints.go index 6312fbc4..c3910e3b 100644 --- a/pkg/endpoints/endpoints.go +++ b/pkg/endpoints/endpoints.go @@ -132,28 +132,19 @@ func (g *endpointGroup) broadcastEndpoints() { g.bcast = make(chan struct{}) } -func (e *endpointGroup) AddInflight(addr string, cancelRequest context.CancelFunc) (func(), error) { +func (e *endpointGroup) AddInflight(addr string) (func(), error) { tokens := strings.Split(addr, ":") if len(tokens) != 2 { return nil, errors.New("unsupported address format") } e.mtx.RLock() + defer e.mtx.RUnlock() endpoint, ok := e.endpoints[tokens[0]] - e.mtx.RUnlock() if !ok { return nil, errors.New("unsupported endpoint address") } endpoint.inFlight.Add(1) - done := make(chan struct{}) - go func() { - select { - case <-done: - case <-endpoint.terminated: - cancelRequest() - } - }() return func() { - close(done) endpoint.inFlight.Add(-1) }, nil } diff --git a/pkg/endpoints/manager.go b/pkg/endpoints/manager.go index 2b1fa63b..414f888b 100644 --- a/pkg/endpoints/manager.go +++ b/pkg/endpoints/manager.go @@ -127,8 +127,6 @@ func (r *Manager) GetAllHosts(service, portName string) []string { return r.getEndpoints(service).getAllHosts(portName) } -func (r *Manager) RegisterInFlight(ctx context.Context, service string, hostAddr string) (context.Context, func(), error) { - ctx, cancel := context.WithCancel(ctx) - completed, err := r.getEndpoints(service).AddInflight(hostAddr, cancel) - return ctx, completed, err +func (r *Manager) RegisterInFlight(service string, hostAddr string) (func(), error) { + return r.getEndpoints(service).AddInflight(hostAddr) } diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index d6a4fbe3..72762953 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -109,17 +109,24 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } log.Printf("Got host: %v, id: %v\n", host, id) + done, err := h.Endpoints.RegisterInFlight(deploy, host) + if err != nil { + log.Printf("error registering in-flight request: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer done() + + log.Printf("Proxying request to host %v: %v\n", host, id) // TODO: Avoid creating new reverse proxies for each request. // TODO: Consider implementing a round robin scheme. - log.Printf("Proxying request to host %v: %v\n", host, id) - middleware := withCancelDeadTargets(h.Endpoints, deploy, host) - middleware(newReverseProxy(host)).ServeHTTP(w, proxyRequest) + newReverseProxy(host).ServeHTTP(w, proxyRequest) } -func withCancelDeadTargets(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc { +func withInflightCounted(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc { return func(other http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - newCtx, done, err := endpoints.RegisterInFlight(r.Context(), deploy, host) + done, err := endpoints.RegisterInFlight(deploy, host) if err != nil { log.Printf("error registering in-flight request: %v", err) w.WriteHeader(http.StatusInternalServerError) @@ -127,7 +134,7 @@ func withCancelDeadTargets(endpoints *endpoints.Manager, deploy string, host str } defer done() - other.ServeHTTP(w, r.Clone(newCtx)) + other.ServeHTTP(w, r) } } } diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index 83d44fcf..36fbf08b 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -1,14 +1,12 @@ package proxy import ( - "context" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,7 +21,7 @@ func TestProxy(t *testing.T) { }{ "ok": { request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), - expCode: http.StatusBadGateway, + expCode: http.StatusOK, }, } for name, spec := range specs { @@ -41,8 +39,7 @@ func TestProxy(t *testing.T) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080}) - time.Sleep(time.Millisecond) - w.WriteHeader(999) + w.WriteHeader(http.StatusOK) })) recorder := httptest.NewRecorder() @@ -51,12 +48,7 @@ func TestProxy(t *testing.T) { } // when - // newCtx, cancel := context.WithCancel(spec.request.Context()) - // cancel() - // newCtx, _ := context.WithTimeout(spec.request.Context(), time.Nanosecond) - newCtx := context.Background() - - h.ServeHTTP(recorder, spec.request.Clone(newCtx)) + h.ServeHTTP(recorder, spec.request) // then assert.Equal(t, spec.expCode, recorder.Code) }) diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 93715957..2c614fec 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -1,8 +1,7 @@ package proxy import ( - "bytes" - "log" + "io" "math/rand" "net/http" "time" @@ -28,19 +27,19 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { } func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - var capturedResp *responseBuffer + var capturedResp *responseWriterDelegator for i := 0; ; i++ { - capturedResp = &responseBuffer{ - header: make(http.Header), - body: bytes.NewBuffer([]byte{}), + capturedResp = &responseWriterDelegator{ + ResponseWriter: writer, + headerBuf: make(http.Header), + discardErrResp: i < r.MaxRetries && + request.Context().Err() == nil, // abort early on timeout, context cancel } // call next handler in chain r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) - if i == r.MaxRetries || // max retries reached - request.Context().Err() != nil || // abort early on timeout, context cancel - capturedResp.status != http.StatusBadGateway && - capturedResp.status != http.StatusServiceUnavailable { + if !capturedResp.discardErrResp || // max retries reached + !isRetryableStatusCode(capturedResp.statusCode) { break } totalRetries.Inc() @@ -48,30 +47,77 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req jitter := time.Duration(r.rSource.Intn(50)) time.Sleep(time.Millisecond*time.Duration(1<= 100 && status < 200) + } + if r.discardErrResp && isRetryableStatusCode(status) { + return + } + // copy header values to target + for k, vals := range r.headerBuf { + for _, val := range vals { + r.ResponseWriter.Header().Add(k, val) + } + } + r.ResponseWriter.WriteHeader(status) } -func (rb *responseBuffer) Write(data []byte) (int, error) { - return rb.body.Write(data) +func (r *responseWriterDelegator) Write(data []byte) (int, error) { + // ensure header is set. default is 200 in Go stdlib + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if r.discardErrResp && isRetryableStatusCode(r.statusCode) { + return io.Discard.Write(data) + } else { + return r.ResponseWriter.Write(data) + } +} + +func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) { + // ensure header is set. default is 200 in Go stdlib + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if r.discardErrResp && isRetryableStatusCode(r.statusCode) { + return io.Copy(io.Discard, re) + } else { + return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re) + } +} + +func (r *responseWriterDelegator) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } } diff --git a/pkg/proxy/middleware_test.go b/pkg/proxy/middleware_test.go index fbeea7b1..d11f1a11 100644 --- a/pkg/proxy/middleware_test.go +++ b/pkg/proxy/middleware_test.go @@ -42,6 +42,12 @@ func TestServeHTTP(t *testing.T) { respStatus: http.StatusBadGateway, expRetries: 3, }, + "not buffered on 100": { + context: func() context.Context { return context.TODO() }, + maxRetries: 3, + respStatus: http.StatusContinue, + expRetries: 0, + }, "context cancelled": { context: func() context.Context { ctx, cancel := context.WithCancel(context.TODO()) From e71df9f5dc012929b4bcefe93d880581c6d6d086 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Mon, 15 Jan 2024 14:01:52 +0100 Subject: [PATCH 05/12] Minor cleanup --- pkg/proxy/handler.go | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index 72762953..c7b18c6e 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -123,22 +123,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { newReverseProxy(host).ServeHTTP(w, proxyRequest) } -func withInflightCounted(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc { - return func(other http.Handler) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - done, err := endpoints.RegisterInFlight(deploy, host) - if err != nil { - log.Printf("error registering in-flight request: %v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer done() - - other.ServeHTTP(w, r) - } - } -} - // parseModel parses the model name from the request // returns empty string when none found or an error for failures on the proxy request object func parseModel(r *http.Request) (string, *http.Request, error) { From b74a3950fce7f729183e20f3453cb815ace3f756 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Tue, 16 Jan 2024 11:28:58 +0100 Subject: [PATCH 06/12] Lazy buffer request body for retry + reuse; refactorings --- pkg/proxy/handler.go | 14 +++- pkg/proxy/metrics.go | 25 ++++++- pkg/proxy/middleware.go | 63 +++++++++++++++- tests/integration/integration_test.go | 102 ++++++++++++++++++++------ tests/integration/main_test.go | 2 +- 5 files changed, 172 insertions(+), 34 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index c7b18c6e..c2419e60 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -129,10 +129,16 @@ func parseModel(r *http.Request) (string, *http.Request, error) { if model := r.Header.Get("X-Model"); model != "" { return model, r, nil } - // parse request body for model name, ignore errors - body, err := io.ReadAll(r.Body) - if err != nil { - return "", r, nil + var body []byte + if mb, ok := r.Body.(*lazyBodyCapturer); ok && mb.capturedBody != nil { + body = mb.capturedBody + } else { + // parse request body for model name, ignore errors + var err error + body, err = io.ReadAll(r.Body) + if err != nil { + return "", r, nil + } } var payload struct { diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 45d20e8e..48d204c8 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -24,14 +25,30 @@ func MustRegister(r prometheus.Registerer) { // captureStatusResponseWriter is a custom HTTP response writer that captures the status code. type captureStatusResponseWriter struct { http.ResponseWriter - statusCode int + statusCode int + wroteHeader bool } func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { return &captureStatusResponseWriter{ResponseWriter: responseWriter} } -func (srw *captureStatusResponseWriter) WriteHeader(code int) { - srw.statusCode = code - srw.ResponseWriter.WriteHeader(code) +func (c *captureStatusResponseWriter) WriteHeader(code int) { + c.wroteHeader = true + c.statusCode = code + c.ResponseWriter.WriteHeader(code) +} + +func (c *captureStatusResponseWriter) Write(b []byte) (int, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.Write(b) +} + +func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.(io.ReaderFrom).ReadFrom(re) } diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 2c614fec..e0b420e3 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -1,6 +1,8 @@ package proxy import ( + "bytes" + "errors" "io" "math/rand" "net/http" @@ -27,6 +29,11 @@ func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { } func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + lazyBody := &lazyBodyCapturer{ + reader: request.Body, + buf: bytes.NewBuffer([]byte{}), + } + request.Body = lazyBody var capturedResp *responseWriterDelegator for i := 0; ; i++ { capturedResp = &responseWriterDelegator{ @@ -36,8 +43,12 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req request.Context().Err() == nil, // abort early on timeout, context cancel } // call next handler in chain - r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) - + req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody) + if err != nil { + panic(err) + } + r.nextHandler.ServeHTTP(capturedResp, req) + lazyBody.Capture() if !capturedResp.discardErrResp || // max retries reached !isRetryableStatusCode(capturedResp.statusCode) { break @@ -121,3 +132,51 @@ func (r *responseWriterDelegator) Flush() { f.Flush() } } + +var ( + _ io.ReadCloser = &lazyBodyCapturer{} + _ io.WriterTo = &lazyBodyCapturer{} +) + +type lazyBodyCapturer struct { + reader io.ReadCloser + capturedBody []byte + buf *bytes.Buffer + allRead bool +} + +func (m *lazyBodyCapturer) Read(p []byte) (int, error) { + if m.allRead { + return m.reader.Read(p) + } + n, err := io.TeeReader(m.reader, m.buf).Read(p) + if errors.Is(err, io.EOF) { + m.allRead = true + } + return n, err +} + +func (m *lazyBodyCapturer) Close() error { + return m.reader.Close() +} + +func (m *lazyBodyCapturer) WriteTo(w io.Writer) (int64, error) { + if m.allRead { + return m.reader.(io.WriterTo).WriteTo(w) + } + n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf)) + if errors.Is(err, io.EOF) { + m.allRead = true + } + return n, err +} + +func (m *lazyBodyCapturer) Capture() { + m.allRead = true + if m.buf != nil { + m.capturedBody = m.buf.Bytes() + m.buf = nil + } else { + m.reader = io.NopCloser(bytes.NewReader(m.capturedBody)) + } +} diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 10fb5e45..bd4f25e7 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -43,17 +43,7 @@ func TestScaleUpAndDown(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -103,17 +93,7 @@ func TestHandleModelUndeployment(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -132,7 +112,7 @@ func TestHandleModelUndeployment(t *testing.T) { require.NoError(t, testK8sClient.Delete(testCtx, deploy)) // Check that the deployment was deleted - err = testK8sClient.Get(testCtx, client.ObjectKey{ + err := testK8sClient.Get(testCtx, client.ObjectKey{ Namespace: deploy.Namespace, Name: deploy.Name, }, deploy) @@ -151,6 +131,82 @@ func TestHandleModelUndeployment(t *testing.T) { wg.Wait() } +func TestRetryMiddleware(t *testing.T) { + const modelName = "test-model-c" + deploy := testDeployment(modelName) + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + backendRequests := &atomic.Int32{} + var serverCodes []int + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + i := backendRequests.Add(1) + code := serverCodes[i-1] + t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) + w.WriteHeader(code) + })) + + // Mock an EndpointSlice. + withMockEndpointSlice(t, testBackend, modelName) + + specs := map[string]struct { + serverCodes []int + expResultCode int + expBackendHits int32 + }{ + "max retries - succeeds": { + serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 4, + }, + "max retries - fails": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, + expResultCode: http.StatusBadGateway, + expBackendHits: 4, + }, + "non retryable error code": { + serverCodes: []int{http.StatusNotImplemented}, + expResultCode: http.StatusNotImplemented, + expBackendHits: 1, + }, + "200 status code": { + serverCodes: []int{http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 1, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // setup + serverCodes = spec.serverCodes + backendRequests.Store(0) + + // when single request sent + var wg sync.WaitGroup + sendRequest(t, &wg, modelName, spec.expResultCode) + wg.Wait() + + // then + require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries") + }) + } +} + +func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) { + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 7ec22fb7..36d8987c 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -109,7 +109,7 @@ func TestMain(m *testing.M) { Endpoints: endpointManager, Queues: queueManager, } - testServer = httptest.NewServer(handler) + testServer = httptest.NewServer(proxy.NewRetryMiddleware(3, handler)) defer testServer.Close() go func() { From 296c6c41e77528fe2082ac2ac93dd6049c49322d Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Tue, 16 Jan 2024 12:09:01 +0100 Subject: [PATCH 07/12] Configurable number of retries --- cmd/lingo/main.go | 4 +- pkg/proxy/middleware.go | 69 ++++++++++++++++++--------- tests/integration/integration_test.go | 20 +++++--- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 4e19f0c7..06b1e0a6 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -71,11 +71,13 @@ func run() error { var metricsAddr string var probeAddr string var concurrencyPerReplica int + var maxRetriesOnErr int flag.StringVar(&metricsAddr, "metrics-bind-address", ":8082", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.IntVar(&concurrencyPerReplica, "concurrency", concurrency, "the number of simultaneous requests that can be processed by each replica") flag.IntVar(&scaleDownDelay, "scale-down-delay", scaleDownDelay, "seconds to wait before scaling down") + flag.IntVar(&maxRetriesOnErr, "max-retries", 0, "max number of retries on a http error code: 502,503,504") opts := zap.Options{ Development: true, } @@ -154,7 +156,7 @@ func run() error { proxy.MustRegister(metricsRegistry) var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager) - proxyHandler = proxy.NewRetryMiddleware(3, proxyHandler) + proxyHandler = proxy.NewRetryMiddleware(maxRetriesOnErr, proxyHandler) proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index e0b420e3..73fe3872 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -12,19 +12,34 @@ import ( var _ http.Handler = &RetryMiddleware{} type RetryMiddleware struct { - nextHandler http.Handler - MaxRetries int - rSource *rand.Rand + nextHandler http.Handler + maxRetries int + rSource *rand.Rand + retryStatusCodes map[int]struct{} } -func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware { - if maxRetries < 1 { - panic("invalid retries") +// NewRetryMiddleware creates a new HTTP middleware that adds retry functionality. +// It takes the maximum number of retries, the next handler in the middleware chain, +// and an optional list of retryable status codes. +// If the maximum number of retries is 0, it returns the next handler without adding any retries. +// If the list of retryable status codes is empty, it uses a default set of status codes (http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout). +// The function creates a RetryMiddleware struct with the given parameters and returns it as an http.Handler. +func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes ...int) http.Handler { + if maxRetries == 0 { + return other + } + if len(optRetryStatusCodes) == 0 { + optRetryStatusCodes = []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout} + } + statusCodeIndex := make(map[int]struct{}, len(optRetryStatusCodes)) + for _, c := range optRetryStatusCodes { + statusCodeIndex[c] = struct{}{} } return &RetryMiddleware{ - nextHandler: other, - MaxRetries: maxRetries, - rSource: rand.New(rand.NewSource(time.Now().UnixNano())), + nextHandler: other, + maxRetries: maxRetries, + retryStatusCodes: statusCodeIndex, + rSource: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -34,12 +49,12 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req buf: bytes.NewBuffer([]byte{}), } request.Body = lazyBody - var capturedResp *responseWriterDelegator for i := 0; ; i++ { - capturedResp = &responseWriterDelegator{ - ResponseWriter: writer, - headerBuf: make(http.Header), - discardErrResp: i < r.MaxRetries && + capturedResp := &responseWriterDelegator{ + isRetryableStatusCode: r.isRetryableStatusCode, + ResponseWriter: writer, + headerBuf: make(http.Header), + discardErrResp: i < r.maxRetries && request.Context().Err() == nil, // abort early on timeout, context cancel } // call next handler in chain @@ -50,7 +65,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req r.nextHandler.ServeHTTP(capturedResp, req) lazyBody.Capture() if !capturedResp.discardErrResp || // max retries reached - !isRetryableStatusCode(capturedResp.statusCode) { + !r.isRetryableStatusCode(capturedResp.statusCode) { break } totalRetries.Inc() @@ -60,10 +75,9 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req } } -func isRetryableStatusCode(status int) bool { - return status == http.StatusBadGateway || - status == http.StatusServiceUnavailable || - status == http.StatusGatewayTimeout +func (r RetryMiddleware) isRetryableStatusCode(status int) bool { + _, ok := r.retryStatusCodes[status] + return ok } var ( @@ -71,13 +85,17 @@ var ( _ io.ReaderFrom = &responseWriterDelegator{} ) +// responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional +// functionalities for handling response writing. Depending on the status code and discard settings, +// the heeader + content on write is skipped so that it can be re-used on retry. type responseWriterDelegator struct { http.ResponseWriter headerBuf http.Header wroteHeader bool statusCode int // always writes to responseWriter when false - discardErrResp bool + discardErrResp bool + isRetryableStatusCode func(status int) bool } func (r *responseWriterDelegator) Header() http.Header { @@ -91,7 +109,7 @@ func (r *responseWriterDelegator) WriteHeader(status int) { // any 1xx informational response should be written r.discardErrResp = r.discardErrResp && !(status >= 100 && status < 200) } - if r.discardErrResp && isRetryableStatusCode(status) { + if r.discardErrResp && r.isRetryableStatusCode(status) { return } // copy header values to target @@ -103,12 +121,17 @@ func (r *responseWriterDelegator) WriteHeader(status int) { r.ResponseWriter.WriteHeader(status) } +// Write writes data to the response. +// If the response header has not been set, it sets the default status code to 200. +// When the status code qualifies for a retry, no content is written. +// +// It returns the number of bytes written and any error encountered. func (r *responseWriterDelegator) Write(data []byte) (int, error) { // ensure header is set. default is 200 in Go stdlib if !r.wroteHeader { r.WriteHeader(http.StatusOK) } - if r.discardErrResp && isRetryableStatusCode(r.statusCode) { + if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) { return io.Discard.Write(data) } else { return r.ResponseWriter.Write(data) @@ -120,7 +143,7 @@ func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) { if !r.wroteHeader { r.WriteHeader(http.StatusOK) } - if r.discardErrResp && isRetryableStatusCode(r.statusCode) { + if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) { return io.Copy(io.Discard, re) } else { return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re) diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index bd4f25e7..3ba2ce18 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "fmt" + "io" "log" "net/http" "net/http/httptest" @@ -145,6 +146,8 @@ func TestRetryMiddleware(t *testing.T) { code := serverCodes[i-1] t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) w.WriteHeader(code) + _, err := w.Write([]byte(strconv.Itoa(code))) + require.NoError(t, err) })) // Mock an EndpointSlice. @@ -183,11 +186,9 @@ func TestRetryMiddleware(t *testing.T) { backendRequests.Store(0) // when single request sent - var wg sync.WaitGroup - sendRequest(t, &wg, modelName, spec.expResultCode) - wg.Wait() - - // then + gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode) + // then only the last body is written + assert.Equal(t, strconv.Itoa(spec.expResultCode), gotBody) require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries") }) } @@ -222,9 +223,10 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp } } -func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) { +func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) <-chan string { t.Helper() wg.Add(1) + bodyRespChan := make(chan string, 1) go func() { defer wg.Done() @@ -235,7 +237,13 @@ func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int res, err := testHTTPClient.Do(req) require.NoError(t, err) require.Equal(t, expCode, res.StatusCode) + got, err := io.ReadAll(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + bodyRespChan <- string(got) + close(bodyRespChan) }() + return bodyRespChan } func completeRequests(c chan struct{}, n int) { From 4299373519a9afccdbed5c55c160774e32e0be52 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Tue, 16 Jan 2024 16:31:33 +0100 Subject: [PATCH 08/12] Remove panic --- pkg/proxy/middleware.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 73fe3872..e8a5820f 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "log" "math/rand" "net/http" "time" @@ -60,7 +61,9 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req // call next handler in chain req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody) if err != nil { - panic(err) + log.Printf("clone request: %v", err) + writer.WriteHeader(http.StatusInternalServerError) + return } r.nextHandler.ServeHTTP(capturedResp, req) lazyBody.Capture() From d0b684a26643404a9efeb0ae90069d6facc8a467 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Wed, 17 Jan 2024 11:53:22 +0100 Subject: [PATCH 09/12] Fixes and tests --- pkg/proxy/middleware.go | 138 +++++++++++++++++--------- pkg/proxy/middleware_test.go | 115 ++++++++++++++++++++- tests/integration/integration_test.go | 2 +- 3 files changed, 204 insertions(+), 51 deletions(-) diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index e8a5820f..6964a84a 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "io" - "log" "math/rand" "net/http" "time" @@ -44,35 +43,35 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes } } +type xResponseWriter interface { + http.ResponseWriter + discardedResponse() bool + capturedStatusCode() int +} +type xBodyCapturer interface { + io.ReadCloser + Capture() +} + func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - lazyBody := &lazyBodyCapturer{ - reader: request.Body, - buf: bytes.NewBuffer([]byte{}), - } + lazyBody := newLazyBodyCapturer(request.Body) request.Body = lazyBody for i := 0; ; i++ { - capturedResp := &responseWriterDelegator{ - isRetryableStatusCode: r.isRetryableStatusCode, - ResponseWriter: writer, - headerBuf: make(http.Header), - discardErrResp: i < r.maxRetries && - request.Context().Err() == nil, // abort early on timeout, context cancel - } + discardErrResp := i < r.maxRetries && request.Context().Err() == nil + capturedResp := newResponseWriterDelegator(writer, r.isRetryableStatusCode, discardErrResp) // call next handler in chain - req, err := http.NewRequestWithContext(request.Context(), request.Method, request.URL.String(), lazyBody) - if err != nil { - log.Printf("clone request: %v", err) - writer.WriteHeader(http.StatusInternalServerError) - return - } - r.nextHandler.ServeHTTP(capturedResp, req) - lazyBody.Capture() - if !capturedResp.discardErrResp || // max retries reached - !r.isRetryableStatusCode(capturedResp.statusCode) { + reqClone := request.Clone(request.Context()) // also copies the reference to the lazy body capturer + r.nextHandler.ServeHTTP(capturedResp, reqClone) + + if !capturedResp.discardedResponse() || // max retries reached or context error + !r.isRetryableStatusCode(capturedResp.capturedStatusCode()) { break } + // setup for retry + lazyBody.Capture() + totalRetries.Inc() - // Exponential backoff + // exponential backoff jitter := time.Duration(r.rSource.Intn(50)) time.Sleep(time.Millisecond*time.Duration(1< Date: Wed, 17 Jan 2024 12:21:43 +0100 Subject: [PATCH 10/12] Better status code capturing --- pkg/proxy/handler.go | 2 +- pkg/proxy/metrics.go | 29 +++++++++++++++++++---- pkg/proxy/metrics_test.go | 45 ++++++++++++++++++++++++++++++++++++ pkg/proxy/middleware.go | 11 +++++---- pkg/proxy/middleware_test.go | 4 ++-- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index c2419e60..e224f4a9 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -42,7 +42,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) w = captureStatusRespWriter timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v) + httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v) })) defer timer.ObserveDuration() diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 48d204c8..a9ef3b32 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -22,15 +22,32 @@ func MustRegister(r prometheus.Registerer) { r.MustRegister(httpDuration, totalRetries) } -// captureStatusResponseWriter is a custom HTTP response writer that captures the status code. +// statusCodeCapturer is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response. +type statusCodeCapturer interface { + http.ResponseWriter + CapturedStatusCode() int +} + +// captureStatusResponseWriter is a custom HTTP response writer that implements statusCodeCapturer type captureStatusResponseWriter struct { http.ResponseWriter statusCode int wroteHeader bool } -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { - return &captureStatusResponseWriter{ResponseWriter: responseWriter} +func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) statusCodeCapturer { + if o, ok := responseWriter.(statusCodeCapturer); ok { // nothing to do as code is captured already + return o + } + c := &captureStatusResponseWriter{ResponseWriter: responseWriter} + if _, ok := responseWriter.(io.ReaderFrom); ok { + return &captureStatusResponseWriterWithReadFrom{captureStatusResponseWriter: c} + } + return c +} + +func (c *captureStatusResponseWriter) CapturedStatusCode() int { + return c.statusCode } func (c *captureStatusResponseWriter) WriteHeader(code int) { @@ -46,7 +63,11 @@ func (c *captureStatusResponseWriter) Write(b []byte) (int, error) { return c.ResponseWriter.Write(b) } -func (c *captureStatusResponseWriter) ReadFrom(re io.Reader) (int64, error) { +type captureStatusResponseWriterWithReadFrom struct { + *captureStatusResponseWriter +} + +func (c *captureStatusResponseWriterWithReadFrom) ReadFrom(re io.Reader) (int64, error) { if !c.wroteHeader { c.WriteHeader(http.StatusOK) } diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index cf0d5c3e..b43b1606 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "net/http/httptest" "strings" @@ -83,6 +84,50 @@ func TestMetricsViaLinter(t *testing.T) { require.Empty(t, problems) } +func TestCaptureStatusCodeResponseWriters(t *testing.T) { + specs := map[string]struct { + rspWriter http.ResponseWriter + expType any + write func(t *testing.T, r http.ResponseWriter, content string) + }{ + "implements statusCodeCapturer": { + rspWriter: &responseWriterDelegator{headerBuf: make(http.Header), ResponseWriter: httptest.NewRecorder()}, + expType: &responseWriterDelegator{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + r.WriteHeader(200) + }, + }, + "implements io.ReaderFrom": { + rspWriter: &testResponseWriter{ResponseRecorder: httptest.NewRecorder()}, + expType: &captureStatusResponseWriterWithReadFrom{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.(io.ReaderFrom).ReadFrom(strings.NewReader(content)) + require.NoError(t, err) + assert.Equal(t, len(content), int(n)) + }, + }, + "default": { + rspWriter: httptest.NewRecorder(), + expType: &captureStatusResponseWriter{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.Write([]byte(content)) + require.NoError(t, err) + assert.Equal(t, len(content), n) + }, + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + instance := newCaptureStatusCodeResponseWriter(spec.rspWriter) + require.IsType(t, spec.expType, instance) + spec.write(t, instance, "foo") + gotCode := instance.CapturedStatusCode() + assert.Equal(t, http.StatusOK, gotCode) + }) + } +} + func toMap(s []*io_prometheus_client.LabelPair) map[string]string { r := make(map[string]string, len(s)) for _, v := range s { diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 6964a84a..2fb1ae68 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -46,7 +46,7 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes type xResponseWriter interface { http.ResponseWriter discardedResponse() bool - capturedStatusCode() int + CapturedStatusCode() int } type xBodyCapturer interface { io.ReadCloser @@ -64,7 +64,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req r.nextHandler.ServeHTTP(capturedResp, reqClone) if !capturedResp.discardedResponse() || // max retries reached or context error - !r.isRetryableStatusCode(capturedResp.capturedStatusCode()) { + !r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) { break } // setup for retry @@ -83,8 +83,9 @@ func (r RetryMiddleware) isRetryableStatusCode(status int) bool { } var ( - _ http.Flusher = &responseWriterDelegator{} - _ io.ReaderFrom = &xResponseWriterDelegator{} + _ http.Flusher = &responseWriterDelegator{} + _ io.ReaderFrom = &xResponseWriterDelegator{} + _ statusCodeCapturer = &responseWriterDelegator{} ) // responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional @@ -118,7 +119,7 @@ func (r *responseWriterDelegator) discardedResponse() bool { return r.discardErrResp } -func (r *responseWriterDelegator) capturedStatusCode() int { +func (r *responseWriterDelegator) CapturedStatusCode() int { return r.statusCode } diff --git a/pkg/proxy/middleware_test.go b/pkg/proxy/middleware_test.go index 452fee12..18bc88d7 100644 --- a/pkg/proxy/middleware_test.go +++ b/pkg/proxy/middleware_test.go @@ -130,7 +130,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(myTestContent), int(n)) assert.Equal(t, myTestContent, rec.Body.String()) - assert.Equal(t, http.StatusOK, d.capturedStatusCode()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) // scenario: discard on error enabled rec = &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} @@ -142,7 +142,7 @@ func TestWriteDelegatorReadFrom(t *testing.T) { require.NoError(t, err) assert.Equal(t, len(myTestContent), int(n)) assert.Equal(t, "", rec.Body.String()) - assert.Equal(t, http.StatusOK, d.capturedStatusCode()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) // scenario: not implementing io.ReaderFrom d = newResponseWriterDelegator(httptest.NewRecorder(), func(int) bool { return true }, false) From 9011105892f91f755d56b58f86820861dd44d13d Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Thu, 18 Jan 2024 09:57:31 +0100 Subject: [PATCH 11/12] Review feedback --- pkg/endpoints/endpoints.go | 8 +++----- pkg/proxy/handler.go | 1 + pkg/proxy/middleware.go | 4 +--- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pkg/endpoints/endpoints.go b/pkg/endpoints/endpoints.go index c3910e3b..b9e03f6c 100644 --- a/pkg/endpoints/endpoints.go +++ b/pkg/endpoints/endpoints.go @@ -18,8 +18,7 @@ func newEndpointGroup() *endpointGroup { } type endpoint struct { - inFlight *atomic.Int64 - terminated chan struct{} + inFlight *atomic.Int64 } type endpointGroup struct { @@ -107,13 +106,12 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) g.ports = ports for ip := range ips { if _, ok := g.endpoints[ip]; !ok { - g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}, terminated: make(chan struct{})} + g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}} } } - for ip, endpoint := range g.endpoints { + for ip := range g.endpoints { if _, ok := ips[ip]; !ok { delete(g.endpoints, ip) - close(endpoint.terminated) } } g.mtx.Unlock() diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index e224f4a9..be6741d8 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -131,6 +131,7 @@ func parseModel(r *http.Request) (string, *http.Request, error) { } var body []byte if mb, ok := r.Body.(*lazyBodyCapturer); ok && mb.capturedBody != nil { + // reuse buffer body = mb.capturedBody } else { // parse request body for model name, ignore errors diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index 2fb1ae68..ad41cc51 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -14,7 +14,6 @@ var _ http.Handler = &RetryMiddleware{} type RetryMiddleware struct { nextHandler http.Handler maxRetries int - rSource *rand.Rand retryStatusCodes map[int]struct{} } @@ -39,7 +38,6 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes nextHandler: other, maxRetries: maxRetries, retryStatusCodes: statusCodeIndex, - rSource: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -72,7 +70,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req totalRetries.Inc() // exponential backoff - jitter := time.Duration(r.rSource.Intn(50)) + jitter := time.Duration(rand.Intn(50)) time.Sleep(time.Millisecond*time.Duration(1< Date: Mon, 22 Jan 2024 10:52:55 +0100 Subject: [PATCH 12/12] Simplify middleware --- pkg/proxy/handler.go | 6 +- pkg/proxy/metrics.go | 12 +-- pkg/proxy/metrics_test.go | 10 ++- pkg/proxy/middleware.go | 167 ++++++++++++++++++++--------------- pkg/proxy/middleware_test.go | 15 ++-- 5 files changed, 121 insertions(+), 89 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index be6741d8..4acfb57e 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -39,7 +39,7 @@ func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queu func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var modelName string - captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) + captureStatusRespWriter := NewCaptureStatusCodeResponseWriter(w) w = captureStatusRespWriter timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v) @@ -130,9 +130,9 @@ func parseModel(r *http.Request) (string, *http.Request, error) { return model, r, nil } var body []byte - if mb, ok := r.Body.(*lazyBodyCapturer); ok && mb.capturedBody != nil { + if mb, ok := r.Body.(CapturedBodySource); ok && mb.IsCaptured() { // reuse buffer - body = mb.capturedBody + body = mb.GetBody() } else { // parse request body for model name, ignore errors var err error diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index a9ef3b32..94a75031 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -22,21 +22,21 @@ func MustRegister(r prometheus.Registerer) { r.MustRegister(httpDuration, totalRetries) } -// statusCodeCapturer is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response. -type statusCodeCapturer interface { +// CaptureStatusCodeResponseWriter is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response. +type CaptureStatusCodeResponseWriter interface { http.ResponseWriter - CapturedStatusCode() int + StatusCodeCapturer } -// captureStatusResponseWriter is a custom HTTP response writer that implements statusCodeCapturer +// captureStatusResponseWriter is a custom HTTP response writer that implements CaptureStatusCodeResponseWriter type captureStatusResponseWriter struct { http.ResponseWriter statusCode int wroteHeader bool } -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) statusCodeCapturer { - if o, ok := responseWriter.(statusCodeCapturer); ok { // nothing to do as code is captured already +func NewCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) CaptureStatusCodeResponseWriter { + if o, ok := responseWriter.(CaptureStatusCodeResponseWriter); ok { // nothing to do as code is captured already return o } c := &captureStatusResponseWriter{ResponseWriter: responseWriter} diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index b43b1606..a3c1a7ea 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -91,8 +91,12 @@ func TestCaptureStatusCodeResponseWriters(t *testing.T) { write func(t *testing.T, r http.ResponseWriter, content string) }{ "implements statusCodeCapturer": { - rspWriter: &responseWriterDelegator{headerBuf: make(http.Header), ResponseWriter: httptest.NewRecorder()}, - expType: &responseWriterDelegator{}, + rspWriter: &discardableResponseWriter{ + headerBuf: make(http.Header), + ResponseWriter: httptest.NewRecorder(), + isDiscardable: func(status int) bool { return false }, + }, + expType: &discardableResponseWriter{}, write: func(t *testing.T, r http.ResponseWriter, content string) { r.WriteHeader(200) }, @@ -119,7 +123,7 @@ func TestCaptureStatusCodeResponseWriters(t *testing.T) { for name, spec := range specs { t.Run(name, func(t *testing.T) { - instance := newCaptureStatusCodeResponseWriter(spec.rspWriter) + instance := NewCaptureStatusCodeResponseWriter(spec.rspWriter) require.IsType(t, spec.expType, instance) spec.write(t, instance, "foo") gotCode := instance.CapturedStatusCode() diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go index ad41cc51..50dc6bc6 100644 --- a/pkg/proxy/middleware.go +++ b/pkg/proxy/middleware.go @@ -41,28 +41,25 @@ func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes } } -type xResponseWriter interface { - http.ResponseWriter - discardedResponse() bool - CapturedStatusCode() int -} -type xBodyCapturer interface { - io.ReadCloser - Capture() -} - +// ServeHTTP handles the HTTP request by capturing the request body, calling the next handler in the chain, and retrying if necessary. +// It captures the request body using a LazyBodyCapturer, and sets a captured response writer using NewDiscardableResponseWriter. +// It retries the request if the response was discarded and the response status code is retryable. +// It uses exponential backoff for retries with a random jitter. +// The maximum number of retries is determined by the maxRetries field of RetryMiddleware. func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - lazyBody := newLazyBodyCapturer(request.Body) + lazyBody := NewLazyBodyCapturer(request.Body) request.Body = lazyBody for i := 0; ; i++ { - discardErrResp := i < r.maxRetries && request.Context().Err() == nil - capturedResp := newResponseWriterDelegator(writer, r.isRetryableStatusCode, discardErrResp) + withoutRetry := i == r.maxRetries || request.Context().Err() != nil + if withoutRetry { + r.nextHandler.ServeHTTP(writer, request) + return + } + capturedResp := NewDiscardableResponseWriter(writer, r.isRetryableStatusCode) // call next handler in chain - reqClone := request.Clone(request.Context()) // also copies the reference to the lazy body capturer - r.nextHandler.ServeHTTP(capturedResp, reqClone) + r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) // clone also copies the reference to the lazy body capturer - if !capturedResp.discardedResponse() || // max retries reached or context error - !r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) { + if !r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) { break } // setup for retry @@ -76,65 +73,89 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req } func (r RetryMiddleware) isRetryableStatusCode(status int) bool { + // any 1xx informational response should be written + if status >= 100 && status < 200 { + return false + } _, ok := r.retryStatusCodes[status] return ok } +type ( + StatusCodeCapturer interface { + CapturedStatusCode() int + } + XResponseWriter interface { + http.ResponseWriter + StatusCodeCapturer + } + + // CapturedBodySource represents an interface for capturing and retrieving the body of an HTTP request. + CapturedBodySource interface { + IsCaptured() bool + GetBody() []byte + } + + XBodyCapturer interface { + io.ReadCloser + CapturedBodySource + Capture() + } +) + var ( - _ http.Flusher = &responseWriterDelegator{} - _ io.ReaderFrom = &xResponseWriterDelegator{} - _ statusCodeCapturer = &responseWriterDelegator{} + _ http.Flusher = &discardableResponseWriter{} + _ io.ReaderFrom = &discardableResponseWriterWithReaderFrom{} + _ CaptureStatusCodeResponseWriter = &discardableResponseWriter{} ) -// responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional -// functionalities for handling response writing. Depending on the status code and discard settings, -// the header + content on write is skipped so that it can be re-used on retry. -type responseWriterDelegator struct { +// discardableResponseWriter represents a wrapper around http.ResponseWriter that provides additional +// functionalities for handling response writing. Depending on the status code, +// the header + content on write is skipped so that it can be re-used on retry or written to the underlying +// response writer. +type discardableResponseWriter struct { http.ResponseWriter - headerBuf http.Header - wroteHeader bool - statusCode int - // always writes to responseWriter when false - discardErrResp bool - isRetryableStatusCode func(status int) bool -} - -// newResponseWriterDelegator constructor -func newResponseWriterDelegator(writer http.ResponseWriter, isRetryableStatusCode func(status int) bool, discardErrResp bool) xResponseWriter { - d := &responseWriterDelegator{ - isRetryableStatusCode: isRetryableStatusCode, - ResponseWriter: writer, - headerBuf: make(http.Header), - discardErrResp: discardErrResp, // abort early on timeout, context cancel + headerBuf http.Header + wroteHeader bool + statusCode int + immediateWrite bool + isDiscardable func(status int) bool +} + +// NewDiscardableResponseWriter creates a new instance of the response writer delegator. +// It takes a http.ResponseWriter and a function to determine by the status code if content should be written or discarded (for retry). +func NewDiscardableResponseWriter(writer http.ResponseWriter, isDiscardable func(status int) bool) XResponseWriter { + d := &discardableResponseWriter{ + isDiscardable: isDiscardable, + ResponseWriter: writer, + headerBuf: make(http.Header), + immediateWrite: false, } if _, ok := writer.(io.ReaderFrom); ok { - return &xResponseWriterDelegator{responseWriterDelegator: d} + return &discardableResponseWriterWithReaderFrom{discardableResponseWriter: d} } return d } -func (r *responseWriterDelegator) discardedResponse() bool { - return r.discardErrResp -} - -func (r *responseWriterDelegator) CapturedStatusCode() int { +func (r *discardableResponseWriter) CapturedStatusCode() int { return r.statusCode } -func (r *responseWriterDelegator) Header() http.Header { +func (r *discardableResponseWriter) Header() http.Header { return r.headerBuf } -func (r *responseWriterDelegator) WriteHeader(status int) { +// WriteHeader sets the response status code and writes the response header to the underlying http.ResponseWriter or +// discards it based on the result of the isDiscardable call. +func (r *discardableResponseWriter) WriteHeader(status int) { r.statusCode = status if !r.wroteHeader { r.wroteHeader = true - // any 1xx informational response should be written - r.discardErrResp = r.discardErrResp && !(status >= 100 && status < 200) } - if r.discardErrResp && r.isRetryableStatusCode(status) { + if r.isDiscardable(status) { return } + r.immediateWrite = true // copy header values to target for k, vals := range r.headerBuf { for _, val := range vals { @@ -146,46 +167,42 @@ func (r *responseWriterDelegator) WriteHeader(status int) { // Write writes data to the response. // If the response header has not been set, it sets the default status code to 200. -// When the status code qualifies for a retry, no content is written. +// Based on the status code, the content is either written or discarded. // // It returns the number of bytes written and any error encountered. -func (r *responseWriterDelegator) Write(data []byte) (int, error) { +func (r *discardableResponseWriter) Write(data []byte) (int, error) { // ensure header is set. default is 200 in Go stdlib if !r.wroteHeader { r.WriteHeader(http.StatusOK) } - if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) { - return io.Discard.Write(data) - } else { + if r.immediateWrite { return r.ResponseWriter.Write(data) } + return io.Discard.Write(data) } -func (r *responseWriterDelegator) Flush() { +func (r *discardableResponseWriter) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() } } -// xResponseWriterDelegator provides the same functionalities as responseWriterDelegator but also implements the +// discardableResponseWriterWithReaderFrom provides the same functionalities as discardableResponseWriter but also implements the // io.ReaderFrom interface. -// The ReadFrom method ensures that the header is set before reading from the reader. -// In case discardErrResp is true and the response status code is retryable, the content is discarded. -// Otherwise, it calls the ReadFrom method of the underlying response writer and returns the result. -type xResponseWriterDelegator struct { - *responseWriterDelegator +// Based on the status code, the content is either written or discarded. +type discardableResponseWriterWithReaderFrom struct { + *discardableResponseWriter } -func (r *xResponseWriterDelegator) ReadFrom(re io.Reader) (int64, error) { +func (r *discardableResponseWriterWithReaderFrom) ReadFrom(re io.Reader) (int64, error) { // ensure header is set. default is 200 in Go stdlib if !r.wroteHeader { r.WriteHeader(http.StatusOK) } - if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) { - return io.Copy(io.Discard, re) - } else { + if r.immediateWrite { return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re) } + return io.Copy(io.Discard, re) } var ( @@ -203,8 +220,8 @@ type lazyBodyCapturer struct { allRead bool } -// newLazyBodyCapturer constructor -func newLazyBodyCapturer(body io.ReadCloser) xBodyCapturer { +// NewLazyBodyCapturer constructor. +func NewLazyBodyCapturer(body io.ReadCloser) XBodyCapturer { c := &lazyBodyCapturer{ reader: body, buf: bytes.NewBuffer([]byte{}), @@ -230,15 +247,27 @@ func (m *lazyBodyCapturer) Close() error { return m.reader.Close() } +// Capture marks the body as fully captured. +// The captured body data can be read via GetBody. func (m *lazyBodyCapturer) Capture() { m.allRead = true - if m.buf != nil { + if !m.IsCaptured() { m.capturedBody = m.buf.Bytes() m.buf = nil } m.reader = io.NopCloser(bytes.NewReader(m.capturedBody)) } +// IsCaptured returns true when a body was captured. +func (m *lazyBodyCapturer) IsCaptured() bool { + return m.capturedBody != nil +} + +// GetBody returns the captured byte slice. Value is nil when not captured, yet. +func (m *lazyBodyCapturer) GetBody() []byte { + return m.capturedBody +} + type lazyBodyCapturerWriteTo struct { *lazyBodyCapturer } diff --git a/pkg/proxy/middleware_test.go b/pkg/proxy/middleware_test.go index 18bc88d7..9c2d7bda 100644 --- a/pkg/proxy/middleware_test.go +++ b/pkg/proxy/middleware_test.go @@ -121,9 +121,9 @@ func TestServeHTTP(t *testing.T) { func TestWriteDelegatorReadFrom(t *testing.T) { const myTestContent = `my body content` + // scenario: with non retry status code rec := &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} - // scenario: discard on error disabled - d := newResponseWriterDelegator(rec, func(int) bool { return true }, false) + d := NewDiscardableResponseWriter(rec, func(int) bool { return false }) // when n, err := d.(io.ReaderFrom).ReadFrom(strings.NewReader(myTestContent)) // then the content is written @@ -132,10 +132,9 @@ func TestWriteDelegatorReadFrom(t *testing.T) { assert.Equal(t, myTestContent, rec.Body.String()) assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) - // scenario: discard on error enabled + // scenario: with a retry status code rec = &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} - // with discard on error disabled - d = newResponseWriterDelegator(rec, func(int) bool { return true }, true) + d = NewDiscardableResponseWriter(rec, func(int) bool { return true }) // when n, err = d.(io.ReaderFrom).ReadFrom(strings.NewReader(myTestContent)) // then the content is not written @@ -145,14 +144,14 @@ func TestWriteDelegatorReadFrom(t *testing.T) { assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) // scenario: not implementing io.ReaderFrom - d = newResponseWriterDelegator(httptest.NewRecorder(), func(int) bool { return true }, false) + d = NewDiscardableResponseWriter(httptest.NewRecorder(), func(int) bool { return true }) _, ok := d.(io.ReaderFrom) require.False(t, ok) } func TestLazyBodyCapturer(t *testing.T) { const myTestContent = "my-test-content" - c := newLazyBodyCapturer(io.NopCloser(strings.NewReader(myTestContent))) + c := NewLazyBodyCapturer(io.NopCloser(strings.NewReader(myTestContent))) var buf bytes.Buffer n, err := c.(io.WriterTo).WriteTo(&buf) require.NoError(t, err) @@ -168,7 +167,7 @@ func TestLazyBodyCapturer(t *testing.T) { assert.Equal(t, myTestContent, buf.String()) // scenario: source reader does not implement WriteTo - c = newLazyBodyCapturer(testReader{strings.NewReader(myTestContent)}) + c = NewLazyBodyCapturer(testReader{strings.NewReader(myTestContent)}) // then instance also does not implement it _, ok := c.(io.WriterTo) require.False(t, ok)