From 1ad70c359b2b212aaf01f3e4a56d0978d25ae391 Mon Sep 17 00:00:00 2001
From: Yousif <753751+yousifh@users.noreply.github.com>
Date: Thu, 18 Aug 2022 16:23:23 -0400
Subject: [PATCH 1/3] Improve performance of enqueueing tasks

Add an in-memory cache to keep track of all the queues. Use this cache
to avoid sending an SADD since after the first call, that extra network
call isn't necessary.

The cache will expire every 10 secs so for cases where the queue is
deleted from asynq:queues set, it can be added again next time a task is
enqueued to it.
---
 internal/rdb/rdb.go      | 90 +++++++++++++++++++++++++++++++++-------
 internal/rdb/rdb_test.go | 48 +++++++++++++++++++++
 2 files changed, 122 insertions(+), 16 deletions(-)

diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go
index 74c140ed..ecbb4004 100644
--- a/internal/rdb/rdb.go
+++ b/internal/rdb/rdb.go
@@ -9,6 +9,7 @@ import (
 	"context"
 	"fmt"
 	"math"
+	"sync"
 	"time"
 
 	"github.com/google/uuid"
@@ -26,15 +27,18 @@ const LeaseDuration = 30 * time.Second
 
 // RDB is a client interface to query and mutate task queues.
 type RDB struct {
-	client redis.UniversalClient
-	clock  timeutil.Clock
+	client      redis.UniversalClient
+	clock       timeutil.Clock
+	queuesCache map[string]time.Time
+	mu          sync.Mutex
 }
 
 // NewRDB returns a new instance of RDB.
 func NewRDB(client redis.UniversalClient) *RDB {
 	return &RDB{
-		client: client,
-		clock:  timeutil.NewRealClock(),
+		client:      client,
+		clock:       timeutil.NewRealClock(),
+		queuesCache: map[string]time.Time{},
 	}
 }
 
@@ -112,9 +116,16 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
+
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
 		base.PendingKey(msg.Queue),
@@ -134,6 +145,23 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
 	return nil
 }
 
+func (r *RDB) runIfNeeded(queue string, fn func() error) error {
+	r.mu.Lock()
+	now := r.clock.Now()
+	expiration := r.queuesCache[queue]
+	expired := now.After(expiration)
+	if expired {
+		r.queuesCache[queue] = now.Add(10 * time.Second)
+	}
+	r.mu.Unlock()
+
+	if expired {
+		return fn()
+	}
+
+	return nil
+}
+
 // enqueueUniqueCmd enqueues the task message if the task is unique.
 //
 // KEYS[1] -> unique key
