Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike: Add optional retry middleware #51

Closed
wants to merge 12 commits into from
9 changes: 6 additions & 3 deletions cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ func run() error {
var metricsAddr string
var probeAddr string
var concurrencyPerReplica int
var maxRetriesOnErr int

flag.StringVar(&metricsAddr, "metrics-bind-address", ":8082", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.IntVar(&concurrencyPerReplica, "concurrency", concurrency, "the number of simultaneous requests that can be processed by each replica")
flag.IntVar(&scaleDownDelay, "scale-down-delay", scaleDownDelay, "seconds to wait before scaling down")
flag.IntVar(&maxRetriesOnErr, "max-retries", 0, "max number of retries on a http error code: 502,503,504")
opts := zap.Options{
Development: true,
}
Expand Down Expand Up @@ -113,17 +115,17 @@ func run() error {
if err != nil {
return fmt.Errorf("getting hostname: %v", err)
}

le := leader.NewElection(clientset, hostname, namespace)

queueManager := queue.NewManager(concurrencyPerReplica)
metricsRegistry := prometheus.WrapRegistererWithPrefix("lingo_", metrics.Registry)
queue.NewMetricsCollector(queueManager).MustRegister(metricsRegistry)

endpointManager, err := endpoints.NewManager(mgr)
endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas)
if err != nil {
return fmt.Errorf("setting up endpoint manager: %w", err)
}
endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas
// The autoscaling leader will scrape other lingo instances.
// Exclude this instance from being scraped by itself.
endpointManager.ExcludePods[hostname] = struct{}{}
Expand Down Expand Up @@ -153,7 +155,8 @@ func run() error {
go autoscaler.Start()

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler = proxy.NewRetryMiddleware(maxRetriesOnErr, proxyHandler)
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
19 changes: 19 additions & 0 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package endpoints

import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
)
Expand Down Expand Up @@ -127,3 +129,20 @@ func (g *endpointGroup) broadcastEndpoints() {
close(g.bcast)
g.bcast = make(chan struct{})
}

func (e *endpointGroup) AddInflight(addr string) (func(), error) {
tokens := strings.Split(addr, ":")
if len(tokens) != 2 {
return nil, errors.New("unsupported address format")
}
e.mtx.RLock()
defer e.mtx.RUnlock()
endpoint, ok := e.endpoints[tokens[0]]
if !ok {
return nil, errors.New("unsupported endpoint address")
}
endpoint.inFlight.Add(1)
return func() {
endpoint.inFlight.Add(-1)
}, nil
}
14 changes: 11 additions & 3 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)

func NewManager(mgr ctrl.Manager) (*Manager, error) {
func NewManager(mgr ctrl.Manager, endpointSizeCallback func(deploymentName string, replicas int)) (*Manager, error) {
r := &Manager{}
r.Client = mgr.GetClient()
r.endpoints = map[string]*endpointGroup{}
r.ExcludePods = map[string]struct{}{}
r.EndpointSizeCallback = endpointSizeCallback
if err := r.SetupWithManager(mgr); err != nil {
return nil, err
}
Expand Down Expand Up @@ -86,6 +87,11 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result,
}
}

r.SetEndpoints(serviceName, ips, ports)
return ctrl.Result{}, nil
}

func (r *Manager) SetEndpoints(serviceName string, ips map[string]struct{}, ports map[string]int32) {
priorLen := r.getEndpoints(serviceName).lenIPs()
r.getEndpoints(serviceName).setIPs(ips, ports)

Expand All @@ -95,8 +101,6 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result,
// replicas by something else.
r.EndpointSizeCallback(serviceName, len(ips))
}

return ctrl.Result{}, nil
}

func (r *Manager) getEndpoints(service string) *endpointGroup {
Expand All @@ -122,3 +126,7 @@ func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string
func (r *Manager) GetAllHosts(service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
}

func (r *Manager) RegisterInFlight(service string, hostAddr string) (func(), error) {
return r.getEndpoints(service).AddInflight(hostAddr)
}
39 changes: 29 additions & 10 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,33 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"

"github.com/substratusai/lingo/pkg/deployments"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

type deploymentSource interface {
ResolveDeployment(model string) (string, bool)
AtLeastOne(deploymentName string)
}

// Handler serves http requests for end-clients.
// It is also responsible for triggering scale-from-zero.
type Handler struct {
Deployments *deployments.Manager
Deployments deploymentSource
Endpoints *endpoints.Manager
Queues *queue.Manager
}

func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues}
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var modelName string
captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w)
captureStatusRespWriter := NewCaptureStatusCodeResponseWriter(w)
w = captureStatusRespWriter
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v)
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.CapturedStatusCode())).Observe(v)
}))
defer timer.ObserveDuration()

Expand Down Expand Up @@ -105,9 +109,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
log.Printf("Got host: %v, id: %v\n", host, id)

