Skip to content

Commit

Permalink
feat: dynamic queue concurrency
Browse files Browse the repository at this point in the history
# Conflicts:
#	internal/rdb/rdb.go
  • Loading branch information
pcmid committed Nov 20, 2024
1 parent be15ef6 commit 541a69a
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 15 deletions.
5 changes: 5 additions & 0 deletions aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func newAggregator(params aggregatorParams) *aggregator {
}
}

func (a *aggregator) resetState() {
a.done = make(chan struct{})
a.sema = make(chan struct{}, maxConcurrentAggregationChecks)
}

func (a *aggregator) shutdown() {
if a.ga == nil {
return
Expand Down
1 change: 1 addition & 0 deletions heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/google/uuid"

"github.com/hibiken/asynq/internal/base"
"github.com/hibiken/asynq/internal/log"
"github.com/hibiken/asynq/internal/timeutil"
Expand Down
8 changes: 5 additions & 3 deletions internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import (
"sync"
"time"

"github.com/hibiken/asynq/internal/errors"
pb "github.com/hibiken/asynq/internal/proto"
"github.com/hibiken/asynq/internal/timeutil"
"github.com/redis/go-redis/v9"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/hibiken/asynq/internal/errors"
pb "github.com/hibiken/asynq/internal/proto"
"github.com/hibiken/asynq/internal/timeutil"
)

// Version of asynq library and CLI.
Expand Down Expand Up @@ -722,4 +723,5 @@ type Broker interface {
PublishCancelation(id string) error

WriteResult(qname, id string, data []byte) (n int, err error)
SetQueueConcurrency(qname string, concurrency int)
}
14 changes: 10 additions & 4 deletions internal/rdb/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ type Option func(r *RDB)

func WithQueueConcurrency(queueConcurrency map[string]int) Option {
return func(r *RDB) {
r.queueConcurrency = queueConcurrency
for qname, concurrency := range queueConcurrency {
r.queueConcurrency.Store(qname, concurrency)
}
}
}

Expand All @@ -39,7 +41,7 @@ type RDB struct {
client redis.UniversalClient
clock timeutil.Clock
queuesPublished sync.Map
queueConcurrency map[string]int
queueConcurrency sync.Map
}

// NewRDB returns a new instance of RDB.
Expand Down Expand Up @@ -271,8 +273,8 @@ func (r *RDB) Dequeue(qnames ...string) (msg *base.TaskMessage, leaseExpirationT
base.LeaseKey(qname),
}
leaseExpirationTime = r.clock.Now().Add(LeaseDuration)
queueConcurrency, ok := r.queueConcurrency[qname]
if !ok || queueConcurrency <= 0 {
queueConcurrency, ok := r.queueConcurrency.Load(qname)
if !ok || queueConcurrency.(int) <= 0 {
queueConcurrency = math.MaxInt
}
argv := []interface{}{
Expand Down Expand Up @@ -1581,3 +1583,7 @@ func (r *RDB) WriteResult(qname, taskID string, data []byte) (int, error) {
}
return len(data), nil
}

func (r *RDB) SetQueueConcurrency(qname string, concurrency int) {
r.queueConcurrency.Store(qname, concurrency)
}
7 changes: 6 additions & 1 deletion internal/testbroker/testbroker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ import (
"sync"
"time"

"github.com/hibiken/asynq/internal/base"
"github.com/redis/go-redis/v9"

"github.com/hibiken/asynq/internal/base"
)

var errRedisDown = errors.New("testutil: redis is down")
Expand Down Expand Up @@ -297,3 +298,7 @@ func (tb *TestBroker) ReclaimStaleAggregationSets(qname string) error {
}
return tb.real.ReclaimStaleAggregationSets(qname)
}

func (tb *TestBroker) SetQueueConcurrency(qname string, concurrency int) {
tb.real.SetQueueConcurrency(qname, concurrency)
}
18 changes: 15 additions & 3 deletions processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ import (
"sync"
"time"

"golang.org/x/time/rate"

"github.com/hibiken/asynq/internal/base"
asynqcontext "github.com/hibiken/asynq/internal/context"
"github.com/hibiken/asynq/internal/errors"
"github.com/hibiken/asynq/internal/log"
"github.com/hibiken/asynq/internal/timeutil"
"golang.org/x/time/rate"
)

type processor struct {
Expand Down Expand Up @@ -57,7 +58,7 @@ type processor struct {
// channel to communicate back to the long running "processor" goroutine.
// once is used to send value to the channel only once.
done chan struct{}
once sync.Once
once *sync.Once

// quit channel is closed when the shutdown of the "processor" goroutine starts.
quit chan struct{}
Expand Down Expand Up @@ -112,6 +113,7 @@ func newProcessor(params processorParams) *processor {
errLogLimiter: rate.NewLimiter(rate.Every(3*time.Second), 1),
sema: make(chan struct{}, params.concurrency),
done: make(chan struct{}),
once: &sync.Once{},
quit: make(chan struct{}),
abort: make(chan struct{}),
errHandler: params.errHandler,
Expand Down Expand Up @@ -139,7 +141,9 @@ func (p *processor) stop() {
func (p *processor) shutdown() {
p.stop()

time.AfterFunc(p.shutdownTimeout, func() { close(p.abort) })
go func(abort chan struct{}) {
time.AfterFunc(p.shutdownTimeout, func() { close(abort) })
}(p.abort)

p.logger.Info("Waiting for all workers to finish...")
// block until all workers have released the token
Expand All @@ -149,6 +153,14 @@ func (p *processor) shutdown() {
p.logger.Info("All workers have finished")
}

func (p *processor) resetState() {
p.sema = make(chan struct{}, cap(p.sema))
p.done = make(chan struct{})
p.quit = make(chan struct{})
p.abort = make(chan struct{})
p.once = &sync.Once{}
}

func (p *processor) start(wg *sync.WaitGroup) {
wg.Add(1)
go func() {
Expand Down
85 changes: 84 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ type Server struct {

state *serverState

mu sync.RWMutex
queues map[string]int
strictPriority bool

// wait group to wait for all goroutines to finish.
wg sync.WaitGroup
forwarder *forwarder
Expand Down Expand Up @@ -481,7 +485,9 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server {
}
}
if len(queues) == 0 {
queues = defaultQueueConfig
for qname, p := range defaultQueueConfig {
queues[qname] = p
}
}
var qnames []string
for q := range queues {
Expand Down Expand Up @@ -610,6 +616,8 @@ func NewServerFromRedisClient(c redis.UniversalClient, cfg Config) *Server {
groupAggregator: cfg.GroupAggregator,
})
return &Server{
queues: queues,
strictPriority: cfg.StrictPriority,
logger: logger,
broker: rdb,
sharedConnection: true,
Expand Down Expand Up @@ -792,3 +800,78 @@ func (srv *Server) Ping() error {

return srv.broker.Ping()
}

func (srv *Server) AddQueue(qname string, priority, concurrency int) {
srv.mu.Lock()
defer srv.mu.Unlock()

if _, ok := srv.queues[qname]; ok {
srv.logger.Warnf("queue %s already exists, skipping", qname)
return
}

srv.queues[qname] = priority

srv.state.mu.Lock()
state := srv.state.value
srv.state.mu.Unlock()
if state == srvStateNew || state == srvStateClosed {
srv.queues[qname] = priority
return
}

srv.logger.Info("restart server...")
srv.forwarder.shutdown()
srv.processor.shutdown()
srv.recoverer.shutdown()
srv.syncer.shutdown()
srv.subscriber.shutdown()
srv.janitor.shutdown()
srv.aggregator.shutdown()
srv.healthchecker.shutdown()
srv.heartbeater.shutdown()
srv.wg.Wait()

qnames := make([]string, 0, len(srv.queues))
for q := range srv.queues {
qnames = append(qnames, q)
}
srv.broker.SetQueueConcurrency(qname, concurrency)
srv.heartbeater.queues = srv.queues
srv.recoverer.queues = qnames
srv.forwarder.queues = qnames
srv.processor.resetState()
queues := normalizeQueues(srv.queues)
orderedQueues := []string(nil)
if srv.strictPriority {
orderedQueues = sortByPriority(queues)
}
srv.processor.queueConfig = srv.queues
srv.processor.orderedQueues = orderedQueues
srv.janitor.queues = qnames
srv.aggregator.resetState()
srv.aggregator.queues = qnames

srv.heartbeater.start(&srv.wg)
srv.healthchecker.start(&srv.wg)
srv.subscriber.start(&srv.wg)
srv.syncer.start(&srv.wg)
srv.recoverer.start(&srv.wg)
srv.forwarder.start(&srv.wg)
srv.processor.start(&srv.wg)
srv.janitor.start(&srv.wg)
srv.aggregator.start(&srv.wg)

srv.logger.Info("server restarted")
}

func (srv *Server) HasQueue(qname string) bool {
srv.mu.RLock()
defer srv.mu.RUnlock()
_, ok := srv.queues[qname]
return ok
}

func (srv *Server) SetQueueConcurrency(queue string, concurrency int) {
srv.broker.SetQueueConcurrency(queue, concurrency)
}
105 changes: 102 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ func TestServerWithQueueConcurrency(t *testing.T) {
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
}

c := NewClient(redisConnOpt)
defer c.Close()

const taskNum = 8
const serverNum = 2
tests := []struct {
Expand Down Expand Up @@ -134,6 +131,8 @@ func TestServerWithQueueConcurrency(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var err error
testutil.FlushDB(t, r)
c := NewClient(redisConnOpt)
defer c.Close()
for i := 0; i < taskNum; i++ {
_, err = c.Enqueue(NewTask("send_email",
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})))
Expand Down Expand Up @@ -173,6 +172,106 @@ func TestServerWithQueueConcurrency(t *testing.T) {
}
}

func TestServerWithDynamicQueue(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029
ignoreOpt := goleak.IgnoreTopFunction("github.com/redis/go-redis/v9/internal/pool.(*ConnPool).reaper")
defer goleak.VerifyNone(t, ignoreOpt)

redisConnOpt := getRedisConnOpt(t)
r, ok := redisConnOpt.MakeRedisClient().(redis.UniversalClient)
if !ok {
t.Fatalf("asynq: unsupported RedisConnOpt type %T", r)
}

const taskNum = 8
const serverNum = 2
tests := []struct {
name string
concurrency int
queueConcurrency int
wantActiveNum int
}{
{
name: "based on client concurrency control",
concurrency: 2,
queueConcurrency: 6,
wantActiveNum: 2 * serverNum,
},
{
name: "no queue concurrency control",
concurrency: 2,
queueConcurrency: 0,
wantActiveNum: 2 * serverNum,
},
{
name: "based on queue concurrency control",
concurrency: 6,
queueConcurrency: 2,
wantActiveNum: 2 * serverNum,
},
}

// no-op handler
handle := func(ctx context.Context, task *Task) error {
time.Sleep(time.Second * 2)
return nil
}

var DynamicQueueNameFmt = "dynamic:%d:%d"
var servers [serverNum]*Server
for tcn, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var err error
testutil.FlushDB(t, r)
c := NewClient(redisConnOpt)
defer c.Close()
for i := 0; i < taskNum; i++ {
_, err = c.Enqueue(NewTask("send_email",
testutil.JSON(map[string]interface{}{"recipient_id": i + 123})),
Queue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i%2)))
if err != nil {
t.Fatalf("could not enqueue a task: %v", err)
}
}

for i := 0; i < serverNum; i++ {
srv := NewServer(redisConnOpt, Config{
Concurrency: tc.concurrency,
LogLevel: testLogLevel,
QueueConcurrency: map[string]int{base.DefaultQueueName: tc.queueConcurrency},
})
err = srv.Start(HandlerFunc(handle))
if err != nil {
t.Fatal(err)
}
srv.AddQueue(fmt.Sprintf(DynamicQueueNameFmt, tcn, i), 1, tc.queueConcurrency)
servers[i] = srv
}
defer func() {
for _, srv := range servers {
srv.Shutdown()
}
}()

time.Sleep(time.Second)
inspector := NewInspector(redisConnOpt)

var tasks []*TaskInfo

for i := range servers {
qtasks, err := inspector.ListActiveTasks(fmt.Sprintf(DynamicQueueNameFmt, tcn, i))
if err != nil {
t.Fatalf("could not list active tasks: %v", err)
}
tasks = append(tasks, qtasks...)
}

if len(tasks) != tc.wantActiveNum {
t.Errorf("dynamic queue has %d active tasks, want %d", len(tasks), tc.wantActiveNum)
}
})
}
}

func TestServerRun(t *testing.T) {
// https://github.com/go-redis/redis/issues/1029
Expand Down

0 comments on commit 541a69a

Please sign in to comment.