Skip to content

Commit

Permalink
Add retry middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 10, 2024
1 parent 4b04877 commit e9486b5
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 6 deletions.
4 changes: 3 additions & 1 deletion cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func run() error {
if err != nil {
return fmt.Errorf("getting hostname: %v", err)
}

le := leader.NewElection(clientset, hostname, namespace)

queueManager := queue.NewManager(concurrencyPerReplica)
Expand Down Expand Up @@ -148,7 +149,8 @@ func run() error {
go autoscaler.Start()

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler = proxy.NewRetryMiddleware(3, proxyHandler)
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
7 changes: 6 additions & 1 deletion pkg/proxy/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Buckets: prometheus.DefBuckets,
}, []string{"model", "status_code"})

var totalRetries = prometheus.NewCounter(prometheus.CounterOpts{
Name: "http_request_retry_total",
Help: "Number of HTTP request retries.",
})

func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration)
r.MustRegister(httpDuration, totalRetries)
}

// captureStatusResponseWriter is a custom HTTP response writer that captures the status code.
Expand Down
9 changes: 5 additions & 4 deletions pkg/proxy/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ func TestMetrics(t *testing.T) {
assert.Equal(t, spec.expCode, recorder.Code)
gathered, err := registry.Gather()
require.NoError(t, err)
require.Len(t, gathered, 1)
require.Len(t, gathered[0].Metric, 1)
assert.NotEmpty(t, gathered[0].Metric[0].GetHistogram().GetSampleCount())
assert.Equal(t, spec.expLabels, toMap(gathered[0].Metric[0].Label))
require.Len(t, gathered, 2)
require.Equal(t, "http_response_time_seconds", *gathered[1].Name)
require.Len(t, gathered[1].Metric, 1)
assert.NotEmpty(t, gathered[1].Metric[0].GetHistogram().GetSampleCount())
assert.Equal(t, spec.expLabels, toMap(gathered[1].Metric[0].Label))
})
}
}
Expand Down
77 changes: 77 additions & 0 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package proxy

import (
"bytes"
"log"
"math/rand"
"net/http"
"time"
)

var _ http.Handler = &RetryMiddleware{}

type RetryMiddleware struct {
other http.Handler
MaxRetries int
rSource *rand.Rand
}

func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware {
if maxRetries < 1 {
panic("invalid retries")
}
return &RetryMiddleware{
other: other,
MaxRetries: maxRetries,
rSource: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
var capturedResp *responseBuffer
for i := 0; ; i++ {
capturedResp = &responseBuffer{
header: make(http.Header),
body: bytes.NewBuffer([]byte{}),
}
// call next handler in chain
r.other.ServeHTTP(capturedResp, request.Clone(request.Context()))

if i == r.MaxRetries || // max retries reached
request.Context().Err() != nil || // abort early on timeout, context cancel
capturedResp.status != http.StatusBadGateway &&
capturedResp.status != http.StatusServiceUnavailable {
break
}
totalRetries.Inc()
// Exponential backoff
jitter := time.Duration(r.rSource.Intn(50))
time.Sleep(time.Millisecond*time.Duration(1<<uint(i)) + jitter)
}
for k, v := range capturedResp.header {
writer.Header()[k] = v
}
writer.WriteHeader(capturedResp.status)
if _, err := capturedResp.body.WriteTo(writer); err != nil {
log.Printf("response write: %v", err)
}
}

type responseBuffer struct {
header http.Header
body *bytes.Buffer
status int
}

func (rb *responseBuffer) Header() http.Header {
return rb.header
}

func (r *responseBuffer) WriteHeader(status int) {
r.status = status
r.header = r.Header().Clone()
}

func (rb *responseBuffer) Write(data []byte) (int, error) {
return rb.body.Write(data)
}
92 changes: 92 additions & 0 deletions pkg/proxy/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package proxy

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestServeHTTP(t *testing.T) {
specs := map[string]struct {
context func() context.Context
maxRetries int
respStatus int
expRetries int
}{
"no retry on 200": {
context: func() context.Context { return context.TODO() },
maxRetries: 3,
respStatus: http.StatusOK,
expRetries: 0,
},
"no retry on 500": {
context: func() context.Context { return context.TODO() },
maxRetries: 3,
respStatus: http.StatusInternalServerError,
expRetries: 0,
},
"max retries on 503": {
context: func() context.Context { return context.TODO() },
maxRetries: 3,
respStatus: http.StatusServiceUnavailable,
expRetries: 3,
},
"max retries on 502": {
context: func() context.Context { return context.TODO() },
maxRetries: 3,
respStatus: http.StatusBadGateway,
expRetries: 3,
},
"context cancelled": {
context: func() context.Context {
ctx, cancel := context.WithCancel(context.TODO())
cancel()
return ctx
},
maxRetries: 3,
respStatus: http.StatusBadGateway,
expRetries: 0,
},
}
for name, spec := range specs {
t.Run(name, func(t *testing.T) {
counterBefore := counterValue(t, totalRetries)
req, err := http.NewRequestWithContext(spec.context(), "GET", "/test", nil)
require.NoError(t, err)

respRecorder := httptest.NewRecorder()

var counter int
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
counter++
w.WriteHeader(spec.respStatus)
})

// when
middleware := NewRetryMiddleware(spec.maxRetries, testHandler)
middleware.ServeHTTP(respRecorder, req)

// then
resp := respRecorder.Result()
require.Equal(t, spec.respStatus, resp.StatusCode)
assert.Equal(t, spec.expRetries, counter-1)

assert.Equal(t, spec.expRetries, int(counterValue(t, totalRetries)-counterBefore))
})
}
}

func counterValue(t *testing.T, counter prometheus.Counter) float64 {
registry := prometheus.NewPedanticRegistry()
registry.MustRegister(counter)
gathered, err := registry.Gather()
require.NoError(t, err)
require.Len(t, gathered, 1)
require.Len(t, gathered[0].Metric, 1)
return gathered[0].Metric[0].GetCounter().GetValue()
}

0 comments on commit e9486b5

Please sign in to comment.