@@ -174,8 +202,14 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
 	if err != nil {
 		return errors.E(op, errors.Internal, "cannot encode task message: %v", err)
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
 	keys := []string{
 		msg.UniqueKey,
@@ -529,8 +563,14 @@ func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey st
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -591,8 +631,14 @@ func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, group
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -648,8 +694,14 @@ func (r *RDB) Schedule(ctx context.Context, msg *base.TaskMessage, processAt tim
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -707,8 +759,14 @@ func (r *RDB) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, process
 	if err != nil {
 		return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode task message: %v", err))
 	}
-	if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
-		return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+	err = r.runIfNeeded(msg.Queue, func() error {
+		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
+			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
+		}
+		return nil
+	})
+	if err != nil {
+		return err
 	}
 	keys := []string{
 		msg.UniqueKey,
diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go
index 433d6fbc..909b36b8 100644
--- a/internal/rdb/rdb_test.go
+++ b/internal/rdb/rdb_test.go
@@ -160,6 +160,54 @@ func TestEnqueueTaskIdConflictError(t *testing.T) {
 	}
 }
 
+func TestEnqueueQueueCache(t *testing.T) {
+	r := setup(t)
+	defer r.Close()
+	t1 := h.NewTaskMessageWithQueue("sync1", nil, "low")
+	t2 := h.NewTaskMessageWithQueue("sync2", nil, "low")
+
+	enqueueTime := time.Now()
+	clock := timeutil.NewSimulatedClock(enqueueTime)
+	r.SetClock(clock)
+
+	err := r.Enqueue(context.Background(), t1)
+	if err != nil {
+		t.Fatalf("(*RDB).Enqueue(msg) = %v, want nil", err)
+	}
+
+	// Check queue is in the AllQueues set.
+	if !r.client.SIsMember(context.Background(), base.AllQueues, t1.Queue).Val() {
+		t.Fatalf("%q is not a member of SET %q", t1.Queue, base.AllQueues)
+	}
+
+	if _, ok := r.queuesCache[t1.Queue]; !ok {
+		t.Fatalf("%q is not cached in %v", t1.Queue, r.queuesCache)
+	}
+
+	// Move clock to ensure cache is expired
+	clock.AdvanceTime(15 * time.Second)
+
+	// Delete queue from AllQueues set to ensure it will be re-added
+	err = r.client.SRem(context.Background(), base.AllQueues, "low").Err()
+	if err != nil {
+		t.Fatalf("Redis SREM = %v, want nil", err)
+	}
+
+	err = r.Enqueue(context.Background(), t2)
+	if err != nil {
+		t.Fatalf("(*RDB).Enqueue(msg) = %v, want nil", err)
+	}
+
+	if !r.client.SIsMember(context.Background(), base.AllQueues, t2.Queue).Val() {
+		t.Fatalf("%q is not a member of SET %q", t2.Queue, base.AllQueues)
+	}
+
+	// Should be cached again
+	if expiration := r.queuesCache[t2.Queue]; expiration.Before(clock.Now()) {
+		t.Fatalf("%q cache is too old %v", t2.Queue, expiration)
+	}
+}
+
 func TestEnqueueUnique(t *testing.T) {
 	r := setup(t)
 	defer r.Close()

From 2e99b71a49786ab153281dc16e79c4e4b6643b06 Mon Sep 17 00:00:00 2001
From: Pior Bastida <pior.bastida@shopify.com>
Date: Fri, 25 Oct 2024 13:37:20 +0200
Subject: [PATCH 2/3] Use sync.Map to simplify the conditional SADD

---
 internal/rdb/rdb.go      | 72 ++++++++++------------------------------
 internal/rdb/rdb_test.go | 28 ++--------------
 2 files changed, 19 insertions(+), 81 deletions(-)

diff --git a/internal/rdb/rdb.go b/internal/rdb/rdb.go
index ecbb4004..e358b015 100644
--- a/internal/rdb/rdb.go
+++ b/internal/rdb/rdb.go
@@ -27,18 +27,16 @@ const LeaseDuration = 30 * time.Second
 
 // RDB is a client interface to query and mutate task queues.
 type RDB struct {
-	client      redis.UniversalClient
-	clock       timeutil.Clock
-	queuesCache map[string]time.Time
-	mu          sync.Mutex
+	client          redis.UniversalClient
+	clock           timeutil.Clock
+	queuesPublished sync.Map
 }
 
 // NewRDB returns a new instance of RDB.
 func NewRDB(client redis.UniversalClient) *RDB {
 	return &RDB{
-		client:      client,
-		clock:       timeutil.NewRealClock(),
-		queuesCache: map[string]time.Time{},
+		client: client,
+		clock:  timeutil.NewRealClock(),
 	}
 }
 
@@ -116,16 +114,12 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
-
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
 		base.PendingKey(msg.Queue),
@@ -145,23 +139,6 @@ func (r *RDB) Enqueue(ctx context.Context, msg *base.TaskMessage) error {
 	return nil
 }
 
-func (r *RDB) runIfNeeded(queue string, fn func() error) error {
-	r.mu.Lock()
-	now := r.clock.Now()
-	expiration := r.queuesCache[queue]
-	expired := now.After(expiration)
-	if expired {
-		r.queuesCache[queue] = now.Add(10 * time.Second)
-	}
-	r.mu.Unlock()
-
-	if expired {
-		return fn()
-	}
-
-	return nil
-}
-
 // enqueueUniqueCmd enqueues the task message if the task is unique.
 //
 // KEYS[1] -> unique key
@@ -202,14 +179,11 @@ func (r *RDB) EnqueueUnique(ctx context.Context, msg *base.TaskMessage, ttl time
 	if err != nil {
 		return errors.E(op, errors.Internal, "cannot encode task message: %v", err)
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
 	keys := []string{
 		msg.UniqueKey,
@@ -563,14 +537,11 @@ func (r *RDB) AddToGroup(ctx context.Context, msg *base.TaskMessage, groupKey st
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -631,14 +602,11 @@ func (r *RDB) AddToGroupUnique(ctx context.Context, msg *base.TaskMessage, group
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -694,14 +662,11 @@ func (r *RDB) Schedule(ctx context.Context, msg *base.TaskMessage, processAt tim
 	if err != nil {
 		return errors.E(op, errors.Unknown, fmt.Sprintf("cannot encode message: %v", err))
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
 	keys := []string{
 		base.TaskKey(msg.Queue, msg.ID),
@@ -759,14 +724,11 @@ func (r *RDB) ScheduleUnique(ctx context.Context, msg *base.TaskMessage, process
 	if err != nil {
 		return errors.E(op, errors.Internal, fmt.Sprintf("cannot encode task message: %v", err))
 	}
-	err = r.runIfNeeded(msg.Queue, func() error {
+	if _, found := r.queuesPublished.Load(msg.Queue); !found {
 		if err := r.client.SAdd(ctx, base.AllQueues, msg.Queue).Err(); err != nil {
 			return errors.E(op, errors.Unknown, &errors.RedisCommandError{Command: "sadd", Err: err})
 		}
-		return nil
-	})
-	if err != nil {
-		return err
+		r.queuesPublished.Store(msg.Queue, true)
 	}
 	keys := []string{
 		msg.UniqueKey,
diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go
index 909b36b8..62a1e2c9 100644
--- a/internal/rdb/rdb_test.go
+++ b/internal/rdb/rdb_test.go
@@ -164,7 +164,6 @@ func TestEnqueueQueueCache(t *testing.T) {
 	r := setup(t)
 	defer r.Close()
 	t1 := h.NewTaskMessageWithQueue("sync1", nil, "low")
-	t2 := h.NewTaskMessageWithQueue("sync2", nil, "low")
 
 	enqueueTime := time.Now()
 	clock := timeutil.NewSimulatedClock(enqueueTime)
@@ -180,31 +179,8 @@ func TestEnqueueQueueCache(t *testing.T) {
 		t.Fatalf("%q is not a member of SET %q", t1.Queue, base.AllQueues)
 	}
 
-	if _, ok := r.queuesCache[t1.Queue]; !ok {
-		t.Fatalf("%q is not cached in %v", t1.Queue, r.queuesCache)
-	}
-
-	// Move clock to ensure cache is expired
-	clock.AdvanceTime(15 * time.Second)
-
-	// Delete queue from AllQueues set to ensure it will be re-added
-	err = r.client.SRem(context.Background(), base.AllQueues, "low").Err()
-	if err != nil {
-		t.Fatalf("Redis SREM = %v, want nil", err)
-	}
-
-	err = r.Enqueue(context.Background(), t2)
-	if err != nil {
-		t.Fatalf("(*RDB).Enqueue(msg) = %v, want nil", err)
-	}
-
-	if !r.client.SIsMember(context.Background(), base.AllQueues, t2.Queue).Val() {
-		t.Fatalf("%q is not a member of SET %q", t2.Queue, base.AllQueues)
-	}
-
-	// Should be cached again
-	if expiration := r.queuesCache[t2.Queue]; expiration.Before(clock.Now()) {
-		t.Fatalf("%q cache is too old %v", t2.Queue, expiration)
+	if _, ok := r.queuesPublished.Load(t1.Queue); !ok {
+		t.Fatalf("%q is not cached in queuesPublished", t1.Queue)
 	}
 }
 

From 24da63e6d2eb7e20c3acf2368569ab2752baca1b Mon Sep 17 00:00:00 2001
From: Pior Bastida <pior.bastida@shopify.com>
Date: Tue, 29 Oct 2024 09:41:31 +0100
Subject: [PATCH 3/3] Cleanup queuePublished in RemoveQueue

---
 internal/rdb/inspect.go  |  3 ++-
 internal/rdb/rdb_test.go | 29 +++++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/internal/rdb/inspect.go b/internal/rdb/inspect.go
index cbaf4bfd..a18c4e29 100644
--- a/internal/rdb/inspect.go
+++ b/internal/rdb/inspect.go
@@ -10,9 +10,9 @@ import (
 	"strings"
 	"time"
 
-	"github.com/redis/go-redis/v9"
 	"github.com/hibiken/asynq/internal/base"
 	"github.com/hibiken/asynq/internal/errors"
+	"github.com/redis/go-redis/v9"
 	"github.com/spf13/cast"
 )
 
@@ -1832,6 +1832,7 @@ func (r *RDB) RemoveQueue(qname string, force bool) error {
 		if err := r.client.SRem(context.Background(), base.AllQueues, qname).Err(); err != nil {
 			return errors.E(op, errors.Unknown, err)
 		}
+		r.queuesPublished.Delete(qname)
 		return nil
 	case -1:
 		return errors.E(op, errors.NotFound, &errors.QueueNotEmptyError{Queue: qname})
diff --git a/internal/rdb/rdb_test.go b/internal/rdb/rdb_test.go
index 62a1e2c9..5249a29a 100644
--- a/internal/rdb/rdb_test.go
+++ b/internal/rdb/rdb_test.go
@@ -182,6 +182,35 @@ func TestEnqueueQueueCache(t *testing.T) {
 	if _, ok := r.queuesPublished.Load(t1.Queue); !ok {
 		t.Fatalf("%q is not cached in queuesPublished", t1.Queue)
 	}
+
+	t.Run("remove-queue", func(t *testing.T) {
+		err := r.RemoveQueue(t1.Queue, true)
+		if err != nil {
+			t.Errorf("(*RDB).RemoveQueue(%q, %t) = %v, want nil", t1.Queue, true, err)
+		}
+
+		if _, ok := r.queuesPublished.Load(t1.Queue); ok {
+			t.Fatalf("%q is still cached in queuesPublished", t1.Queue)
+		}
+
+		if r.client.SIsMember(context.Background(), base.AllQueues, t1.Queue).Val() {
+			t.Fatalf("%q is a member of SET %q", t1.Queue, base.AllQueues)
+		}
+
+		err = r.Enqueue(context.Background(), t1)
+		if err != nil {
+			t.Fatalf("(*RDB).Enqueue(msg) = %v, want nil", err)
+		}
+
+		// Check queue is in the AllQueues set.
+		if !r.client.SIsMember(context.Background(), base.AllQueues, t1.Queue).Val() {
+			t.Fatalf("%q is not a member of SET %q", t1.Queue, base.AllQueues)
+		}
+
+		if _, ok := r.queuesPublished.Load(t1.Queue); !ok {
+			t.Fatalf("%q is not cached in queuesPublished", t1.Queue)
+		}
+	})
 }
 
 func TestEnqueueUnique(t *testing.T) {