Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spike: Add optional retry middleware #51

Closed
wants to merge 12 commits into from
7 changes: 4 additions & 3 deletions cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@ func run() error {
if err != nil {
return fmt.Errorf("getting hostname: %v", err)
}

le := leader.NewElection(clientset, hostname, namespace)

queueManager := queue.NewManager(concurrencyPerReplica)
metricsRegistry := prometheus.WrapRegistererWithPrefix("lingo_", metrics.Registry)
queue.NewMetricsCollector(queueManager).MustRegister(metricsRegistry)

endpointManager, err := endpoints.NewManager(mgr)
endpointManager, err := endpoints.NewManager(mgr, queueManager.UpdateQueueSizeForReplicas)
if err != nil {
return fmt.Errorf("setting up endpoint manager: %w", err)
}
endpointManager.EndpointSizeCallback = queueManager.UpdateQueueSizeForReplicas
// The autoscaling leader will scrape other lingo instances.
// Exclude this instance from being scraped by itself.
endpointManager.ExcludePods[hostname] = struct{}{}
Expand Down Expand Up @@ -153,7 +153,8 @@ func run() error {
go autoscaler.Start()

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler = proxy.NewRetryMiddleware(3, proxyHandler)
alpe marked this conversation as resolved.
Show resolved Hide resolved
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
27 changes: 24 additions & 3 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package endpoints

import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
)
Expand All @@ -16,7 +18,8 @@ func newEndpointGroup() *endpointGroup {
}

type endpoint struct {
inFlight *atomic.Int64
inFlight *atomic.Int64
terminated chan struct{}
alpe marked this conversation as resolved.
Show resolved Hide resolved
}

type endpointGroup struct {
Expand Down Expand Up @@ -104,12 +107,13 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32)
g.ports = ports
for ip := range ips {
if _, ok := g.endpoints[ip]; !ok {
g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}}
g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}, terminated: make(chan struct{})}
}
}
for ip := range g.endpoints {
for ip, endpoint := range g.endpoints {
if _, ok := ips[ip]; !ok {
delete(g.endpoints, ip)
close(endpoint.terminated)
}
}
g.mtx.Unlock()
Expand All @@ -127,3 +131,20 @@ func (g *endpointGroup) broadcastEndpoints() {
close(g.bcast)
g.bcast = make(chan struct{})
}

func (e *endpointGroup) AddInflight(addr string) (func(), error) {
tokens := strings.Split(addr, ":")
if len(tokens) != 2 {
return nil, errors.New("unsupported address format")
}
e.mtx.RLock()
defer e.mtx.RUnlock()
endpoint, ok := e.endpoints[tokens[0]]
if !ok {
return nil, errors.New("unsupported endpoint address")
}
endpoint.inFlight.Add(1)
return func() {
endpoint.inFlight.Add(-1)
}, nil
}
14 changes: 11 additions & 3 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)

func NewManager(mgr ctrl.Manager) (*Manager, error) {
func NewManager(mgr ctrl.Manager, endpointSizeCallback func(deploymentName string, replicas int)) (*Manager, error) {
r := &Manager{}
r.Client = mgr.GetClient()
r.endpoints = map[string]*endpointGroup{}
r.ExcludePods = map[string]struct{}{}
r.EndpointSizeCallback = endpointSizeCallback
if err := r.SetupWithManager(mgr); err != nil {
return nil, err
}
Expand Down Expand Up @@ -86,6 +87,11 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result,
}
}

r.SetEndpoints(serviceName, ips, ports)
return ctrl.Result{}, nil
}

func (r *Manager) SetEndpoints(serviceName string, ips map[string]struct{}, ports map[string]int32) {
priorLen := r.getEndpoints(serviceName).lenIPs()
r.getEndpoints(serviceName).setIPs(ips, ports)

Expand All @@ -95,8 +101,6 @@ func (r *Manager) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result,
// replicas by something else.
r.EndpointSizeCallback(serviceName, len(ips))
}

return ctrl.Result{}, nil
}

func (r *Manager) getEndpoints(service string) *endpointGroup {
Expand All @@ -122,3 +126,7 @@ func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string
func (r *Manager) GetAllHosts(service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
}

func (r *Manager) RegisterInFlight(service string, hostAddr string) (func(), error) {
return r.getEndpoints(service).AddInflight(hostAddr)
}
36 changes: 32 additions & 4 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ import (
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"

"github.com/substratusai/lingo/pkg/deployments"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

type deploymentSource interface {
ResolveDeployment(model string) (string, bool)
AtLeastOne(deploymentName string)
}

// Handler serves http requests for end-clients.
// It is also responsible for triggering scale-from-zero.
type Handler struct {
Deployments *deployments.Manager
Deployments deploymentSource
Endpoints *endpoints.Manager
Queues *queue.Manager
}

func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues}
}

