diff --git a/pkg/deployments/manager.go b/pkg/deployments/manager.go index a008dd70..40a295f8 100644 --- a/pkg/deployments/manager.go +++ b/pkg/deployments/manager.go @@ -115,6 +115,7 @@ func getModelsFromAnnotation(ann map[string]string) []string { } func (r *Manager) removeDeployment(req ctrl.Request) { + r.getScaler(req.Name).StopScaleDownTimer() r.scalersMtx.Lock() delete(r.scalers, req.Name) r.scalersMtx.Unlock() diff --git a/pkg/deployments/manager_test.go b/pkg/deployments/manager_test.go index 12cba718..e2da822f 100644 --- a/pkg/deployments/manager_test.go +++ b/pkg/deployments/manager_test.go @@ -4,6 +4,11 @@ import ( "context" "reflect" "testing" + "time" + + "k8s.io/apimachinery/pkg/types" + + "sigs.k8s.io/controller-runtime/pkg/reconcile" appsv1 "k8s.io/api/apps/v1" autoscalingv1 "k8s.io/api/autoscaling/v1" @@ -136,6 +141,79 @@ func TestAddDeployment(t *testing.T) { } } +func TestRemoveDeployment(t *testing.T) { + const myDeployment = "myDeployment" + specs := map[string]struct { + setup func(t *testing.T, m *Manager) + assert func(t *testing.T, m *Manager) + }{ + "single model deployment": { + setup: func(t *testing.T, m *Manager) { + m.modelToDeployment["model1"] = myDeployment + m.scalers[myDeployment] = &scaler{} + }, + assert: func(t *testing.T, m *Manager) { + assert.Len(t, m.modelToDeployment, 0) + assert.Len(t, m.scalers, 0) + }, + }, + "multi model deployment": { + setup: func(t *testing.T, m *Manager) { + m.modelToDeployment["model1"] = myDeployment + m.modelToDeployment["model2"] = myDeployment + m.modelToDeployment["other"] = "other" + m.scalers[myDeployment] = &scaler{} + m.scalers["other"] = &scaler{} + }, + assert: func(t *testing.T, m *Manager) { + assert.Equal(t, map[string]string{"other": "other"}, m.modelToDeployment) + assert.Equal(t, map[string]*scaler{"other": {}}, m.scalers) + }, + }, + "unknown deployment - ignored": { + setup: func(t *testing.T, m *Manager) { + m.modelToDeployment["other"] = "other" + m.scalers["other"] = &scaler{} + }, + assert: func(t *testing.T, m *Manager) { + assert.Equal(t, map[string]string{"other": "other"}, m.modelToDeployment) + assert.Equal(t, map[string]*scaler{"other": {}}, m.scalers) + }, + }, + "scale down timer stopped": { + setup: func(t *testing.T, m *Manager) { + m.modelToDeployment["model1"] = myDeployment + m.scalers[myDeployment] = &scaler{ + scaleDownStarted: true, + scaleDownTimer: time.AfterFunc(100*time.Millisecond, func() { + t.Fatal("scale down timer not stopped") + }), + } + }, + assert: func(t *testing.T, m *Manager) { + // wait a bit longer than the timer would need to run + time.Sleep(120 * time.Millisecond) + assert.Len(t, m.modelToDeployment, 0) + assert.Len(t, m.scalers, 0) + }, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + m := &Manager{ + scalers: make(map[string]*scaler), + modelToDeployment: make(map[string]string), + } + spec.setup(t, m) + req := reconcile.Request{NamespacedName: types.NamespacedName{Name: myDeployment}} + // when + m.removeDeployment(req) + // then + spec.assert(t, m) + }) + } +} + type partialFakeClient struct { client.Client subRes client.Object diff --git a/pkg/deployments/scaler.go b/pkg/deployments/scaler.go index 5e5b49bc..2c3fb001 100644 --- a/pkg/deployments/scaler.go +++ b/pkg/deployments/scaler.go @@ -116,6 +116,15 @@ func (s *scaler) compareScales(current, desired int32) { } } +func (s *scaler) StopScaleDownTimer() { + s.mtx.Lock() + defer s.mtx.Unlock() + if s.scaleDownTimer != nil { + s.scaleDownTimer.Stop() + } + s.scaleDownStarted = false +} + type scale struct { Current, Min, Max int32 }