done, err := h.Endpoints.RegisterInFlight(deploy, host)
if err != nil {
log.Printf("error registering in-flight request: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer done()

log.Printf("Proxying request to host %v: %v\n", host, id)
// TODO: Avoid creating new reverse proxies for each request.
// TODO: Consider implementing a round robin scheme.
log.Printf("Proxying request to host %v: %v\n", host, id)
newReverseProxy(host).ServeHTTP(w, proxyRequest)
}

Expand All @@ -117,10 +129,17 @@ func parseModel(r *http.Request) (string, *http.Request, error) {
if model := r.Header.Get("X-Model"); model != "" {
return model, r, nil
}
// parse request body for model name, ignore errors
body, err := io.ReadAll(r.Body)
if err != nil {
return "", r, nil
var body []byte
if mb, ok := r.Body.(CapturedBodySource); ok && mb.IsCaptured() {
// reuse buffer
body = mb.GetBody()
} else {
// parse request body for model name, ignore errors
var err error
body, err = io.ReadAll(r.Body)
if err != nil {
return "", r, nil
}
}

var payload struct {
Expand Down
75 changes: 75 additions & 0 deletions pkg/proxy/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package proxy

import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

func TestProxy(t *testing.T) {
specs := map[string]struct {
request *http.Request
expCode int
}{
"ok": {
request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)),
expCode: http.StatusOK,
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
deplMgr := mockDeploymentSource{
ResolveDeploymentFn: func(model string) (string, bool) {
return "my-deployment", true
},
AtLeastOneFn: func(deploymentName string) {},
}
em, err := endpoints.NewManager(&fakeManager{}, func(deploymentName string, replicas int) {})
require.NoError(t, err)
em.SetEndpoints("my-deployment", map[string]struct{}{"my-ip": {}}, map[string]int32{"my-port": 8080})
h := NewHandler(deplMgr, em, queue.NewManager(10))

svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080})
w.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()

AdditionalProxyRewrite = func(r *httputil.ProxyRequest) {
r.SetURL(&url.URL{Scheme: "http", Host: svr.Listener.Addr().String()})
}

// when
h.ServeHTTP(recorder, spec.request)
// then
assert.Equal(t, spec.expCode, recorder.Code)
})
}
}

type mockDeploymentSource struct {
ResolveDeploymentFn func(model string) (string, bool)
AtLeastOneFn func(deploymentName string)
}

func (m mockDeploymentSource) ResolveDeployment(model string) (string, bool) {
if m.ResolveDeploymentFn == nil {
panic("not expected to be called")
}
return m.ResolveDeploymentFn(model)
}

func (m mockDeploymentSource) AtLeastOne(deploymentName string) {
if m.AtLeastOneFn == nil {
panic("not expected to be called")
}
m.AtLeastOneFn(deploymentName)
}
59 changes: 51 additions & 8 deletions pkg/proxy/metrics.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package proxy

import (
"io"
"net/http"

"github.com/prometheus/client_golang/prometheus"
Expand All @@ -12,21 +13,63 @@ var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Buckets: prometheus.DefBuckets,
}, []string{"model", "status_code"})

var totalRetries = prometheus.NewCounter(prometheus.CounterOpts{
Name: "http_request_retry_total",
Help: "Number of HTTP request retries.",
})

func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration)
r.MustRegister(httpDuration, totalRetries)
}

// CaptureStatusCodeResponseWriter is an interface that extends the http.ResponseWriter interface and provides a method for reading the status code of an HTTP response.
type CaptureStatusCodeResponseWriter interface {
http.ResponseWriter
StatusCodeCapturer
}

// captureStatusResponseWriter is a custom HTTP response writer that captures the status code.
// captureStatusResponseWriter is a custom HTTP response writer that implements CaptureStatusCodeResponseWriter
type captureStatusResponseWriter struct {
http.ResponseWriter
statusCode int
statusCode int
wroteHeader bool
}

func NewCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) CaptureStatusCodeResponseWriter {
if o, ok := responseWriter.(CaptureStatusCodeResponseWriter); ok { // nothing to do as code is captured already
return o
}
c := &captureStatusResponseWriter{ResponseWriter: responseWriter}
if _, ok := responseWriter.(io.ReaderFrom); ok {
return &captureStatusResponseWriterWithReadFrom{captureStatusResponseWriter: c}
}
return c
}

func (c *captureStatusResponseWriter) CapturedStatusCode() int {
return c.statusCode
}

func (c *captureStatusResponseWriter) WriteHeader(code int) {
c.wroteHeader = true
c.statusCode = code
c.ResponseWriter.WriteHeader(code)
}

func (c *captureStatusResponseWriter) Write(b []byte) (int, error) {
if !c.wroteHeader {
c.WriteHeader(http.StatusOK)
}
return c.ResponseWriter.Write(b)
}

func newCaptureStatusCodeResponseWriter(responseWriter http.ResponseWriter) *captureStatusResponseWriter {
return &captureStatusResponseWriter{ResponseWriter: responseWriter}
type captureStatusResponseWriterWithReadFrom struct {
*captureStatusResponseWriter
}

func (srw *captureStatusResponseWriter) WriteHeader(code int) {
srw.statusCode = code
srw.ResponseWriter.WriteHeader(code)
func (c *captureStatusResponseWriterWithReadFrom) ReadFrom(re io.Reader) (int64, error) {
if !c.wroteHeader {
c.WriteHeader(http.StatusOK)
}
return c.ResponseWriter.(io.ReaderFrom).ReadFrom(re)
}
Loading
Loading