Skip to content

Commit

Permalink
Better concurrent request handling for model host address (#38)
Browse files Browse the repository at this point in the history
Like in #36 the reconcile may be affected by external requests. This
refactoring helps by reducing lock conflicts

* Handle request context timeout
* Optimise concurrent access in endpoints type by separating r/w lock
and notification lock
* Added some tests and Go doc

I have also added a benchmark that shows that the new rwlock is ~30%
faster than before on my box. But this is all within ns and does not
really matter:
```
new: BenchmarkEndpointGroup-12    	 7667690	       154.6 ns/op
old: BenchmarkEndpointGroup-12    	 4968279	       234.2 ns/op
```
The key benefit of this PR is handling request timeout
  • Loading branch information
alpe authored Jan 11, 2024
1 parent e51569a commit cbfa863
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ jobs:
uses: actions/setup-go@v4
with:
go-version: '>=1.21.0'
- name: Run race tests
run: make test-race
- name: Run integration tests
run: make test-integration
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
ENVTEST_K8S_VERSION = 1.27.1

.PHONY: test
test: test-unit

.PHONY: test-all
test-all: test-race test-integration

.PHONY: test-unit
test-unit:
go test -mod=readonly ./pkg/...

.PHONY: test-race
test-race:
go test -mod=readonly -race ./pkg/...

.PHONY: test-integration
test-integration: envtest
KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./tests/integration -v

Expand Down
2 changes: 1 addition & 1 deletion pkg/autoscaler/autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (a *Autoscaler) Start() {
log.Println("Calculating scales for all")

// TODO: Remove hardcoded Service lookup by name "lingo".
otherLingoEndpoints := a.Endpoints.GetAllHosts(context.Background(), "lingo", "stats")
otherLingoEndpoints := a.Endpoints.GetAllHosts("lingo", "stats")

stats, errs := aggregateStats(stats.Stats{
ActiveRequests: a.Queues.TotalCounts(),
Expand Down
68 changes: 50 additions & 18 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package endpoints

import (
"context"
"fmt"
"sync"
"sync/atomic"
Expand All @@ -10,7 +11,7 @@ func newEndpointGroup() *endpointGroup {
e := &endpointGroup{}
e.ports = make(map[string]int32)
e.endpoints = make(map[string]endpoint)
e.active = sync.NewCond(&e.mtx)
e.bcast = make(chan struct{})
return e
}

Expand All @@ -19,20 +20,37 @@ type endpoint struct {
}

type endpointGroup struct {
mtx sync.RWMutex
ports map[string]int32
endpoints map[string]endpoint
active *sync.Cond
mtx sync.Mutex
}

func (e *endpointGroup) getHost(portName string) string {
e.mtx.Lock()
defer e.mtx.Unlock()
bmtx sync.RWMutex
bcast chan struct{} // closed when there's a broadcast
}

// getBestHost returns the best host for the given port name. It blocks until there are available endpoints
// in the endpoint group.
//
// It selects the host with the minimum in-flight requests among all the available endpoints.
// The host is returned as a string in the format "IP:Port".
//
// Parameters:
// - portName: The name of the port for which the best host needs to be determined.
//
// Returns:
// - string: The best host with the minimum in-flight requests.
func (e *endpointGroup) getBestHost(ctx context.Context, portName string) (string, error) {
e.mtx.RLock()
// await endpoints exists
for len(e.endpoints) == 0 {
e.active.Wait()
e.mtx.RUnlock()
select {
case <-e.awaitEndpoints():
case <-ctx.Done():
return "", ctx.Err()
}
e.mtx.RLock()
}

var bestIP string
port := e.getPort(portName)
var minInFlight int
Expand All @@ -43,13 +61,19 @@ func (e *endpointGroup) getHost(portName string) string {
minInFlight = inFlight
}
}
e.mtx.RUnlock()
return fmt.Sprintf("%s:%v", bestIP, port), nil
}

return fmt.Sprintf("%s:%v", bestIP, port)
func (e *endpointGroup) awaitEndpoints() chan struct{} {
e.bmtx.RLock()
defer e.bmtx.RUnlock()
return e.bcast
}

func (e *endpointGroup) getAllHosts(portName string) []string {
e.mtx.Lock()
defer e.mtx.Unlock()
e.mtx.RLock()
defer e.mtx.RUnlock()

var hosts []string
port := e.getPort(portName)
Expand All @@ -70,15 +94,13 @@ func (e *endpointGroup) getPort(portName string) int32 {
}

func (g *endpointGroup) lenIPs() int {
g.mtx.Lock()
defer g.mtx.Unlock()
g.mtx.RLock()
defer g.mtx.RUnlock()
return len(g.endpoints)
}

func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) {
g.mtx.Lock()
defer g.mtx.Unlock()

g.ports = ports
for ip := range ips {
if _, ok := g.endpoints[ip]; !ok {
Expand All @@ -90,8 +112,18 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32)
delete(g.endpoints, ip)
}
}
g.mtx.Unlock()

if len(g.endpoints) > 0 {
g.active.Broadcast()
// notify waiting requests
if len(ips) > 0 {
g.broadcastEndpoints()
}
}

func (g *endpointGroup) broadcastEndpoints() {
g.bmtx.Lock()
defer g.bmtx.Unlock()

close(g.bcast)
g.bcast = make(chan struct{})
}
117 changes: 117 additions & 0 deletions pkg/endpoints/endpoints_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package endpoints

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"k8s.io/apimachinery/pkg/util/rand"
)

func TestConcurrentAccess(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

testCases := map[string]struct {
readerCount int
writerCount int
}{
"lot of reader": {readerCount: 1_000, writerCount: 1},
"lot of writer": {readerCount: 1, writerCount: 1_000},
"lot of both": {readerCount: 1_000, writerCount: 1_000},
}
for name, spec := range testCases {
randomReadFn := []func(g *endpointGroup){
func(g *endpointGroup) { g.getBestHost(nil, myPort) },
func(g *endpointGroup) { g.getAllHosts(myPort) },
func(g *endpointGroup) { g.lenIPs() },
}
t.Run(name, func(t *testing.T) {
// setup endpoint with one service so that requests are not waiting
endpoint := newEndpointGroup()
endpoint.setIPs(
map[string]struct{}{myService: {}},
map[string]int32{myPort: 1},
)

var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
doneWg.Done()
}()
}
}
// when
startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](endpoint) })
startTogether(spec.writerCount, func() {
endpoint.setIPs(
map[string]struct{}{rand.String(1): {}},
map[string]int32{rand.String(1): 1},
)
})
doneWg.Wait()
})
}
}