Expand Down Expand Up @@ -105,12 +109,36 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
log.Printf("Got host: %v, id: %v\n", host, id)

done, err := h.Endpoints.RegisterInFlight(deploy, host)
if err != nil {
log.Printf("error registering in-flight request: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer done()

log.Printf("Proxying request to host %v: %v\n", host, id)
// TODO: Avoid creating new reverse proxies for each request.
// TODO: Consider implementing a round robin scheme.
log.Printf("Proxying request to host %v: %v\n", host, id)
newReverseProxy(host).ServeHTTP(w, proxyRequest)
}

func withInflightCounted(endpoints *endpoints.Manager, deploy string, host string) func(other http.Handler) http.HandlerFunc {
return func(other http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
done, err := endpoints.RegisterInFlight(deploy, host)
if err != nil {
log.Printf("error registering in-flight request: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer done()

other.ServeHTTP(w, r)
}
}
}

// parseModel parses the model name from the request
// returns empty string when none found or an error for failures on the proxy request object
func parseModel(r *http.Request) (string, *http.Request, error) {
Expand Down
75 changes: 75 additions & 0 deletions pkg/proxy/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package proxy

import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

func TestProxy(t *testing.T) {
specs := map[string]struct {
request *http.Request
expCode int
}{
"ok": {
request: httptest.NewRequest(http.MethodGet, "/", strings.NewReader(`{"model":"my_model"}`)),
expCode: http.StatusOK,
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
deplMgr := mockDeploymentSource{
ResolveDeploymentFn: func(model string) (string, bool) {
return "my-deployment", true
},
AtLeastOneFn: func(deploymentName string) {},
}
em, err := endpoints.NewManager(&fakeManager{}, func(deploymentName string, replicas int) {})
require.NoError(t, err)
em.SetEndpoints("my-deployment", map[string]struct{}{"my-ip": {}}, map[string]int32{"my-port": 8080})
h := NewHandler(deplMgr, em, queue.NewManager(10))

svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080})
w.WriteHeader(http.StatusOK)
}))
recorder := httptest.NewRecorder()

AdditionalProxyRewrite = func(r *httputil.ProxyRequest) {
r.SetURL(&url.URL{Scheme: "http", Host: svr.Listener.Addr().String()})
}

// when
h.ServeHTTP(recorder, spec.request)
// then
assert.Equal(t, spec.expCode, recorder.Code)
})
}
}

type mockDeploymentSource struct {
ResolveDeploymentFn func(model string) (string, bool)
AtLeastOneFn func(deploymentName string)
}

func (m mockDeploymentSource) ResolveDeployment(model string) (string, bool) {
if m.ResolveDeploymentFn == nil {
panic("not expected to be called")
}
return m.ResolveDeploymentFn(model)
}

func (m mockDeploymentSource) AtLeastOne(deploymentName string) {
if m.AtLeastOneFn == nil {
panic("not expected to be called")
}
m.AtLeastOneFn(deploymentName)
}
7 changes: 6 additions & 1 deletion pkg/proxy/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Buckets: prometheus.DefBuckets,
}, []string{"model", "status_code"})

var totalRetries = prometheus.NewCounter(prometheus.CounterOpts{
Name: "http_request_retry_total",
Help: "Number of HTTP request retries.",
})

func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration)
r.MustRegister(httpDuration, totalRetries)
}

// captureStatusResponseWriter is a custom HTTP response writer that captures the status code.
Expand Down
9 changes: 5 additions & 4 deletions pkg/proxy/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ func TestMetrics(t *testing.T) {
assert.Equal(t, spec.expCode, recorder.Code)
gathered, err := registry.Gather()
require.NoError(t, err)
require.Len(t, gathered, 1)
require.Len(t, gathered[0].Metric, 1)
assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount())
assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label))
require.Len(t, gathered, 2)
require.Equal(t, "http_response_time_seconds", *gathered[1].Name)
require.Len(t, gathered[1].Metric, 1)
assert.NotEmpty(t, gathered[1].Metric[0].GetHistogram().GetSampleCount())
assert.Equal(t, spec.expLabels, toMap(gathered[1].Metric[0].Label))
})
}
}
Expand Down
Loading
Loading