From 2ab82e034f5e89c8993faaeaa69d753203162aef Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Thu, 18 Jan 2024 11:22:36 -0500 Subject: [PATCH 01/13] Add proof of concepts in hack/ --- hack/failonceserver/main.go | 32 +++++++++++++++++++++++++ hack/failserver/main.go | 10 ++++++++ hack/retryproxy/main.go | 48 +++++++++++++++++++++++++++++++++++++ hack/successserver/main.go | 9 +++++++ 4 files changed, 99 insertions(+) create mode 100644 hack/failonceserver/main.go create mode 100644 hack/failserver/main.go create mode 100644 hack/retryproxy/main.go create mode 100644 hack/successserver/main.go diff --git a/hack/failonceserver/main.go b/hack/failonceserver/main.go new file mode 100644 index 00000000..e45588de --- /dev/null +++ b/hack/failonceserver/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "net/http" + "sync" +) + +func main() { + // HTTP server that fails once and then succeeds for a given request path + var mtx sync.RWMutex + paths := map[string]bool{} + + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mtx.RLock() + shouldSucceed := paths[r.URL.Path] + mtx.RUnlock() + + defer func() { + mtx.Lock() + paths[r.URL.Path] = true + mtx.Unlock() + }() + + if !shouldSucceed { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("failure\n")) + return + } + + w.Write([]byte("success\n")) + })) +} diff --git a/hack/failserver/main.go b/hack/failserver/main.go new file mode 100644 index 00000000..6d936d17 --- /dev/null +++ b/hack/failserver/main.go @@ -0,0 +1,10 @@ +package main + +import "net/http" + +func main() { + http.ListenAndServe(":8081", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("unavailable\n")) + })) +} diff --git a/hack/retryproxy/main.go b/hack/retryproxy/main.go new file mode 100644 index 00000000..d3e49ada --- /dev/null +++ b/hack/retryproxy/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "errors" + "log" + "net/http" + "net/http/httputil" + "net/url" +) + +func main() { + // go run ./hack/failserver + u1, err := url.Parse("http://localhost:8081") + if err != nil { + panic(err) + } + p1 := httputil.NewSingleHostReverseProxy(u1) + + // go run ./hack/successserver + u2, err := url.Parse("http://localhost:8082") + if err != nil { + panic(err) + } + p2 := httputil.NewSingleHostReverseProxy(u2) + + p1.ModifyResponse = func(r *http.Response) error { + if r.StatusCode == http.StatusServiceUnavailable { + // Returning an error will trigger the ErrorHandler. + return errRetry + } + return nil + } + + p1.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + if err == errRetry { + log.Println("retrying") + // Simulate calling the next backend. + p2.ServeHTTP(w, r) + return + } + } + + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p1.ServeHTTP(w, r) + })) +} + +var errRetry = errors.New("retry") diff --git a/hack/successserver/main.go b/hack/successserver/main.go new file mode 100644 index 00000000..2651dd95 --- /dev/null +++ b/hack/successserver/main.go @@ -0,0 +1,9 @@ +package main + +import "net/http" + +func main() { + http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("success\n")) + })) +} From 363aaa08d47537c473a256b6f963e6e736f6b0f5 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Thu, 18 Jan 2024 12:37:23 -0500 Subject: [PATCH 02/13] Copy request body --- hack/failserver/main.go | 7 +++- hack/retryproxy/main.go | 69 +++++++++++++++++++++++--------------- hack/successserver/main.go | 7 +++- 3 files changed, 54 insertions(+), 29 deletions(-) diff --git a/hack/failserver/main.go b/hack/failserver/main.go index 6d936d17..e0eaaa35 100644 --- a/hack/failserver/main.go +++ b/hack/failserver/main.go @@ -1,9 +1,14 @@ package main -import "net/http" +import ( + "io" + "net/http" + "os" +) func main() { http.ListenAndServe(":8081", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(os.Stdout, r.Body) w.WriteHeader(http.StatusServiceUnavailable) w.Write([]byte("unavailable\n")) })) diff --git a/hack/retryproxy/main.go b/hack/retryproxy/main.go index d3e49ada..64534549 100644 --- a/hack/retryproxy/main.go +++ b/hack/retryproxy/main.go @@ -1,7 +1,9 @@ package main import ( + "bytes" "errors" + "io" "log" "net/http" "net/http/httputil" @@ -9,40 +11,53 @@ import ( ) func main() { - // go run ./hack/failserver - u1, err := url.Parse("http://localhost:8081") - if err != nil { - panic(err) - } - p1 := httputil.NewSingleHostReverseProxy(u1) + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + r.Body.Close() - // go run ./hack/successserver - u2, err := url.Parse("http://localhost:8082") - if err != nil { - panic(err) - } - p2 := httputil.NewSingleHostReverseProxy(u2) + // go run ./hack/failserver + first := newReverseProxy("http://localhost:8081") - p1.ModifyResponse = func(r *http.Response) error { - if r.StatusCode == http.StatusServiceUnavailable { - // Returning an error will trigger the ErrorHandler. - return errRetry + first.ModifyResponse = func(r *http.Response) error { + if r.StatusCode == http.StatusServiceUnavailable { + // Returning an error will trigger the ErrorHandler. + return errRetry + } + return nil } - return nil - } - p1.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - if err == errRetry { - log.Println("retrying") - // Simulate calling the next backend. - p2.ServeHTTP(w, r) - return + first.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + if err == errRetry { + log.Println("retrying") + + // Simulate calling the next backend. + // go run ./hack/successserver + next := newReverseProxy("http://localhost:8082") + next.ServeHTTP(w, newReq(r, body)) + return + } } - } - http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - p1.ServeHTTP(w, r) + log.Println("serving") + first.ServeHTTP(w, newReq(r, body)) })) } var errRetry = errors.New("retry") + +func newReq(r *http.Request, body []byte) *http.Request { + clone := r.Clone(r.Context()) + clone.Body = io.NopCloser(bytes.NewReader(body)) + return clone +} + +func newReverseProxy(addr string) *httputil.ReverseProxy { + u, err := url.Parse(addr) + if err != nil { + panic(err) + } + return httputil.NewSingleHostReverseProxy(u) +} diff --git a/hack/successserver/main.go b/hack/successserver/main.go index 2651dd95..2d656a9c 100644 --- a/hack/successserver/main.go +++ b/hack/successserver/main.go @@ -1,9 +1,14 @@ package main -import "net/http" +import ( + "io" + "net/http" + "os" +) func main() { http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(os.Stdout, r.Body) w.Write([]byte("success\n")) })) } From 5f54718e8cffcedff3f9e1672da1af67bed7fe38 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Thu, 18 Jan 2024 13:04:39 -0500 Subject: [PATCH 03/13] Update retry proxy to include a recursive constructor --- hack/failonceserver/main.go | 6 ++- hack/retryproxy/main.go | 73 ++++++++++++++++++++++--------------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/hack/failonceserver/main.go b/hack/failonceserver/main.go index e45588de..ca7437a3 100644 --- a/hack/failonceserver/main.go +++ b/hack/failonceserver/main.go @@ -1,7 +1,9 @@ package main import ( + "io" "net/http" + "os" "sync" ) @@ -10,7 +12,9 @@ func main() { var mtx sync.RWMutex paths := map[string]bool{} - http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(os.Stdout, r.Body) + mtx.RLock() shouldSucceed := paths[r.URL.Path] mtx.RUnlock() diff --git a/hack/retryproxy/main.go b/hack/retryproxy/main.go index 64534549..a68f62fa 100644 --- a/hack/retryproxy/main.go +++ b/hack/retryproxy/main.go @@ -3,6 +3,7 @@ package main import ( "bytes" "errors" + "fmt" "io" "log" "net/http" @@ -12,52 +13,64 @@ import ( func main() { http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("serving") + body, err := io.ReadAll(r.Body) if err != nil { panic(err) } r.Body.Close() - // go run ./hack/failserver - first := newReverseProxy("http://localhost:8081") + fmt.Println("body:", string(body)) + + newProxy(body, 0).ServeHTTP(w, newRequest(r, body)) + })) + +} + +var errRetry = errors.New("retry") + +func newProxy(body []byte, attempt int) http.Handler { + // go run ./hack/failserver + u, err := url.Parse(getEndpoint(attempt)) + if err != nil { + panic(err) + } + proxy := httputil.NewSingleHostReverseProxy(u) - first.ModifyResponse = func(r *http.Response) error { - if r.StatusCode == http.StatusServiceUnavailable { - // Returning an error will trigger the ErrorHandler. - return errRetry - } - return nil + proxy.ModifyResponse = func(r *http.Response) error { + if r.StatusCode == http.StatusServiceUnavailable { + // Returning an error will trigger the ErrorHandler. + return errRetry } + return nil + } - first.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - if err == errRetry { - log.Println("retrying") + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + if err == errRetry { + log.Println("retrying") - // Simulate calling the next backend. - // go run ./hack/successserver - next := newReverseProxy("http://localhost:8082") - next.ServeHTTP(w, newReq(r, body)) - return - } + // Simulate calling the next backend. + // go run ./hack/successserver + newProxy(body, attempt+1).ServeHTTP(w, newRequest(r, body)) + return } + } - log.Println("serving") - first.ServeHTTP(w, newReq(r, body)) - })) + return proxy } -var errRetry = errors.New("retry") +func getEndpoint(attempt int) string { + switch attempt { + case 0: + return "http://localhost:8081" + default: + return "http://localhost:8082" + } +} -func newReq(r *http.Request, body []byte) *http.Request { +func newRequest(r *http.Request, body []byte) *http.Request { clone := r.Clone(r.Context()) clone.Body = io.NopCloser(bytes.NewReader(body)) return clone } - -func newReverseProxy(addr string) *httputil.ReverseProxy { - u, err := url.Parse(addr) - if err != nil { - panic(err) - } - return httputil.NewSingleHostReverseProxy(u) -} From a9bc8b12c6b49198b1799bf0ecc0763ab60c680a Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Tue, 23 Jan 2024 10:34:46 -0500 Subject: [PATCH 04/13] Add retries to handler --- Makefile | 2 +- hack/failonceserver/main.go | 3 + hack/retryproxy/main.go | 13 +- pkg/proxy/handler.go | 199 +++++++++--------- .../{metrics_test.go => handler_test.go} | 8 +- pkg/proxy/metrics.go | 32 --- pkg/proxy/request.go | 113 ++++++++++ 7 files changed, 232 insertions(+), 138 deletions(-) rename pkg/proxy/{metrics_test.go => handler_test.go} (95%) delete mode 100644 pkg/proxy/metrics.go create mode 100644 pkg/proxy/request.go diff --git a/Makefile b/Makefile index 1826ea5b..406b8e97 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ ENVTEST_K8S_VERSION = 1.27.1 .PHONY: test -test: test-unit test-race test-integration test-e2e +test: test-unit test-integration test-e2e .PHONY: test-unit test-unit: diff --git a/hack/failonceserver/main.go b/hack/failonceserver/main.go index ca7437a3..a063fb0d 100644 --- a/hack/failonceserver/main.go +++ b/hack/failonceserver/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "io" "net/http" "os" @@ -13,7 +14,9 @@ func main() { paths := map[string]bool{} http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Println(r.URL.Path) io.Copy(os.Stdout, r.Body) + fmt.Println("---") mtx.RLock() shouldSucceed := paths[r.URL.Path] diff --git a/hack/retryproxy/main.go b/hack/retryproxy/main.go index a68f62fa..bf98aa70 100644 --- a/hack/retryproxy/main.go +++ b/hack/retryproxy/main.go @@ -12,6 +12,8 @@ import ( ) func main() { + var maxRetries = 1 + http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("serving") @@ -23,14 +25,14 @@ func main() { fmt.Println("body:", string(body)) - newProxy(body, 0).ServeHTTP(w, newRequest(r, body)) + newProxy(body, 0, maxRetries).ServeHTTP(w, newRequest(r, body)) })) } var errRetry = errors.New("retry") -func newProxy(body []byte, attempt int) http.Handler { +func newProxy(body []byte, attempt, maxRetries int) http.Handler { // go run ./hack/failserver u, err := url.Parse(getEndpoint(attempt)) if err != nil { @@ -47,14 +49,17 @@ func newProxy(body []byte, attempt int) http.Handler { } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - if err == errRetry { + if err != nil && attempt < maxRetries { log.Println("retrying") // Simulate calling the next backend. // go run ./hack/successserver - newProxy(body, attempt+1).ServeHTTP(w, newRequest(r, body)) + newProxy(body, attempt+1, maxRetries).ServeHTTP(w, newRequest(r, body)) return } + + log.Printf("http: proxy error: %v", err) + w.WriteHeader(http.StatusBadGateway) } return proxy diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index a51d9d3b..9f8945d9 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -1,153 +1,126 @@ package proxy import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" - "io" "log" "net/http" "net/http/httputil" "net/url" - "strconv" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - "github.com/substratusai/lingo/pkg/deployments" "github.com/substratusai/lingo/pkg/endpoints" "github.com/substratusai/lingo/pkg/queue" ) +var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "http_response_time_seconds", + Help: "Duration of HTTP requests.", + Buckets: prometheus.DefBuckets, +}, []string{"model", "status_code"}) + +func MustRegister(r prometheus.Registerer) { + r.MustRegister(httpDuration) +} + // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { Deployments *deployments.Manager Endpoints *endpoints.Manager Queues *queue.Manager + MaxRetries int + RetryCodes map[int]struct{} +} + +func NewHandler( + deployments *deployments.Manager, + endpoints *endpoints.Manager, + queues *queue.Manager, +) *Handler { + return &Handler{ + Deployments: deployments, + Endpoints: endpoints, + Queues: queues, + } } -func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { - return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues} +var defaultRetryCodes = map[int]struct{}{ + http.StatusBadGateway: {}, + http.StatusServiceUnavailable: {}, + http.StatusGatewayTimeout: {}, } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var modelName string - captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w) - w = captureStatusRespWriter - timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { - httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v) - })) - defer timer.ObserveDuration() - - id := uuid.New().String() - log.Printf("request: %v", r.URL) + log.Printf("url: %v", r.URL) + w.Header().Set("X-Proxy", "lingo") - var ( - proxyRequest *http.Request - err error - ) + pr := newProxyRequest(r) + defer pr.done() + // TODO: Only parse model for paths that would have a model. - modelName, proxyRequest, err = parseModel(r) - if err != nil || modelName == "" { - modelName = "unknown" - log.Printf("error reading model from request body: %v", err) - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Bad request: unable to parse .model from JSON payload")) + if err := pr.parseModel(); err != nil { + pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err) return } - log.Println("model:", modelName) - deploy, found := h.Deployments.ResolveDeployment(modelName) - if !found { - log.Printf("deployment not found for model: %v", err) - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName))) + log.Println("model:", pr.model) + + var backendExists bool + pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model) + if !backendExists { + pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.model) return } - h.Deployments.AtLeastOne(deploy) + // Ensure the backend is scaled to at least one Pod. + h.Deployments.AtLeastOne(pr.backendDeployment) + + log.Printf("Entering queue: %v", pr.id) - log.Println("Entering queue", id) - complete := h.Queues.EnqueueAndWait(r.Context(), deploy, id) - log.Println("Admitted into queue", id) + complete := h.Queues.EnqueueAndWait(r.Context(), pr.backendDeployment, pr.id) defer complete() - // abort when deployment was removed meanwhile - if _, exists := h.Deployments.ResolveDeployment(modelName); !exists { - log.Printf("deployment not active for model removed: %v", err) - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName))) + log.Printf("Admitted into queue: %v", pr.id) + + // After waiting for the request to be admitted, double check that the model + // still exists. It's possible that the model was deleted while waiting. + // This would lead to a long subequent wait with the host lookup. + pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model) + if !backendExists { + pr.sendErrorResponse(w, http.StatusNotFound, "model not found after being dequeued: %v", pr.model) return } - log.Println("Waiting for IPs", id) - host, err := h.Endpoints.AwaitHostAddress(r.Context(), deploy, "http") + h.proxyHTTP(w, pr) +} + +// AdditionalProxyRewrite is an injection point for modifying proxy requests. +// Used in tests. +var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {} + +func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { + log.Printf("Waiting for host: %v", pr.id) + + host, err := h.Endpoints.AwaitHostAddress(pr.r.Context(), pr.backendDeployment, "http") if err != nil { - log.Printf("error while finding the host address %v", err) switch { case errors.Is(err, context.Canceled): - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Request cancelled")) + pr.sendErrorResponse(w, http.StatusInternalServerError, "request cancelled while finding host: %v", err) return case errors.Is(err, context.DeadlineExceeded): - w.WriteHeader(http.StatusGatewayTimeout) - _, _ = w.Write([]byte(fmt.Sprintf("Request timed out for model: %v", modelName))) + pr.sendErrorResponse(w, http.StatusGatewayTimeout, "request timeout while finding host: %v", err) return default: - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte("Internal server error")) + pr.sendErrorResponse(w, http.StatusGatewayTimeout, "unable to find host: %v", err) return } } - log.Printf("Got host: %v, id: %v\n", host, id) - - // TODO: Avoid creating new reverse proxies for each request. - // TODO: Consider implementing a round robin scheme. - log.Printf("Proxying request to host %v: %v\n", host, id) - newReverseProxy(host).ServeHTTP(w, proxyRequest) -} - -// parseModel parses the model name from the request -// returns empty string when none found or an error for failures on the proxy request object -func parseModel(r *http.Request) (string, *http.Request, error) { - if model := r.Header.Get("X-Model"); model != "" { - return model, r, nil - } - // parse request body for model name, ignore errors - body, err := io.ReadAll(r.Body) - if err != nil { - return "", r, nil - } - - var payload struct { - Model string `json:"model"` - } - var model string - if err := json.Unmarshal(body, &payload); err == nil { - model = payload.Model - } - - // create new request object - proxyReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), bytes.NewReader(body)) - if err != nil { - return "", nil, fmt.Errorf("create proxy request: %w", err) - } - proxyReq.Header = r.Header - if err := proxyReq.ParseForm(); err != nil { - return "", nil, fmt.Errorf("parse proxy form: %w", err) - } - return model, proxyReq, nil -} -// AdditionalProxyRewrite is an injection point for modifying proxy requests. -// Used in tests. -var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {} + log.Printf("Got host: %v, id: %v\n", host, pr.id) -func newReverseProxy(host string) *httputil.ReverseProxy { proxy := &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(&url.URL{ @@ -158,5 +131,37 @@ func newReverseProxy(host string) *httputil.ReverseProxy { AdditionalProxyRewrite(r) }, } - return proxy + + proxy.ModifyResponse = func(r *http.Response) error { + if h.isRetryCode(r.StatusCode) { + // Returning an error will trigger the ErrorHandler. + return errors.New("retry") + } + return nil + } + + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + if err != nil && pr.attempt < h.MaxRetries { + log.Println("retrying") + pr.attempt++ + h.proxyHTTP(w, pr) + return + } + + pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries) + } + + log.Printf("Proxying request to host %v: %v\n", host, pr.id) + proxy.ServeHTTP(w, pr.httpRequest()) +} + +func (h *Handler) isRetryCode(status int) bool { + var retry bool + // TODO: avoid the nil check here and set a default map in the constructor. + if h.RetryCodes != nil { + _, retry = h.RetryCodes[status] + } else { + _, retry = defaultRetryCodes[status] + } + return retry } diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/handler_test.go similarity index 95% rename from pkg/proxy/metrics_test.go rename to pkg/proxy/handler_test.go index 5940e445..2b622810 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/handler_test.go @@ -9,7 +9,7 @@ import ( "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" - "github.com/prometheus/client_model/go" + io_prometheus_client "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substratusai/lingo/pkg/deployments" @@ -30,7 +30,7 @@ func TestMetrics(t *testing.T) { expCode int expLabels map[string]string }{ - "with mode name": { + "with model name": { request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), expCode: http.StatusNotFound, expLabels: map[string]string{ @@ -38,11 +38,11 @@ func TestMetrics(t *testing.T) { "status_code": "404", }, }, - "unknown model name": { + "empty model name": { request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader("{}")), expCode: http.StatusBadRequest, expLabels: map[string]string{ - "model": "unknown", + "model": "", "status_code": "400", }, }, diff --git a/pkg/proxy/metrics.go b/pkg/proxy/metrics.go deleted file mode 100644 index 5cd229e7..00000000 --- a/pkg/proxy/metrics.go +++ /dev/null @@ -1,32 +0,0 @@ -package proxy - -import ( - "net/http" - - "github.com/prometheus/client_golang/prometheus" -) - -var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "http_response_time_seconds", - Help: "Duration of HTTP requests.", - Buckets: prometheus.DefBuckets, -}, []string{"model", "status_code"}) - -func MustRegister(r prometheus.Registerer) { - r.MustRegister(httpDuration) -} - -// captureStatusResponseWriter is a custom HTTP response writer that captures the status code. -type captureStatusResponseWriter struct { - http.ResponseWriter - statusCode int -} - -func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter { - return &captureStatusResponseWriter{ResponseWriter: responseWriter} -} - -func (srw *captureStatusResponseWriter) WriteHeader(code int) { - srw.statusCode = code - srw.ResponseWriter.WriteHeader(code) -} diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go new file mode 100644 index 00000000..5f1e82ab --- /dev/null +++ b/pkg/proxy/request.go @@ -0,0 +1,113 @@ +package proxy + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strconv" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" +) + +// proxyRequest keeps track of the state of a request that is to be proxied. +type proxyRequest struct { + // r is the original request. It is stored here so that is can be cloned + // and sent to the backend while preserving the original request body. + r *http.Request + // body will be stored here if the request body needed to be read + // in order to determine the model. + body []byte + + // metadata: + + id string + status int + model string + backendDeployment string + attempt int + + // metrics: + + timer *prometheus.Timer +} + +func newProxyRequest(r *http.Request) *proxyRequest { + pr := &proxyRequest{ + r: r, + id: uuid.New().String(), + status: http.StatusOK, + } + + pr.timer = prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) { + httpDuration.WithLabelValues(pr.model, strconv.Itoa(pr.status)).Observe(v) + })) + + return pr + +} + +func (p *proxyRequest) done() { + p.timer.ObserveDuration() +} + +func (pr *proxyRequest) parseModel() error { + pr.model = pr.r.Header.Get("X-Model") + if pr.model != "" { + return nil + } + + var err error + pr.body, err = io.ReadAll(pr.r.Body) + if err != nil { + return fmt.Errorf("read: %w", err) + } + + var payload struct { + Model string `json:"model"` + } + if err := json.Unmarshal(pr.body, &payload); err != nil { + return fmt.Errorf("unmarshal json: %w", err) + } + pr.model = payload.Model + + if pr.model == "" { + return fmt.Errorf("no model specified") + } + + return nil +} + +func (pr *proxyRequest) sendErrorResponse(w http.ResponseWriter, status int, format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.Printf("sending error response: %v: %v", status, msg) + + pr.setStatus(w, status) + + if status >= 500 { + // Don't leak internal error messages to the client. + msg = http.StatusText(status) + } + + if err := json.NewEncoder(w).Encode(struct { + Error string `json:"error"` + }{ + Error: msg, + }); err != nil { + log.Printf("error encoding error response: %v", err) + } +} + +func (pr *proxyRequest) setStatus(w http.ResponseWriter, code int) { + pr.status = code + w.WriteHeader(code) +} + +func (pr *proxyRequest) httpRequest() *http.Request { + clone := pr.r.Clone(pr.r.Context()) + clone.Body = io.NopCloser(bytes.NewReader(pr.body)) + return clone +} From e3e7b7db1629a121ef43f1959f9d1f2ffddf1c07 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Thu, 1 Feb 2024 11:19:57 -0500 Subject: [PATCH 05/13] Fix recorded and returned response codes --- pkg/proxy/handler.go | 41 +++++--- pkg/proxy/handler_test.go | 208 +++++++++++++++++++++++++++----------- 2 files changed, 178 insertions(+), 71 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index 9f8945d9..dbd8d264 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -9,9 +9,6 @@ import ( "net/url" "github.com/prometheus/client_golang/prometheus" - "github.com/substratusai/lingo/pkg/deployments" - "github.com/substratusai/lingo/pkg/endpoints" - "github.com/substratusai/lingo/pkg/queue" ) var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ @@ -24,20 +21,34 @@ func MustRegister(r prometheus.Registerer) { r.MustRegister(httpDuration) } +type DeploymentManager interface { + ResolveDeployment(model string) (string, bool) + AtLeastOne(model string) +} + +type EndpointManager interface { + AwaitHostAddress(ctx context.Context, service, portName string) (string, error) +} + +type QueueManager interface { + EnqueueAndWait(ctx context.Context, deploymentName, id string) func() +} + // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - Deployments *deployments.Manager - Endpoints *endpoints.Manager - Queues *queue.Manager - MaxRetries int - RetryCodes map[int]struct{} + Deployments DeploymentManager + Endpoints EndpointManager + Queues QueueManager + + MaxRetries int + RetryCodes map[int]struct{} } func NewHandler( - deployments *deployments.Manager, - endpoints *endpoints.Manager, - queues *queue.Manager, + deployments DeploymentManager, + endpoints EndpointManager, + queues QueueManager, ) *Handler { return &Handler{ Deployments: deployments, @@ -47,9 +58,10 @@ func NewHandler( } var defaultRetryCodes = map[int]struct{}{ - http.StatusBadGateway: {}, - http.StatusServiceUnavailable: {}, - http.StatusGatewayTimeout: {}, + http.StatusInternalServerError: {}, + http.StatusBadGateway: {}, + http.StatusServiceUnavailable: {}, + http.StatusGatewayTimeout: {}, } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -133,6 +145,7 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { } proxy.ModifyResponse = func(r *http.Response) error { + pr.status = r.StatusCode if h.isRetryCode(r.StatusCode) { // Returning an error will trigger the ErrorHandler. return errors.New("retry") diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index 2b622810..fca6a205 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -1,69 +1,157 @@ package proxy import ( + "context" + "fmt" + "io" "net/http" "net/http/httptest" "strings" "testing" - "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/testutil" io_prometheus_client "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/substratusai/lingo/pkg/deployments" - "k8s.io/apimachinery/pkg/runtime" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" - - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/cache" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/config" - "sigs.k8s.io/controller-runtime/pkg/manager" ) -func TestMetrics(t *testing.T) { +func TestHandler(t *testing.T) { + const ( + model1 = "model1" + model2 = "model2" + + maxRetries = 3 + ) + models := map[string]string{ + model1: "deploy1", + model2: "deploy2", + } + specs := map[string]struct { - request *http.Request - expCode int - expLabels map[string]string + reqMethod string + reqPath string + reqBody string + + backendCode int + backendBody string + + expCode int + expBody string + expLabels map[string]string + expBackendRequestCount int }{ - "with model name": { - request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)), - expCode: http.StatusNotFound, + "no model": { + reqMethod: http.MethodPost, + reqPath: "/", + reqBody: "{}", + expCode: http.StatusBadRequest, + expBody: `{"error":"unable to parse model: no model specified"}` + "\n", expLabels: map[string]string{ - "model": "my_model", + "model": "", + "status_code": "400", + }, + expBackendRequestCount: 0, + }, + "model not found": { + reqMethod: http.MethodPost, + reqPath: "/", + reqBody: `{"model":"does-not-exist"}`, + expCode: http.StatusNotFound, + expBody: `{"error":"model not found: does-not-exist"}` + "\n", + expLabels: map[string]string{ + "model": "does-not-exist", "status_code": "404", }, + expBackendRequestCount: 0, }, - "empty model name": { - request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader("{}")), - expCode: http.StatusBadRequest, + "happy 200": { + reqMethod: http.MethodPost, + reqPath: "/", + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusOK, + backendBody: `{"result":"ok"}`, + expCode: http.StatusOK, + expBody: `{"result":"ok"}`, expLabels: map[string]string{ - "model": "", + "model": model1, + "status_code": "200", + }, + expBackendRequestCount: 1, + }, + "retryable 500": { + reqMethod: http.MethodPost, + reqPath: "/", + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusInternalServerError, + backendBody: `{"err":"oh no!"}`, + expCode: http.StatusBadGateway, + expBody: `{"error":"Bad Gateway"}` + "\n", + expLabels: map[string]string{ + "model": model1, + "status_code": "502", + }, + expBackendRequestCount: 1 + maxRetries, + }, + "not retryable 400": { + reqMethod: http.MethodPost, + reqPath: "/", + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendCode: http.StatusBadRequest, + backendBody: `{"err":"bad request"}`, + expCode: http.StatusBadRequest, + expBody: `{"err":"bad request"}`, + expLabels: map[string]string{ + "model": model1, "status_code": "400", }, + expBackendRequestCount: 1, }, } for name, spec := range specs { t.Run(name, func(t *testing.T) { + // Register metrics from a clean slate. httpDuration.Reset() - registry := prometheus.NewPedanticRegistry() - MustRegister(registry) - - deplMgr, err := deployments.NewManager(&fakeManager{}) + metricsRegistry := prometheus.NewPedanticRegistry() + MustRegister(metricsRegistry) + + // Mock backend. + var backendRequestCount int + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendRequestCount++ + if spec.backendCode != 0 { + w.WriteHeader(spec.backendCode) + } + if spec.backendBody != "" { + _, _ = w.Write([]byte(spec.backendBody)) + } + })) + + // Setup handler. + deploys := &testDeploymentManager{models: models} + endpoints := &testEndpointManager{address: backend.Listener.Addr().String()} + queues := &testQueueManager{} + h := NewHandler(deploys, endpoints, queues) + h.MaxRetries = maxRetries + server := httptest.NewServer(h) + + // Issue request. + client := &http.Client{} + req, err := http.NewRequest(spec.reqMethod, server.URL+spec.reqPath, strings.NewReader(spec.reqBody)) + require.NoError(t, err) + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) - h := NewHandler(deplMgr, nil, nil) - recorder := httptest.NewRecorder() - // when - h.ServeHTTP(recorder, spec.request) + // Assert on response. + assert.Equal(t, spec.expCode, resp.StatusCode) + assert.Equal(t, spec.expBody, string(respBody)) + assert.Equal(t, spec.expBackendRequestCount, backendRequestCount) - // then - assert.Equal(t, spec.expCode, recorder.Code) - gathered, err := registry.Gather() + // Assert on metrics. + gathered, err := metricsRegistry.Gather() require.NoError(t, err) require.Len(t, gathered, 1) require.Len(t, gathered[0].Metric, 1) @@ -82,41 +170,47 @@ func TestMetricsViaLinter(t *testing.T) { require.Empty(t, problems) } -func toMap(s []*io_prometheus_client.LabelPair) map[string]string { - r := make(map[string]string, len(s)) - for _, v := range s { - r[v.GetName()] = v.GetValue() - } - return r +type testDeploymentManager struct { + models map[string]string } -// for test setup only -type fakeManager struct { - ctrl.Manager +func (t *testDeploymentManager) ResolveDeployment(model string) (string, bool) { + deploy, ok := t.models[model] + return deploy, ok } -func (m *fakeManager) GetCache() cache.Cache { - return nil +func (t *testDeploymentManager) AtLeastOne(model string) { + } -func (m *fakeManager) GetScheme() *runtime.Scheme { - s := runtime.NewScheme() - utilruntime.Must(clientgoscheme.AddToScheme(s)) - return s +type testEndpointManager struct { + address string + + requestedService string + requestedPort string } -func (m *fakeManager) Add(_ manager.Runnable) error { - return nil +func (t *testEndpointManager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) { + t.requestedService = service + t.requestedPort = portName + return t.address, nil } -func (m *fakeManager) GetLogger() logr.Logger { - return logr.Discard() +type testQueueManager struct { + requestedDeploymentName string + requestedID string } -func (m *fakeManager) GetControllerOptions() config.Controller { - return config.Controller{} +func (t *testQueueManager) EnqueueAndWait(ctx context.Context, deploymentName, id string) func() { + t.requestedDeploymentName = deploymentName + t.requestedID = id + return func() {} } -func (m *fakeManager) GetClient() client.Client { - return nil +func toMap(s []*io_prometheus_client.LabelPair) map[string]string { + r := make(map[string]string, len(s)) + for _, v := range s { + r[v.GetName()] = v.GetValue() + } + return r } From 5c523b9ef90231d7376c8df9b8be4729c3b5b00d Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 3 Feb 2024 09:53:09 -0500 Subject: [PATCH 06/13] Pass retry config through in main --- cmd/lingo/main.go | 2 ++ pkg/proxy/handler.go | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 7bfec8f0..d279c898 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -67,6 +67,7 @@ func run() error { concurrency := getEnvInt("CONCURRENCY", 100) scaleDownDelay := getEnvInt("SCALE_DOWN_DELAY", 30) + backendRetries := getEnvInt("BACKEND_RETRIES", 1) var metricsAddr string var probeAddr string @@ -154,6 +155,7 @@ func run() error { proxy.MustRegister(metricsRegistry) proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager) + proxyHandler.MaxRetries = backendRetries proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index dbd8d264..2543e985 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -92,6 +92,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("Entering queue: %v", pr.id) + // Wait to until the request is admitted into the queue before proceeding with + // serving the request. complete := h.Queues.EnqueueAndWait(r.Context(), pr.backendDeployment, pr.id) defer complete() @@ -145,18 +147,22 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { } proxy.ModifyResponse = func(r *http.Response) error { + // Record the response for metrics. pr.status = r.StatusCode + if h.isRetryCode(r.StatusCode) { // Returning an error will trigger the ErrorHandler. return errors.New("retry") } + return nil } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { if err != nil && pr.attempt < h.MaxRetries { - log.Println("retrying") pr.attempt++ + + log.Printf("Retrying request (%v/%v): %v", pr.attempt, h.MaxRetries, pr.id) h.proxyHTTP(w, pr) return } From f7bbd289106316b75026c478ae067b0ecbbce899 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 3 Feb 2024 10:39:13 -0500 Subject: [PATCH 07/13] Add test cases --- pkg/proxy/handler.go | 14 +++++-- pkg/proxy/handler_test.go | 88 ++++++++++++++++++++++++++------------- pkg/proxy/request.go | 4 +- 3 files changed, 72 insertions(+), 34 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index 2543e985..b5554455 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -150,15 +150,19 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { // Record the response for metrics. pr.status = r.StatusCode - if h.isRetryCode(r.StatusCode) { + // This point is reached if a response code is received. + if h.isRetryCode(r.StatusCode) && pr.attempt < h.MaxRetries { // Returning an error will trigger the ErrorHandler. - return errors.New("retry") + return ErrRetry } return nil } proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + // This point could be reached if a bad response code was sent by the backend + // or + // if there was an issue with the connection and no response was ever received. if err != nil && pr.attempt < h.MaxRetries { pr.attempt++ @@ -167,13 +171,17 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { return } - pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries) + if !errors.Is(err, ErrRetry) { + pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries) + } } log.Printf("Proxying request to host %v: %v\n", host, pr.id) proxy.ServeHTTP(w, pr.httpRequest()) } +var ErrRetry = errors.New("retry") + func (h *Handler) isRetryCode(status int) bool { var retry bool // TODO: avoid the nil check here and set a default map in the constructor. diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index fca6a205..e4de6793 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -29,12 +29,12 @@ func TestHandler(t *testing.T) { } specs := map[string]struct { - reqMethod string - reqPath string - reqBody string + reqBody string + reqHeaders map[string]string - backendCode int - backendBody string + backendPanic bool + backendCode int + backendBody string expCode int expBody string @@ -42,11 +42,9 @@ func TestHandler(t *testing.T) { expBackendRequestCount int }{ "no model": { - reqMethod: http.MethodPost, - reqPath: "/", - reqBody: "{}", - expCode: http.StatusBadRequest, - expBody: `{"error":"unable to parse model: no model specified"}` + "\n", + reqBody: "{}", + expCode: http.StatusBadRequest, + expBody: `{"error":"unable to parse model: no model specified"}` + "\n", expLabels: map[string]string{ "model": "", "status_code": "400", @@ -54,20 +52,16 @@ func TestHandler(t *testing.T) { expBackendRequestCount: 0, }, "model not found": { - reqMethod: http.MethodPost, - reqPath: "/", - reqBody: `{"model":"does-not-exist"}`, - expCode: http.StatusNotFound, - expBody: `{"error":"model not found: does-not-exist"}` + "\n", + reqBody: `{"model":"does-not-exist"}`, + expCode: http.StatusNotFound, + expBody: `{"error":"model not found: does-not-exist"}` + "\n", expLabels: map[string]string{ "model": "does-not-exist", "status_code": "404", }, expBackendRequestCount: 0, }, - "happy 200": { - reqMethod: http.MethodPost, - reqPath: "/", + "happy 200 model in body": { reqBody: fmt.Sprintf(`{"model":%q}`, model1), backendCode: http.StatusOK, backendBody: `{"result":"ok"}`, @@ -79,23 +73,32 @@ func TestHandler(t *testing.T) { }, expBackendRequestCount: 1, }, + "happy 200 model in header": { + reqBody: "{}", + reqHeaders: map[string]string{"X-Model": model1}, + backendCode: http.StatusOK, + backendBody: `{"result":"ok"}`, + expCode: http.StatusOK, + expBody: `{"result":"ok"}`, + expLabels: map[string]string{ + "model": model1, + "status_code": "200", + }, + expBackendRequestCount: 1, + }, "retryable 500": { - reqMethod: http.MethodPost, - reqPath: "/", reqBody: fmt.Sprintf(`{"model":%q}`, model1), backendCode: http.StatusInternalServerError, backendBody: `{"err":"oh no!"}`, - expCode: http.StatusBadGateway, - expBody: `{"error":"Bad Gateway"}` + "\n", + expCode: http.StatusInternalServerError, + expBody: `{"err":"oh no!"}`, expLabels: map[string]string{ "model": model1, - "status_code": "502", + "status_code": "500", }, expBackendRequestCount: 1 + maxRetries, }, "not retryable 400": { - reqMethod: http.MethodPost, - reqPath: "/", reqBody: fmt.Sprintf(`{"model":%q}`, model1), backendCode: http.StatusBadRequest, backendBody: `{"err":"bad request"}`, @@ -107,6 +110,17 @@ func TestHandler(t *testing.T) { }, expBackendRequestCount: 1, }, + "good request but dropped connection": { + reqBody: fmt.Sprintf(`{"model":%q}`, model1), + backendPanic: true, + expCode: http.StatusBadGateway, + expBody: `{"error":"Bad Gateway"}` + "\n", + expLabels: map[string]string{ + "model": model1, + "status_code": "502", + }, + expBackendRequestCount: 1 + maxRetries, + }, } for name, spec := range specs { t.Run(name, func(t *testing.T) { @@ -119,6 +133,17 @@ func TestHandler(t *testing.T) { var backendRequestCount int backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendRequestCount++ + + bdy, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, spec.reqBody, string(bdy), "The request body should reach the backend") + + if spec.backendPanic { + // Panic should close connection. + // https://pkg.go.dev/net/http#Handler + panic("panicing on purpose") + } + if spec.backendCode != 0 { w.WriteHeader(spec.backendCode) } @@ -137,18 +162,21 @@ func TestHandler(t *testing.T) { // Issue request. client := &http.Client{} - req, err := http.NewRequest(spec.reqMethod, server.URL+spec.reqPath, strings.NewReader(spec.reqBody)) + req, err := http.NewRequest(http.MethodPost, server.URL, strings.NewReader(spec.reqBody)) require.NoError(t, err) + for k, v := range spec.reqHeaders { + req.Header.Add(k, v) + } resp, err := client.Do(req) - require.NoError(t, err) + require.NoError(t, err, "The client request should not fail") defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) // Assert on response. - assert.Equal(t, spec.expCode, resp.StatusCode) - assert.Equal(t, spec.expBody, string(respBody)) - assert.Equal(t, spec.expBackendRequestCount, backendRequestCount) + assert.Equal(t, spec.expCode, resp.StatusCode, "Unexpected response code to client") + assert.Equal(t, spec.expBody, string(respBody), "Unexpected response body to client") + assert.Equal(t, spec.expBackendRequestCount, backendRequestCount, "Unexpected number of requests sent to backend") // Assert on metrics. gathered, err := metricsRegistry.Gather() diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go index 5f1e82ab..beb0216b 100644 --- a/pkg/proxy/request.go +++ b/pkg/proxy/request.go @@ -108,6 +108,8 @@ func (pr *proxyRequest) setStatus(w http.ResponseWriter, code int) { func (pr *proxyRequest) httpRequest() *http.Request { clone := pr.r.Clone(pr.r.Context()) - clone.Body = io.NopCloser(bytes.NewReader(pr.body)) + if pr.body != nil { + clone.Body = io.NopCloser(bytes.NewReader(pr.body)) + } return clone } From 09f6e28822a7fd2eb18e3975ffa21fb4f6ce4e9a Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 3 Feb 2024 10:41:33 -0500 Subject: [PATCH 08/13] Remove hack/ examples --- hack/failonceserver/main.go | 39 ------------------ hack/failserver/main.go | 15 ------- hack/retryproxy/main.go | 81 ------------------------------------- hack/successserver/main.go | 14 ------- 4 files changed, 149 deletions(-) delete mode 100644 hack/failonceserver/main.go delete mode 100644 hack/failserver/main.go delete mode 100644 hack/retryproxy/main.go delete mode 100644 hack/successserver/main.go diff --git a/hack/failonceserver/main.go b/hack/failonceserver/main.go deleted file mode 100644 index a063fb0d..00000000 --- a/hack/failonceserver/main.go +++ /dev/null @@ -1,39 +0,0 @@ -package main - -import ( - "fmt" - "io" - "net/http" - "os" - "sync" -) - -func main() { - // HTTP server that fails once and then succeeds for a given request path - var mtx sync.RWMutex - paths := map[string]bool{} - - http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Println(r.URL.Path) - io.Copy(os.Stdout, r.Body) - fmt.Println("---") - - mtx.RLock() - shouldSucceed := paths[r.URL.Path] - mtx.RUnlock() - - defer func() { - mtx.Lock() - paths[r.URL.Path] = true - mtx.Unlock() - }() - - if !shouldSucceed { - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte("failure\n")) - return - } - - w.Write([]byte("success\n")) - })) -} diff --git a/hack/failserver/main.go b/hack/failserver/main.go deleted file mode 100644 index e0eaaa35..00000000 --- a/hack/failserver/main.go +++ /dev/null @@ -1,15 +0,0 @@ -package main - -import ( - "io" - "net/http" - "os" -) - -func main() { - http.ListenAndServe(":8081", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(os.Stdout, r.Body) - w.WriteHeader(http.StatusServiceUnavailable) - w.Write([]byte("unavailable\n")) - })) -} diff --git a/hack/retryproxy/main.go b/hack/retryproxy/main.go deleted file mode 100644 index bf98aa70..00000000 --- a/hack/retryproxy/main.go +++ /dev/null @@ -1,81 +0,0 @@ -package main - -import ( - "bytes" - "errors" - "fmt" - "io" - "log" - "net/http" - "net/http/httputil" - "net/url" -) - -func main() { - var maxRetries = 1 - - http.ListenAndServe(":8080", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Println("serving") - - body, err := io.ReadAll(r.Body) - if err != nil { - panic(err) - } - r.Body.Close() - - fmt.Println("body:", string(body)) - - newProxy(body, 0, maxRetries).ServeHTTP(w, newRequest(r, body)) - })) - -} - -var errRetry = errors.New("retry") - -func newProxy(body []byte, attempt, maxRetries int) http.Handler { - // go run ./hack/failserver - u, err := url.Parse(getEndpoint(attempt)) - if err != nil { - panic(err) - } - proxy := httputil.NewSingleHostReverseProxy(u) - - proxy.ModifyResponse = func(r *http.Response) error { - if r.StatusCode == http.StatusServiceUnavailable { - // Returning an error will trigger the ErrorHandler. - return errRetry - } - return nil - } - - proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - if err != nil && attempt < maxRetries { - log.Println("retrying") - - // Simulate calling the next backend. - // go run ./hack/successserver - newProxy(body, attempt+1, maxRetries).ServeHTTP(w, newRequest(r, body)) - return - } - - log.Printf("http: proxy error: %v", err) - w.WriteHeader(http.StatusBadGateway) - } - - return proxy -} - -func getEndpoint(attempt int) string { - switch attempt { - case 0: - return "http://localhost:8081" - default: - return "http://localhost:8082" - } -} - -func newRequest(r *http.Request, body []byte) *http.Request { - clone := r.Clone(r.Context()) - clone.Body = io.NopCloser(bytes.NewReader(body)) - return clone -} diff --git a/hack/successserver/main.go b/hack/successserver/main.go deleted file mode 100644 index 2d656a9c..00000000 --- a/hack/successserver/main.go +++ /dev/null @@ -1,14 +0,0 @@ -package main - -import ( - "io" - "net/http" - "os" -) - -func main() { - http.ListenAndServe(":8082", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.Copy(os.Stdout, r.Body) - w.Write([]byte("success\n")) - })) -} From b09fd70d2bf78770de191afee5888dbbccef09f7 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 3 Feb 2024 10:47:05 -0500 Subject: [PATCH 09/13] Add comments --- pkg/proxy/request.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go index beb0216b..5817faca 100644 --- a/pkg/proxy/request.go +++ b/pkg/proxy/request.go @@ -50,10 +50,15 @@ func newProxyRequest(r *http.Request) *proxyRequest { } +// done should be called when the original client request is complete. func (p *proxyRequest) done() { p.timer.ObserveDuration() } +// parseModel attempts to determine the model from the request. +// It first checks the "X-Model" header, and if that is not set, it +// attempts to unmarshal the request body as JSON and extract the +// .model field. func (pr *proxyRequest) parseModel() error { pr.model = pr.r.Header.Get("X-Model") if pr.model != "" { @@ -81,6 +86,9 @@ func (pr *proxyRequest) parseModel() error { return nil } +// sendErrorResponse sends an error response to the client and +// records the status code. If the status code is 5xx, the error +// message is not included in the response body. func (pr *proxyRequest) sendErrorResponse(w http.ResponseWriter, status int, format string, args ...interface{}) { msg := fmt.Sprintf(format, args...) log.Printf("sending error response: %v: %v", status, msg) @@ -106,6 +114,9 @@ func (pr *proxyRequest) setStatus(w http.ResponseWriter, code int) { w.WriteHeader(code) } +// httpRequest returns a new http.Request that is a clone of the original +// request, preserving the original request body even if it was already +// read (i.e. if the body was inspected to determine the model). func (pr *proxyRequest) httpRequest() *http.Request { clone := pr.r.Clone(pr.r.Context()) if pr.body != nil { From 0cf286608113b5ca2c7eff2cdf6f1be248403f9b Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 3 Feb 2024 10:50:02 -0500 Subject: [PATCH 10/13] Add assertion for host requests --- pkg/proxy/handler_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index e4de6793..4d15d778 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -177,6 +177,7 @@ func TestHandler(t *testing.T) { assert.Equal(t, spec.expCode, resp.StatusCode, "Unexpected response code to client") assert.Equal(t, spec.expBody, string(respBody), "Unexpected response body to client") assert.Equal(t, spec.expBackendRequestCount, backendRequestCount, "Unexpected number of requests sent to backend") + assert.Equal(t, spec.expBackendRequestCount, endpoints.hostRequestCount, "Unexpected number of requests for backend hosts") // Assert on metrics. gathered, err := metricsRegistry.Gather() @@ -216,9 +217,12 @@ type testEndpointManager struct { requestedService string requestedPort string + + hostRequestCount int } func (t *testEndpointManager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) { + t.hostRequestCount++ t.requestedService = service t.requestedPort = portName return t.address, nil From c1beaf381ef3ca3c3b1170cd3df0a8429405fb63 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Wed, 31 Jan 2024 17:03:46 +0100 Subject: [PATCH 11/13] Add integration test --- tests/integration/integration_test.go | 137 +++++++++++++++++++++----- tests/integration/main_test.go | 1 + 2 files changed, 114 insertions(+), 24 deletions(-) diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 10fb5e45..6325ba7a 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "fmt" + "io" "log" "net/http" "net/http/httptest" @@ -43,17 +44,7 @@ func TestScaleUpAndDown(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -103,17 +94,7 @@ func TestHandleModelUndeployment(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -132,7 +113,7 @@ func TestHandleModelUndeployment(t *testing.T) { require.NoError(t, testK8sClient.Delete(testCtx, deploy)) // Check that the deployment was deleted - err = testK8sClient.Get(testCtx, client.ObjectKey{ + err := testK8sClient.Get(testCtx, client.ObjectKey{ Namespace: deploy.Namespace, Name: deploy.Name, }, deploy) @@ -151,6 +132,100 @@ func TestHandleModelUndeployment(t *testing.T) { wg.Wait() } +func TestRetryMiddleware(t *testing.T) { + const modelName = "test-model-c" + deploy := testDeployment(modelName) + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + backendRequests := &atomic.Int32{} + var serverCodes []int + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expBody := []byte(fmt.Sprintf(`{"model": %q}`, modelName)) + gotBody, err := io.ReadAll(r.Body) + require.NoError(t, err) + assert.Equal(t, expBody, gotBody) + + i := backendRequests.Add(1) + code := serverCodes[i-1] + t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) + w.WriteHeader(code) + _, err = w.Write([]byte(strconv.Itoa(code))) + require.NoError(t, err) + })) + + // Mock an EndpointSlice. + withMockEndpointSlice(t, testBackend, modelName) + + specs := map[string]struct { + serverCodes []int + header []tuple + expResultCode int + expResultBody string + expBackendHits int32 + }{ + "max retries - succeeds": { + serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 4, + }, + "max retries - fails": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, + expResultCode: http.StatusBadGateway, + expResultBody: "{\"error\":\"Bad Gateway\"}\n", // note the linebreak + expBackendHits: 4, + }, + "non retryable error code": { + serverCodes: []int{http.StatusNotImplemented}, + expResultCode: http.StatusNotImplemented, + expResultBody: "501", + expBackendHits: 1, + }, + "200 status code": { + serverCodes: []int{http.StatusOK}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 1, + }, + "200 status code - model header": { + serverCodes: []int{http.StatusOK}, + header: []tuple{{k: "X-Model", v: modelName}}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 1, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // setup + serverCodes = spec.serverCodes + backendRequests.Store(0) + + // when single request sent + gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode, spec.header...) + // then only the last body is written + assert.Equal(t, spec.expResultBody, gotBody) + require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit") + }) + } +} + +func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) { + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) @@ -166,20 +241,34 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp } } -func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) { +type tuple struct { + k, v string +} + +func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int, headers ...tuple) <-chan string { t.Helper() wg.Add(1) + bodyRespChan := make(chan string, 1) go func() { defer wg.Done() + defer close(bodyRespChan) body := []byte(fmt.Sprintf(`{"model": %q}`, modelName)) req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewReader(body)) requireNoError(err) + for _, e := range headers { + req.Header.Add(e.k, e.v) + } res, err := testHTTPClient.Do(req) require.NoError(t, err) require.Equal(t, expCode, res.StatusCode) + got, err := io.ReadAll(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + bodyRespChan <- string(got) }() + return bodyRespChan } func completeRequests(c chan struct{}, n int) { diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 74697559..95f3c2d1 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -109,6 +109,7 @@ func TestMain(m *testing.M) { Deployments: deploymentManager, Endpoints: endpointManager, Queues: queueManager, + MaxRetries: 3, } testServer = httptest.NewServer(handler) defer testServer.Close() From c1418539ca2fec33455332f07770df10fe3348bc Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Mon, 5 Feb 2024 11:00:14 +0100 Subject: [PATCH 12/13] Add x-model header integration test --- tests/integration/integration_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 6325ba7a..7e68ca57 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -174,7 +174,7 @@ func TestRetryMiddleware(t *testing.T) { "max retries - fails": { serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, expResultCode: http.StatusBadGateway, - expResultBody: "{\"error\":\"Bad Gateway\"}\n", // note the linebreak + expResultBody: "502", expBackendHits: 4, }, "non retryable error code": { @@ -196,6 +196,13 @@ func TestRetryMiddleware(t *testing.T) { expResultBody: "200", expBackendHits: 1, }, + "503 with model header": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable}, + header: []tuple{{k: "X-Model", v: modelName}}, + expResultCode: http.StatusServiceUnavailable, + expResultBody: "503", + expBackendHits: 4, + }, } for name, spec := range specs { t.Run(name, func(t *testing.T) { From c2a30d5148601e3249ad9b5044732ed1c26b4b64 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Mon, 5 Feb 2024 11:28:15 +0100 Subject: [PATCH 13/13] Make integration test pass - naive approach --- pkg/proxy/handler.go | 4 +++- pkg/proxy/request.go | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index b5554455..f02e85a5 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -163,7 +163,9 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) { // This point could be reached if a bad response code was sent by the backend // or // if there was an issue with the connection and no response was ever received. - if err != nil && pr.attempt < h.MaxRetries { + if err != nil && + r.Context().Err() == nil && + pr.attempt < h.MaxRetries { pr.attempt++ log.Printf("Retrying request (%v/%v): %v", pr.attempt, h.MaxRetries, pr.id) diff --git a/pkg/proxy/request.go b/pkg/proxy/request.go index 5817faca..2d10604b 100644 --- a/pkg/proxy/request.go +++ b/pkg/proxy/request.go @@ -47,7 +47,6 @@ func newProxyRequest(r *http.Request) *proxyRequest { })) return pr - } // done should be called when the original client request is complete. @@ -61,16 +60,18 @@ func (p *proxyRequest) done() { // .model field. func (pr *proxyRequest) parseModel() error { pr.model = pr.r.Header.Get("X-Model") - if pr.model != "" { - return nil - } + // always buffer body var err error pr.body, err = io.ReadAll(pr.r.Body) if err != nil { return fmt.Errorf("read: %w", err) } + if pr.model != "" { + return nil + } + var payload struct { Model string `json:"model"` }