diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 7bfec8f0..b561b0d9 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -71,11 +71,13 @@ func run() error { var metricsAddr string var probeAddr string var concurrencyPerReplica int + var maxRetriesOnErr int flag.StringVar(&metricsAddr, "metrics-bind-address", ":8082", "The address the metric endpoint binds to.") flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.") flag.IntVar(&concurrencyPerReplica, "concurrency", concurrency, "the number of simultaneous requests that can be processed by each replica") flag.IntVar(&scaleDownDelay, "scale-down-delay", scaleDownDelay, "seconds to wait before scaling down") + flag.IntVar(&maxRetriesOnErr, "max-retries", 0, "max number of retries on a http error code: 502,503,504") opts := zap.Options{ Development: true, } @@ -113,17 +115,17 @@ func run() error { if err != nil { return fmt.Errorf("getting hostname: %v", err) } + le := leader.NewElection(clientset, hostname, namespace) queueManager := queue.NewManager(concurrencyPerReplica) metricsRegistry := prometheus.WrapRegistererWithPrefix("lingo_", metrics.Registry) queue.NewMetricsCollector(queueManager).MustRegister(metricsRegistry) - endpointManager, err := endpoints.NewManager(mgr) + endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas) if err != nil { return fmt.Errorf("setting up endpoint manager: %w", err) } - endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas // The autoscaling leader will scrape other lingo instances. // Exclude this instance from being scraped by itself. endpointManager.ExcludePods[hostname] = struct{}{} @@ -153,7 +155,7 @@ func run() error { go autoscaler.Start() proxy.MustRegister(metricsRegistry) - proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager) + var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager, maxRetriesOnErr) proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/endpoints/endpoints.go b/pkg/endpoints/endpoints.go index 0e51083a..b9e03f6c 100644 --- a/pkg/endpoints/endpoints.go +++ b/pkg/endpoints/endpoints.go @@ -2,7 +2,9 @@ package endpoints import ( "context" + "errors" "fmt" + "strings" "sync" "sync/atomic" ) @@ -127,3 +129,20 @@ func (g *endpointGroup) broadcastEndpoints() { close(g.bcast) g.bcast = make(chan struct{}) } + +func (e *endpointGroup) AddInflight(addr string) (func(), error) { + tokens := strings.Split(addr, ":") + if len(tokens) != 2 { + return nil, errors.New("unsupported address format") + } + e.mtx.RLock() + defer e.mtx.RUnlock() + endpoint, ok := e.endpoints[tokens[0]] + if !ok { + return nil, errors.New("unsupported endpoint address") + } + endpoint.inFlight.Add(1) + return func() { + endpoint.inFlight.Add(-1) + }, nil +} diff --git a/pkg/endpoints/manager.go b/pkg/endpoints/manager.go index bd2d3e39..414f888b 100644 --- a/pkg/endpoints/manager.go +++ b/pkg/endpoints/manager.go @@ -12,11 +12,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -func NewManager(mgr ctrl.Manager) (*Manager, error) { +func NewManager(mgr ctrl.Manager, endpointSizeCallback func(deploymentName string, replicas int)) (*Manager, error) { r := &Manager{} r.Client = mgr.GetClient() r.endpoints = map[string]*endpointGroup{} r.ExcludePods = map[string]struct{}{} + r.EndpointSizeCallback = endpointSizeCallback if err := r.SetupWithManager(mgr); err != nil { return nil, err } @@ -86,6 +87,11 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, } } + r.SetEndpoints(serviceName, ips, ports) + return ctrl.Result{}, nil +} + +func (r *Manager) SetEndpoints(serviceName string, ips map[string]struct{}, ports map[string]int32) { priorLen := r.getEndpoints(serviceName).lenIPs() r.getEndpoints(serviceName).setIPs(ips, ports) @@ -95,8 +101,6 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, // replicas by something else. r.EndpointSizeCallback(serviceName, len(ips)) } - - return ctrl.Result{}, nil } func (r *Manager) getEndpoints(service string) *endpointGroup { @@ -122,3 +126,7 @@ func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string func (r *Manager) GetAllHosts(service, portName string) []string { return r.getEndpoints(service).getAllHosts(portName) } + +func (r *Manager) RegisterInFlight(service string, hostAddr string) (func(), error) { + return r.getEndpoints(service).AddInflight(hostAddr) +} diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index a51d9d3b..cec4b1f6 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -16,29 +16,39 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - "github.com/substratusai/lingo/pkg/deployments" "github.com/substratusai/lingo/pkg/endpoints" "github.com/substratusai/lingo/pkg/queue" ) +type deploymentSource interface { + ResolveDeployment(model string) (string, bool) + AtLeastOne(deploymentName string) +} + // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - Deployments *deployments.Manager - Endpoints *endpoints.Manager - Queues *queue.Manager + Deployments deploymentSource + Endpoints *endpoints.Manager + Queues *queue.Manager + retriesOnErr int } -func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { - return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues} +func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager, retriesOnErr int) *Handler { + return &Handler{ + Deployments: deployments, + Endpoints: endpoints, + Queues: queues, + retriesOnErr: retriesOnErr, + } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var modelName string - captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) + captureStatusRespWriter := NewCaptureStatusCodeResponseWriter(w) w = captureStatusRespWriter timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v) + httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v) })) defer timer.ObserveDuration() @@ -105,10 +115,19 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } log.Printf("Got host: %v, id: %v\n", host, id) + done, err := h.Endpoints.RegisterInFlight(deploy, host) + if err != nil { + log.Printf("error registering in-flight request: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer done() + + log.Printf("Proxying request to host %v: %v\n", host, id) // TODO: Avoid creating new reverse proxies for each request. // TODO: Consider implementing a round robin scheme. - log.Printf("Proxying request to host %v: %v\n", host, id) - newReverseProxy(host).ServeHTTP(w, proxyRequest) + proxy := newReverseProxy(host) + NewRetryMiddleware(h.retriesOnErr, proxy).ServeHTTP(w, proxyRequest) } // parseModel parses the model name from the request @@ -117,10 +136,17 @@ func parseModel(r *http.Request) (string, *http.Request, error) { if model := r.Header.Get("X-Model"); model != "" { return model, r, nil } - // parse request body for model name, ignore errors - body, err := io.ReadAll(r.Body) - if err != nil { - return "", r, nil + var body []byte + if mb, ok := r.Body.(CapturedBodySource); ok && mb.IsCaptured() { + // reuse buffer + body = mb.GetBody() + } else { + // parse request body for model name, ignore errors + var err error + body, err = io.ReadAll(r.Body) + if err != nil { + return "", r, nil + } } var payload struct { diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go new file mode 100644 index 00000000..0442bc96 --- /dev/null +++ b/pkg/proxy/handler_test.go @@ -0,0 +1,75 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substratusai/lingo/pkg/endpoints" + "github.com/substratusai/lingo/pkg/queue" +) + +func TestProxy(t *testing.T) { + specs := map[string]struct { + request *http.Request + expCode int + }{ + "ok": { + request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), + expCode: http.StatusOK, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + deplMgr := mockDeploymentSource{ + ResolveDeploymentFn: func(model string) (string, bool) { + return "my-deployment", true + }, + AtLeastOneFn: func(deploymentName string) {}, + } + em, err := endpoints.NewManager(&fakeManager{}, func(deploymentName string, replicas int) {}) + require.NoError(t, err) + em.SetEndpoints("my-deployment", map[string]struct{}{"my-ip": {}}, map[string]int32{"my-port": 8080}) + h := NewHandler(deplMgr, em, queue.NewManager(10), 1) + + svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080}) + w.WriteHeader(http.StatusOK) + })) + recorder := httptest.NewRecorder() + + AdditionalProxyRewrite = func(r *httputil.ProxyRequest) { + r.SetURL(&url.URL{Scheme: "http", Host: svr.Listener.Addr().String()}) + } + + // when + h.ServeHTTP(recorder, spec.request) + // then + assert.Equal(t, spec.expCode, recorder.Code) + }) + } +} + +type mockDeploymentSource struct { + ResolveDeploymentFn func(model string) (string, bool) + AtLeastOneFn func(deploymentName string) +} + +func (m mockDeploymentSource) ResolveDeployment(model string) (string, bool) { + if m.ResolveDeploymentFn == nil { + panic("not expected to be called") + } + return m.ResolveDeploymentFn(model) +} + +func (m mockDeploymentSource) AtLeastOne(deploymentName string) { + if m.AtLeastOneFn == nil { + panic("not expected to be called") + } + m.AtLeastOneFn(deploymentName) +} diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go index 5cd229e7..94a75031 100644 --- a/pkg/proxy/metrics.go +++ b/pkg/proxy/metrics.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -12,21 +13,63 @@ var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Buckets: prometheus.DefBuckets, }, []string{"model", "status_code"}) +var totalRetries = prometheus.NewCounter(prometheus.CounterOpts{ + Name: "http_request_retry_total", + Help: "Number of HTTP request retries.", +}) + func MustRegister(r prometheus.Registerer) { - r.MustRegister(httpDuration) + r.MustRegister(httpDuration, totalRetries) +} + +// CaptureStatusCodeResponseWriter is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response. +type CaptureStatusCodeResponseWriter interface { + http.ResponseWriter + StatusCodeCapturer } -// captureStatusResponseWriter is a custom HTTP response writer that captures the status code. +// captureStatusResponseWriter is a custom HTTP response writer that implements CaptureStatusCodeResponseWriter type captureStatusResponseWriter struct { http.ResponseWriter - statusCode int + statusCode int + wroteHeader bool +} + +func NewCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) CaptureStatusCodeResponseWriter { + if o, ok := responseWriter.(CaptureStatusCodeResponseWriter); ok { // nothing to do as code is captured already + return o + } + c := &captureStatusResponseWriter{ResponseWriter: responseWriter} + if _, ok := responseWriter.(io.ReaderFrom); ok { + return &captureStatusResponseWriterWithReadFrom{captureStatusResponseWriter: c} + } + return c +} + +func (c *captureStatusResponseWriter) CapturedStatusCode() int { + return c.statusCode +} + +func (c *captureStatusResponseWriter) WriteHeader(code int) { + c.wroteHeader = true + c.statusCode = code + c.ResponseWriter.WriteHeader(code) +} + +func (c *captureStatusResponseWriter) Write(b []byte) (int, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.Write(b) } -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { - return &captureStatusResponseWriter{ResponseWriter: responseWriter} +type captureStatusResponseWriterWithReadFrom struct { + *captureStatusResponseWriter } -func (srw *captureStatusResponseWriter) WriteHeader(code int) { - srw.statusCode = code - srw.ResponseWriter.WriteHeader(code) +func (c *captureStatusResponseWriterWithReadFrom) ReadFrom(re io.Reader) (int64, error) { + if !c.wroteHeader { + c.WriteHeader(http.StatusOK) + } + return c.ResponseWriter.(io.ReaderFrom).ReadFrom(re) } diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index 5940e445..79bc68ff 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -1,6 +1,7 @@ package proxy import ( + "io" "net/http" "net/http/httptest" "strings" @@ -55,7 +56,7 @@ func TestMetrics(t *testing.T) { deplMgr, err := deployments.NewManager(&fakeManager{}) require.NoError(t, err) - h := NewHandler(deplMgr, nil, nil) + h := NewHandler(deplMgr, nil, nil, 2) recorder := httptest.NewRecorder() // when @@ -65,10 +66,11 @@ func TestMetrics(t *testing.T) { assert.Equal(t, spec.expCode, recorder.Code) gathered, err := registry.Gather() require.NoError(t, err) - require.Len(t, gathered, 1) - require.Len(t, gathered[0].Metric, 1) - assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount()) - assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label)) + require.Len(t, gathered, 2) + require.Equal(t, "http_response_time_seconds", *gathered[1].Name) + require.Len(t, gathered[1].Metric, 1) + assert.NotEmpty(t, gathered[1].Metric[0].GetHistogram().GetSampleCount()) + assert.Equal(t, spec.expLabels, toMap(gathered[1].Metric[0].Label)) }) } } @@ -82,6 +84,54 @@ func TestMetricsViaLinter(t *testing.T) { require.Empty(t, problems) } +func TestCaptureStatusCodeResponseWriters(t *testing.T) { + specs := map[string]struct { + rspWriter http.ResponseWriter + expType any + write func(t *testing.T, r http.ResponseWriter, content string) + }{ + "implements statusCodeCapturer": { + rspWriter: &discardableResponseWriter{ + headerBuf: make(http.Header), + ResponseWriter: httptest.NewRecorder(), + isDiscardable: func(status int) bool { return false }, + }, + expType: &discardableResponseWriter{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + r.WriteHeader(200) + }, + }, + "implements io.ReaderFrom": { + rspWriter: &testResponseWriter{ResponseRecorder: httptest.NewRecorder()}, + expType: &captureStatusResponseWriterWithReadFrom{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.(io.ReaderFrom).ReadFrom(strings.NewReader(content)) + require.NoError(t, err) + assert.Equal(t, len(content), int(n)) + }, + }, + "default": { + rspWriter: httptest.NewRecorder(), + expType: &captureStatusResponseWriter{}, + write: func(t *testing.T, r http.ResponseWriter, content string) { + n, err := r.Write([]byte(content)) + require.NoError(t, err) + assert.Equal(t, len(content), n) + }, + }, + } + + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + instance := NewCaptureStatusCodeResponseWriter(spec.rspWriter) + require.IsType(t, spec.expType, instance) + spec.write(t, instance, "foo") + gotCode := instance.CapturedStatusCode() + assert.Equal(t, http.StatusOK, gotCode) + }) + } +} + func toMap(s []*io_prometheus_client.LabelPair) map[string]string { r := make(map[string]string, len(s)) for _, v := range s { diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go new file mode 100644 index 00000000..50dc6bc6 --- /dev/null +++ b/pkg/proxy/middleware.go @@ -0,0 +1,282 @@ +package proxy + +import ( + "bytes" + "errors" + "io" + "math/rand" + "net/http" + "time" +) + +var _ http.Handler = &RetryMiddleware{} + +type RetryMiddleware struct { + nextHandler http.Handler + maxRetries int + retryStatusCodes map[int]struct{} +} + +// NewRetryMiddleware creates a new HTTP middleware that adds retry functionality. +// It takes the maximum number of retries, the next handler in the middleware chain, +// and an optional list of retryable status codes. +// If the maximum number of retries is 0, it returns the next handler without adding any retries. +// If the list of retryable status codes is empty, it uses a default set of status codes (http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout). +// The function creates a RetryMiddleware struct with the given parameters and returns it as an http.Handler. +func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes ...int) http.Handler { + if maxRetries == 0 { + return other + } + if len(optRetryStatusCodes) == 0 { + optRetryStatusCodes = []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout} + } + statusCodeIndex := make(map[int]struct{}, len(optRetryStatusCodes)) + for _, c := range optRetryStatusCodes { + statusCodeIndex[c] = struct{}{} + } + return &RetryMiddleware{ + nextHandler: other, + maxRetries: maxRetries, + retryStatusCodes: statusCodeIndex, + } +} + +// ServeHTTP handles the HTTP request by capturing the request body, calling the next handler in the chain, and retrying if necessary. +// It captures the request body using a LazyBodyCapturer, and sets a captured response writer using NewDiscardableResponseWriter. +// It retries the request if the response was discarded and the response status code is retryable. +// It uses exponential backoff for retries with a random jitter. +// The maximum number of retries is determined by the maxRetries field of RetryMiddleware. +func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + lazyBody := NewLazyBodyCapturer(request.Body) + request.Body = lazyBody + for i := 0; ; i++ { + withoutRetry := i == r.maxRetries || request.Context().Err() != nil + if withoutRetry { + r.nextHandler.ServeHTTP(writer, request) + return + } + capturedResp := NewDiscardableResponseWriter(writer, r.isRetryableStatusCode) + // call next handler in chain + r.nextHandler.ServeHTTP(capturedResp, request.Clone(request.Context())) // clone also copies the reference to the lazy body capturer + + if !r.isRetryableStatusCode(capturedResp.CapturedStatusCode()) { + break + } + // setup for retry + lazyBody.Capture() + + totalRetries.Inc() + // exponential backoff + jitter := time.Duration(rand.Intn(50)) + time.Sleep(time.Millisecond*time.Duration(1<= 100 && status < 200 { + return false + } + _, ok := r.retryStatusCodes[status] + return ok +} + +type ( + StatusCodeCapturer interface { + CapturedStatusCode() int + } + XResponseWriter interface { + http.ResponseWriter + StatusCodeCapturer + } + + // CapturedBodySource represents an interface for capturing and retrieving the body of an HTTP request. + CapturedBodySource interface { + IsCaptured() bool + GetBody() []byte + } + + XBodyCapturer interface { + io.ReadCloser + CapturedBodySource + Capture() + } +) + +var ( + _ http.Flusher = &discardableResponseWriter{} + _ io.ReaderFrom = &discardableResponseWriterWithReaderFrom{} + _ CaptureStatusCodeResponseWriter = &discardableResponseWriter{} +) + +// discardableResponseWriter represents a wrapper around http.ResponseWriter that provides additional +// functionalities for handling response writing. Depending on the status code, +// the header + content on write is skipped so that it can be re-used on retry or written to the underlying +// response writer. +type discardableResponseWriter struct { + http.ResponseWriter + headerBuf http.Header + wroteHeader bool + statusCode int + immediateWrite bool + isDiscardable func(status int) bool +} + +// NewDiscardableResponseWriter creates a new instance of the response writer delegator. +// It takes a http.ResponseWriter and a function to determine by the status code if content should be written or discarded (for retry). +func NewDiscardableResponseWriter(writer http.ResponseWriter, isDiscardable func(status int) bool) XResponseWriter { + d := &discardableResponseWriter{ + isDiscardable: isDiscardable, + ResponseWriter: writer, + headerBuf: make(http.Header), + immediateWrite: false, + } + if _, ok := writer.(io.ReaderFrom); ok { + return &discardableResponseWriterWithReaderFrom{discardableResponseWriter: d} + } + return d +} + +func (r *discardableResponseWriter) CapturedStatusCode() int { + return r.statusCode +} + +func (r *discardableResponseWriter) Header() http.Header { + return r.headerBuf +} + +// WriteHeader sets the response status code and writes the response header to the underlying http.ResponseWriter or +// discards it based on the result of the isDiscardable call. +func (r *discardableResponseWriter) WriteHeader(status int) { + r.statusCode = status + if !r.wroteHeader { + r.wroteHeader = true + } + if r.isDiscardable(status) { + return + } + r.immediateWrite = true + // copy header values to target + for k, vals := range r.headerBuf { + for _, val := range vals { + r.ResponseWriter.Header().Add(k, val) + } + } + r.ResponseWriter.WriteHeader(status) +} + +// Write writes data to the response. +// If the response header has not been set, it sets the default status code to 200. +// Based on the status code, the content is either written or discarded. +// +// It returns the number of bytes written and any error encountered. +func (r *discardableResponseWriter) Write(data []byte) (int, error) { + // ensure header is set. default is 200 in Go stdlib + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if r.immediateWrite { + return r.ResponseWriter.Write(data) + } + return io.Discard.Write(data) +} + +func (r *discardableResponseWriter) Flush() { + if f, ok := r.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// discardableResponseWriterWithReaderFrom provides the same functionalities as discardableResponseWriter but also implements the +// io.ReaderFrom interface. +// Based on the status code, the content is either written or discarded. +type discardableResponseWriterWithReaderFrom struct { + *discardableResponseWriter +} + +func (r *discardableResponseWriterWithReaderFrom) ReadFrom(re io.Reader) (int64, error) { + // ensure header is set. default is 200 in Go stdlib + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + if r.immediateWrite { + return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re) + } + return io.Copy(io.Discard, re) +} + +var ( + _ io.ReadCloser = &lazyBodyCapturer{} + _ io.WriterTo = &lazyBodyCapturerWriteTo{} +) + +// lazyBodyCapturer represents a type that captures the request body lazily. +// It wraps an io.ReadCloser and provides methods for reading, closing, +// writing to an io.Writer, and capturing the body content. +type lazyBodyCapturer struct { + reader io.ReadCloser + capturedBody []byte + buf *bytes.Buffer + allRead bool +} + +// NewLazyBodyCapturer constructor. +func NewLazyBodyCapturer(body io.ReadCloser) XBodyCapturer { + c := &lazyBodyCapturer{ + reader: body, + buf: bytes.NewBuffer([]byte{}), + } + if _, ok := c.reader.(io.WriterTo); ok { + return &lazyBodyCapturerWriteTo{c} + } + return c +} + +func (m *lazyBodyCapturer) Read(p []byte) (int, error) { + if m.allRead { + return m.reader.Read(p) + } + n, err := io.TeeReader(m.reader, m.buf).Read(p) + if errors.Is(err, io.EOF) { + m.allRead = true + } + return n, err +} + +func (m *lazyBodyCapturer) Close() error { + return m.reader.Close() +} + +// Capture marks the body as fully captured. +// The captured body data can be read via GetBody. +func (m *lazyBodyCapturer) Capture() { + m.allRead = true + if !m.IsCaptured() { + m.capturedBody = m.buf.Bytes() + m.buf = nil + } + m.reader = io.NopCloser(bytes.NewReader(m.capturedBody)) +} + +// IsCaptured returns true when a body was captured. +func (m *lazyBodyCapturer) IsCaptured() bool { + return m.capturedBody != nil +} + +// GetBody returns the captured byte slice. Value is nil when not captured, yet. +func (m *lazyBodyCapturer) GetBody() []byte { + return m.capturedBody +} + +type lazyBodyCapturerWriteTo struct { + *lazyBodyCapturer +} + +func (m *lazyBodyCapturerWriteTo) WriteTo(w io.Writer) (int64, error) { + if m.allRead { + return m.reader.(io.WriterTo).WriteTo(w) + } + n, err := m.reader.(io.WriterTo).WriteTo(io.MultiWriter(w, m.buf)) + m.allRead = true + return n, err +} diff --git a/pkg/proxy/middleware_test.go b/pkg/proxy/middleware_test.go new file mode 100644 index 00000000..9c2d7bda --- /dev/null +++ b/pkg/proxy/middleware_test.go @@ -0,0 +1,204 @@ +package proxy + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServeHTTP(t *testing.T) { + myHeader := map[string][]string{"Foo": {"bar1", "bar2"}} + specs := map[string]struct { + context func() context.Context + maxRetries int + headers http.Header + respStatus int + expRetries int + }{ + "no retry on 200": { + context: func() context.Context { return context.TODO() }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusOK, + expRetries: 0, + }, + "no retry on 500": { + context: func() context.Context { return context.TODO() }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusInternalServerError, + expRetries: 0, + }, + "max retries on 503": { + context: func() context.Context { return context.TODO() }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusServiceUnavailable, + expRetries: 3, + }, + "max retries on 502": { + context: func() context.Context { return context.TODO() }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusBadGateway, + expRetries: 3, + }, + "not buffered on 100": { + context: func() context.Context { return context.TODO() }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusContinue, + expRetries: 0, + }, + "context cancelled": { + context: func() context.Context { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + return ctx + }, + headers: myHeader, + maxRetries: 3, + respStatus: http.StatusBadGateway, + expRetries: 0, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + counterBefore := counterValue(t, totalRetries) + const myBody = "my-request-body" + req, err := http.NewRequestWithContext(spec.context(), "GET", "/test", strings.NewReader(myBody)) + require.NoError(t, err) + req.Header = spec.headers.Clone() + + respRecorder := httptest.NewRecorder() + + var counter int + testBackend := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + counter++ + // return all headers + for k, vals := range r.Header { + for _, v := range vals { + w.Header().Add(k, v) + } + } + w.WriteHeader(spec.respStatus) + reqBody, err := io.ReadAll(req.Body) + require.NoError(t, err) + _, err = w.Write(append(reqBody, []byte(strconv.Itoa(spec.respStatus))...)) + require.NoError(t, err) + }) + + // when + middleware := NewRetryMiddleware(spec.maxRetries, testBackend) + middleware.ServeHTTP(respRecorder, req) + + // then + resp := respRecorder.Result() + require.Equal(t, spec.respStatus, resp.StatusCode) + assert.Equal(t, spec.expRetries, counter-1) + // and headers matches + assert.Equal(t, spec.headers, resp.Header) + // and body matches + bodyRead, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + assert.Equal(t, myBody+strconv.Itoa(spec.respStatus), string(bodyRead)) + // and prometheus metric updated + assert.Equal(t, spec.expRetries, int(counterValue(t, totalRetries)-counterBefore)) + }) + } +} + +func TestWriteDelegatorReadFrom(t *testing.T) { + const myTestContent = `my body content` + + // scenario: with non retry status code + rec := &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} + d := NewDiscardableResponseWriter(rec, func(int) bool { return false }) + // when + n, err := d.(io.ReaderFrom).ReadFrom(strings.NewReader(myTestContent)) + // then the content is written + require.NoError(t, err) + assert.Equal(t, len(myTestContent), int(n)) + assert.Equal(t, myTestContent, rec.Body.String()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) + + // scenario: with a retry status code + rec = &testResponseWriter{ResponseRecorder: httptest.NewRecorder()} + d = NewDiscardableResponseWriter(rec, func(int) bool { return true }) + // when + n, err = d.(io.ReaderFrom).ReadFrom(strings.NewReader(myTestContent)) + // then the content is not written + require.NoError(t, err) + assert.Equal(t, len(myTestContent), int(n)) + assert.Equal(t, "", rec.Body.String()) + assert.Equal(t, http.StatusOK, d.CapturedStatusCode()) + + // scenario: not implementing io.ReaderFrom + d = NewDiscardableResponseWriter(httptest.NewRecorder(), func(int) bool { return true }) + _, ok := d.(io.ReaderFrom) + require.False(t, ok) +} + +func TestLazyBodyCapturer(t *testing.T) { + const myTestContent = "my-test-content" + c := NewLazyBodyCapturer(io.NopCloser(strings.NewReader(myTestContent))) + var buf bytes.Buffer + n, err := c.(io.WriterTo).WriteTo(&buf) + require.NoError(t, err) + assert.Len(t, myTestContent, int(n)) + assert.Equal(t, myTestContent, buf.String()) + // when captured + c.Capture() + // then data is buffered for second read + buf.Reset() + n, err = c.(io.WriterTo).WriteTo(&buf) + require.NoError(t, err) + assert.Equal(t, len(myTestContent), 15) + assert.Equal(t, myTestContent, buf.String()) + + // scenario: source reader does not implement WriteTo + c = NewLazyBodyCapturer(testReader{strings.NewReader(myTestContent)}) + // then instance also does not implement it + _, ok := c.(io.WriterTo) + require.False(t, ok) +} + +func counterValue(t *testing.T, counter prometheus.Counter) float64 { + registry := prometheus.NewPedanticRegistry() + registry.MustRegister(counter) + gathered, err := registry.Gather() + require.NoError(t, err) + require.Len(t, gathered, 1) + require.Len(t, gathered[0].Metric, 1) + return gathered[0].Metric[0].GetCounter().GetValue() +} + +type testResponseWriter struct { + *httptest.ResponseRecorder +} + +func (r *testResponseWriter) ReadFrom(re io.Reader) (int64, error) { + return r.ResponseRecorder.Body.ReadFrom(re) +} + +type testReader struct { + r io.Reader +} + +func (t testReader) Read(p []byte) (n int, err error) { + return t.r.Read(p) +} + +func (t testReader) Close() error { + return nil +} diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 10fb5e45..003f976b 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "fmt" + "io" "log" "net/http" "net/http/httptest" @@ -43,17 +44,7 @@ func TestScaleUpAndDown(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -103,17 +94,7 @@ func TestHandleModelUndeployment(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -132,7 +113,7 @@ func TestHandleModelUndeployment(t *testing.T) { require.NoError(t, testK8sClient.Delete(testCtx, deploy)) // Check that the deployment was deleted - err = testK8sClient.Get(testCtx, client.ObjectKey{ + err := testK8sClient.Get(testCtx, client.ObjectKey{ Namespace: deploy.Namespace, Name: deploy.Name, }, deploy) @@ -151,6 +132,82 @@ func TestHandleModelUndeployment(t *testing.T) { wg.Wait() } +func TestRetryMiddleware(t *testing.T) { + const modelName = "test-model-c" + deploy := testDeployment(modelName) + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + backendRequests := &atomic.Int32{} + var serverCodes []int + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + i := backendRequests.Add(1) + code := serverCodes[i-1] + t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) + w.WriteHeader(code) + _, err := w.Write([]byte(strconv.Itoa(code))) + require.NoError(t, err) + })) + + // Mock an EndpointSlice. + withMockEndpointSlice(t, testBackend, modelName) + + specs := map[string]struct { + serverCodes []int + expResultCode int + expBackendHits int32 + }{ + "max retries - succeeds": { + serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 4, + }, + "max retries - fails": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, + expResultCode: http.StatusBadGateway, + expBackendHits: 4, + }, + "non retryable error code": { + serverCodes: []int{http.StatusNotImplemented}, + expResultCode: http.StatusNotImplemented, + expBackendHits: 1, + }, + "200 status code": { + serverCodes: []int{http.StatusOK}, + expResultCode: http.StatusOK, + expBackendHits: 1, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // setup + serverCodes = spec.serverCodes + backendRequests.Store(0) + + // when single request sent + gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode) + // then only the last body is written + assert.Equal(t, strconv.Itoa(spec.expResultCode), gotBody) + require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries") + }) + } +} + +func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) { + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) @@ -166,11 +223,13 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp } } -func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) { +func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) <-chan string { t.Helper() wg.Add(1) + bodyRespChan := make(chan string, 1) go func() { defer wg.Done() + defer close(bodyRespChan) body := []byte(fmt.Sprintf(`{"model": %q}`, modelName)) req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewReader(body)) @@ -179,7 +238,12 @@ func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int res, err := testHTTPClient.Do(req) require.NoError(t, err) require.Equal(t, expCode, res.StatusCode) + got, err := io.ReadAll(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + bodyRespChan <- string(got) }() + return bodyRespChan } func completeRequests(c chan struct{}, n int) { diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 74697559..e7f41684 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -81,9 +81,8 @@ func TestMain(m *testing.M) { const concurrencyPerReplica = 1 queueManager = queue.NewManager(concurrencyPerReplica) - endpointManager, err := endpoints.NewManager(mgr) + endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas) requireNoError(err) - endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas deploymentManager, err := deployments.NewManager(mgr) requireNoError(err) @@ -105,11 +104,7 @@ func TestMain(m *testing.M) { autoscaler.Endpoints = endpointManager go autoscaler.Start() - handler := &proxy.Handler{ - Deployments: deploymentManager, - Endpoints: endpointManager, - Queues: queueManager, - } + handler := proxy.NewHandler(deploymentManager, endpointManager, queueManager, 3) testServer = httptest.NewServer(handler) defer testServer.Close()