func TestBlockAndWaitForEndpoints(t *testing.T) {
var completed atomic.Int32
var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
completed.Add(1)
doneWg.Done()
}()
}
}
endpoint := newEndpointGroup()
ctx := context.TODO()
startTogether(100, func() {
endpoint.getBestHost(ctx, rand.String(4))
})
startWg.Wait()

// when broadcast triggered
endpoint.setIPs(
map[string]struct{}{rand.String(4): {}},
map[string]int32{rand.String(4): 1},
)
// then
doneWg.Wait()
assert.Equal(t, int32(100), completed.Load())
}

func TestAbortOnCtxCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

var startWg, doneWg sync.WaitGroup
startWg.Add(1)
doneWg.Add(1)
go func(t *testing.T) {
startWg.Wait()
endpoint := newEndpointGroup()
_, err := endpoint.getBestHost(ctx, rand.String(4))
require.Error(t, err)
doneWg.Done()
}(t)
startWg.Done()
cancel()

doneWg.Wait()
}
20 changes: 20 additions & 0 deletions pkg/endpoints/endponts_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package endpoints

import (
"context"
"testing"
)

func BenchmarkEndpointGroup(b *testing.B) {
e := newEndpointGroup()
e.setIPs(map[string]struct{}{"10.0.0.1": {}}, map[string]int32{"testPort": 1})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := e.getBestHost(context.Background(), "testPort")
if err != nil {
b.Fatal(err)
}
}
})
}
11 changes: 8 additions & 3 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ func (r *Manager) getEndpoints(service string) *endpointGroup {
return e
}

func (r *Manager) GetHost(ctx context.Context, service, portName string) string {
return r.getEndpoints(service).getHost(portName)
// AwaitHostAddress returns the host address with the lowest number of in-flight requests. It will block until the host address
// becomes available or the context times out.
//
// It returns a string in the format "host:port" or error on timeout
func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) {
return r.getEndpoints(service).getBestHost(ctx, portName)
}

func (r *Manager) GetAllHosts(ctx context.Context, service, portName string) []string {
// GetAllHosts retrieves the list of all hosts for a given service and port.
func (r *Manager) GetAllHosts(service, portName string) []string {
return r.getEndpoints(service).getAllHosts(portName)
}
59 changes: 59 additions & 0 deletions pkg/endpoints/manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package endpoints

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAwaitBestHost(t *testing.T) {
const myService = "myService"
const myPort = "myPort"

manager := &Manager{endpoints: make(map[string]*endpointGroup, 1)}
manager.getEndpoints(myService).
setIPs(map[string]struct{}{myService: {}}, map[string]int32{myPort: 1})

testCases := map[string]struct {
service string
portName string
timeout time.Duration
expErr error
}{
"all good": {
service: myService,
portName: myPort,
timeout: time.Millisecond,
},
"unknown port - returns default if only 1": {
service: myService,
portName: "unknownPort",
timeout: time.Millisecond,
},
"unknown service - blocks until timeout": {
service: "unknownService",
portName: myPort,
timeout: time.Millisecond,
expErr: context.DeadlineExceeded,
},
// not covered: unknown port with multiple ports on entrypoint
}

for name, spec := range testCases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), spec.timeout)
defer cancel()

gotHost, gotErr := manager.AwaitHostAddress(ctx, spec.service, spec.portName)
if spec.expErr != nil {
require.ErrorIs(t, spec.expErr, gotErr)
return
}
require.NoError(t, gotErr)
assert.Equal(t, myService+":1", gotHost)
})
}
}
Loading

0 comments on commit cbfa863

Please sign in to comment.