diff --git a/broker.go b/broker.go index 2ac4745c..29b095ea 100644 --- a/broker.go +++ b/broker.go @@ -103,6 +103,10 @@ type PublishOptions struct { ClientInfo *ClientInfo // Tags to set Publication.Tags. Tags map[string]string + // IdempotencyKey is an optional key for idempotent publish. Broker implementation + // may cache these keys for some time to prevent duplicate publications. In this case + // the returned result is the same as from the previous publication with the same key. + IdempotencyKey string } // Broker is responsible for PUB/SUB mechanics. diff --git a/broker_memory.go b/broker_memory.go index efb789bf..aa295137 100644 --- a/broker_memory.go +++ b/broker_memory.go @@ -33,6 +33,12 @@ type MemoryBroker struct { closeOnce sync.Once closeCh chan struct{} + + resultKeyExpSeconds int64 + nextExpireCheck int64 + resultExpireQueue priority.Queue + resultCache map[string]StreamPosition + resultCacheMu sync.RWMutex } var _ Broker = (*MemoryBroker)(nil) @@ -42,6 +48,8 @@ type MemoryBrokerConfig struct{} const numPubLocks = 4096 +const idempotentResulExpireSeconds = 30 + // NewMemoryBroker initializes MemoryBroker. func NewMemoryBroker(n *Node, _ MemoryBrokerConfig) (*MemoryBroker, error) { pubLocks := make(map[int]*sync.Mutex, numPubLocks) @@ -50,10 +58,12 @@ func NewMemoryBroker(n *Node, _ MemoryBrokerConfig) (*MemoryBroker, error) { } closeCh := make(chan struct{}) b := &MemoryBroker{ - node: n, - historyHub: newHistoryHub(n.config.HistoryMetaTTL, closeCh), - pubLocks: pubLocks, - closeCh: closeCh, + node: n, + historyHub: newHistoryHub(n.config.HistoryMetaTTL, closeCh), + pubLocks: pubLocks, + closeCh: closeCh, + resultCache: map[string]StreamPosition{}, + resultKeyExpSeconds: idempotentResulExpireSeconds, } return b, nil } @@ -61,6 +71,7 @@ func NewMemoryBroker(n *Node, _ MemoryBrokerConfig) (*MemoryBroker, error) { // Run runs memory broker. func (b *MemoryBroker) Run(h BrokerEventHandler) error { b.eventHandler = h + go b.expireResultCache() b.historyHub.runCleanups() return nil } @@ -84,6 +95,15 @@ func (b *MemoryBroker) Publish(ch string, data []byte, opts PublishOptions) (Str mu.Lock() defer mu.Unlock() + if opts.IdempotencyKey != "" { + b.resultCacheMu.RLock() + if res, ok := b.resultCache[opts.IdempotencyKey]; ok { + b.resultCacheMu.RUnlock() + return res, nil + } + b.resultCacheMu.RUnlock() + } + pub := &Publication{ Data: data, Info: opts.ClientInfo, @@ -95,9 +115,57 @@ func (b *MemoryBroker) Publish(ch string, data []byte, opts PublishOptions) (Str return StreamPosition{}, err } pub.Offset = streamTop.Offset + if opts.IdempotencyKey != "" { + b.saveToResultCache(opts.IdempotencyKey, streamTop) + } return streamTop, b.eventHandler.HandlePublication(ch, pub, streamTop) } - return StreamPosition{}, b.eventHandler.HandlePublication(ch, pub, StreamPosition{}) + streamPosition := StreamPosition{} + if opts.IdempotencyKey != "" { + b.saveToResultCache(opts.IdempotencyKey, streamPosition) + } + return streamPosition, b.eventHandler.HandlePublication(ch, pub, StreamPosition{}) +} + +func (b *MemoryBroker) saveToResultCache(key string, sp StreamPosition) { + b.resultCacheMu.Lock() + b.resultCache[key] = sp + expireAt := time.Now().Unix() + b.resultKeyExpSeconds + heap.Push(&b.resultExpireQueue, &priority.Item{Value: key, Priority: expireAt}) + if b.nextExpireCheck == 0 || b.nextExpireCheck > expireAt { + b.nextExpireCheck = expireAt + } + b.resultCacheMu.Unlock() +} + +func (b *MemoryBroker) expireResultCache() { + var nextExpireCheck int64 + for { + select { + case <-time.After(time.Second): + case <-b.closeCh: + return + } + b.resultCacheMu.Lock() + if b.nextExpireCheck == 0 || b.nextExpireCheck > time.Now().Unix() { + b.resultCacheMu.Unlock() + continue + } + nextExpireCheck = 0 + for b.resultExpireQueue.Len() > 0 { + item := heap.Pop(&b.resultExpireQueue).(*priority.Item) + expireAt := item.Priority + if expireAt > time.Now().Unix() { + heap.Push(&b.resultExpireQueue, item) + nextExpireCheck = expireAt + break + } + key := item.Value + delete(b.resultCache, key) + } + b.nextExpireCheck = nextExpireCheck + b.resultCacheMu.Unlock() + } } // PublishJoin - see Broker interface description. diff --git a/broker_memory_test.go b/broker_memory_test.go index beb33aeb..4f303aa4 100644 --- a/broker_memory_test.go +++ b/broker_memory_test.go @@ -113,6 +113,48 @@ func TestMemoryBrokerPublishHistory(t *testing.T) { require.Equal(t, 1, len(pubs)) } +func TestMemoryBrokerPublishIdempotent(t *testing.T) { + e := testMemoryBroker() + defer func() { _ = e.node.Shutdown(context.Background()) }() + + require.NotEqual(t, nil, e.historyHub) + + // Test publish with history and with idempotency key. + sp1, err := e.Publish("channel", testPublicationData(), PublishOptions{ + HistorySize: 4, + HistoryTTL: time.Second, + IdempotencyKey: "test", + }) + require.NoError(t, err) + pubs, _, err := e.History("channel", HistoryOptions{ + Filter: HistoryFilter{ + Limit: -1, + Since: nil, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(pubs)) + + // Publish with same key. + sp2, err := e.Publish("channel", testPublicationData(), PublishOptions{ + HistorySize: 4, + HistoryTTL: time.Second, + IdempotencyKey: "test", + }) + require.NoError(t, err) + pubs, _, err = e.History("channel", HistoryOptions{ + Filter: HistoryFilter{ + Limit: -1, + Since: nil, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(pubs)) + + // Make sure stream positions match. + require.Equal(t, sp1, sp2) +} + func TestMemoryEngineSubscribeUnsubscribe(t *testing.T) { e := testMemoryBroker() defer func() { _ = e.node.Shutdown(context.Background()) }() diff --git a/broker_redis.go b/broker_redis.go index 3e939eb2..356261d4 100644 --- a/broker_redis.go +++ b/broker_redis.go @@ -67,21 +67,22 @@ type shardWrapper struct { // By default, Redis >= 5 required (due to the fact RedisBroker uses STREAM data structure // to keep publication history for a channel). type RedisBroker struct { - controlRound uint64 - node *Node - sharding bool - config RedisBrokerConfig - shards []*shardWrapper - historyListScript *rueidis.Lua - historyStreamScript *rueidis.Lua - addHistoryListScript *rueidis.Lua - addHistoryStreamScript *rueidis.Lua - shardChannel string - messagePrefix string - controlChannel string - nodeChannel string - closeOnce sync.Once - closeCh chan struct{} + controlRound uint64 + node *Node + sharding bool + config RedisBrokerConfig + shards []*shardWrapper + publishIdempotentScript *rueidis.Lua + historyListScript *rueidis.Lua + historyStreamScript *rueidis.Lua + addHistoryListScript *rueidis.Lua + addHistoryStreamScript *rueidis.Lua + shardChannel string + messagePrefix string + controlChannel string + nodeChannel string + closeOnce sync.Once + closeCh chan struct{} } // RedisBrokerConfig is a config for Broker. @@ -173,15 +174,16 @@ func NewRedisBroker(n *Node, config RedisBrokerConfig) (*RedisBroker, error) { } b := &RedisBroker{ - node: n, - config: config, - shards: shardWrappers, - sharding: len(config.Shards) > 1, - historyStreamScript: rueidis.NewLuaScript(historyStreamSource), - historyListScript: rueidis.NewLuaScript(historyListSource), - addHistoryStreamScript: rueidis.NewLuaScript(addHistoryStreamSource), - addHistoryListScript: rueidis.NewLuaScript(addHistoryListSource), - closeCh: make(chan struct{}), + node: n, + config: config, + shards: shardWrappers, + sharding: len(config.Shards) > 1, + publishIdempotentScript: rueidis.NewLuaScript(publishIdempotentSource), + historyStreamScript: rueidis.NewLuaScript(historyStreamSource), + historyListScript: rueidis.NewLuaScript(historyListSource), + addHistoryStreamScript: rueidis.NewLuaScript(addHistoryStreamSource), + addHistoryListScript: rueidis.NewLuaScript(addHistoryListSource), + closeCh: make(chan struct{}), } b.shardChannel = config.Prefix + redisPubSubShardChannelSuffix b.messagePrefix = config.Prefix + redisClientChannelPrefix @@ -222,6 +224,9 @@ func NewRedisBroker(n *Node, config RedisBrokerConfig) (*RedisBroker, error) { } var ( + //go:embed internal/redis_lua/broker_publish_idempotent.lua + publishIdempotentSource string + //go:embed internal/redis_lua/broker_history_add_list.lua addHistoryListSource string @@ -614,14 +619,50 @@ func (b *RedisBroker) publish(s *shardWrapper, ch string, data []byte, opts Publ publishCommand = "spublish" } + idempotencyKey := opts.IdempotencyKey + var resultKey channelID + var resultExpire string + if idempotencyKey != "" { + resultKey = b.resultCacheKey(s.shard, ch) + resultExpire = strconv.Itoa(idempotentResulExpireSeconds) + } + if opts.HistorySize <= 0 || opts.HistoryTTL <= 0 { var resp rueidis.RedisResult if useShardedPublish { - cmd := s.shard.client.B().Spublish().Channel(string(publishChannel)).Message(convert.BytesToString(byteMessage)).Build() - resp = s.shard.client.Do(context.Background(), cmd) + if resultKey == "" { + cmd := s.shard.client.B().Spublish().Channel(string(publishChannel)).Message(convert.BytesToString(byteMessage)).Build() + resp = s.shard.client.Do(context.Background(), cmd) + } else { + resp = b.publishIdempotentScript.Exec( + context.Background(), + s.shard.client, + []string{string(resultKey)}, + []string{ + convert.BytesToString(byteMessage), + string(publishChannel), + publishCommand, + resultExpire, + }, + ) + } } else { - cmd := s.shard.client.B().Publish().Channel(string(publishChannel)).Message(convert.BytesToString(byteMessage)).Build() - resp = s.shard.client.Do(context.Background(), cmd) + if resultKey == "" { + cmd := s.shard.client.B().Publish().Channel(string(publishChannel)).Message(convert.BytesToString(byteMessage)).Build() + resp = s.shard.client.Do(context.Background(), cmd) + } else { + resp = b.publishIdempotentScript.Exec( + context.Background(), + s.shard.client, + []string{string(resultKey)}, + []string{ + convert.BytesToString(byteMessage), + string(publishChannel), + publishCommand, + resultExpire, + }, + ) + } } return StreamPosition{}, resp.Error() } @@ -651,7 +692,7 @@ func (b *RedisBroker) publish(s *shardWrapper, ch string, data []byte, opts Publ replies, err := script.Exec( context.Background(), s.shard.client, - []string{string(streamKey), string(historyMetaKey)}, + []string{string(streamKey), string(historyMetaKey), string(resultKey)}, []string{ convert.BytesToString(byteMessage), strconv.Itoa(size), @@ -660,6 +701,7 @@ func (b *RedisBroker) publish(s *shardWrapper, ch string, data []byte, opts Publ strconv.Itoa(historyMetaTTLSeconds), strconv.FormatInt(time.Now().Unix(), 10), publishCommand, + resultExpire, }, ).ToArray() if err != nil { @@ -858,6 +900,17 @@ func (b *RedisBroker) nodeChannelID(nodeID string) channelID { return channelID(b.config.Prefix + redisNodeChannelPrefix + nodeID) } +func (b *RedisBroker) resultCacheKey(s *RedisShard, ch string) channelID { + if s.useCluster { + if b.config.numClusterShards > 0 { + ch = "{" + strconv.Itoa(consistentIndex(ch, b.config.numClusterShards)) + "}." + ch + } else { + ch = "{" + ch + "}" + } + } + return channelID(b.config.Prefix + ".result." + ch) +} + func (b *RedisBroker) historyListKey(s *RedisShard, ch string) channelID { if s.useCluster { if b.config.numClusterShards > 0 { diff --git a/broker_redis_test.go b/broker_redis_test.go index d9986488..68ead481 100644 --- a/broker_redis_test.go +++ b/broker_redis_test.go @@ -289,6 +289,62 @@ func TestRedisBroker(t *testing.T) { } } +func TestRedisBrokerPublishIdempotent(t *testing.T) { + for _, tt := range redisTests { + t.Run(tt.Name, func(t *testing.T) { + node := testNode(t) + + b := newTestRedisBroker(t, node, tt.UseStreams, tt.UseCluster) + defer func() { _ = node.Shutdown(context.Background()) }() + defer stopRedisBroker(b) + + _, err := b.Publish("channel", testPublicationData(), PublishOptions{ + IdempotencyKey: "publish_no_history", + }) + require.NoError(t, err) + + _, err = b.Publish("channel", testPublicationData(), PublishOptions{ + IdempotencyKey: "publish_no_history", + }) + require.NoError(t, err) + + rawData := []byte("{}") + + // test adding history + sp1, err := b.Publish("channel", rawData, PublishOptions{ + HistorySize: 4, + HistoryTTL: time.Second, + IdempotencyKey: "publish_with_history", + }) + require.NoError(t, err) + pubs, _, err := b.History("channel", HistoryOptions{ + Filter: HistoryFilter{ + Limit: -1, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(pubs)) + + // test publish with history and same idempotency key. + sp2, err := b.Publish("channel", rawData, PublishOptions{ + HistorySize: 4, + HistoryTTL: time.Second, + IdempotencyKey: "publish_with_history", + }) + require.NoError(t, err) + pubs, _, err = b.History("channel", HistoryOptions{ + Filter: HistoryFilter{ + Limit: -1, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, len(pubs)) + + require.Equal(t, sp1, sp2) + }) + } +} + func TestRedisCurrentPosition(t *testing.T) { for _, tt := range redisTests { t.Run(tt.Name, func(t *testing.T) { @@ -580,8 +636,6 @@ func randString(n int) string { // We just expect +-equal distribution and keeping most of chans on // the same shard after resharding. func TestRedisConsistentIndex(t *testing.T) { - - rand.Seed(time.Now().UnixNano()) numChans := 10000 numShards := 10 chans := make([]string, numChans) diff --git a/internal/redis_lua/broker_history_add_list.lua b/internal/redis_lua/broker_history_add_list.lua index a4959a4f..cabb7de6 100644 --- a/internal/redis_lua/broker_history_add_list.lua +++ b/internal/redis_lua/broker_history_add_list.lua @@ -1,5 +1,6 @@ local list_key = KEYS[1] local meta_key = KEYS[2] +local result_key = KEYS[3] local message_payload = ARGV[1] local ltrim_right_bound = ARGV[2] local list_ttl = ARGV[3] @@ -7,6 +8,15 @@ local channel = ARGV[4] local meta_expire = ARGV[5] local new_epoch_if_empty = ARGV[6] local publish_command = ARGV[7] +local result_key_expire = ARGV[8] + +if result_key ~= '' then + local stream_meta = redis.call("hmget", result_key, "e", "s") + local result_epoch, result_offset = stream_meta[1], stream_meta[2] + if result_epoch ~= false then + return {result_offset, result_epoch} + end +end local current_epoch = redis.call("hget", meta_key, "e") if current_epoch == false then @@ -27,6 +37,11 @@ redis.call("expire", list_key, list_ttl) if channel ~= '' then redis.call(publish_command, channel, payload) + + if result_key ~= '' then + redis.call("hset", result_key, "e", current_epoch, "s", top_offset) + redis.call("expire", result_key, result_key_expire) + end end return {top_offset, current_epoch} diff --git a/internal/redis_lua/broker_history_add_stream.lua b/internal/redis_lua/broker_history_add_stream.lua index 482707f2..c84f9efe 100644 --- a/internal/redis_lua/broker_history_add_stream.lua +++ b/internal/redis_lua/broker_history_add_stream.lua @@ -1,5 +1,6 @@ local stream_key = KEYS[1] local meta_key = KEYS[2] +local result_key = KEYS[3] local message_payload = ARGV[1] local stream_size = ARGV[2] local stream_ttl = ARGV[3] @@ -7,6 +8,15 @@ local channel = ARGV[4] local meta_expire = ARGV[5] local new_epoch_if_empty = ARGV[6] local publish_command = ARGV[7] +local result_key_expire = ARGV[8] + +if result_key ~= '' then + local stream_meta = redis.call("hmget", result_key, "e", "s") + local result_epoch, result_offset = stream_meta[1], stream_meta[2] + if result_epoch ~= false then + return {result_offset, result_epoch} + end +end local current_epoch = redis.call("hget", meta_key, "e") if current_epoch == false then @@ -26,6 +36,11 @@ redis.call("expire", stream_key, stream_ttl) if channel ~= '' then local payload = "__" .. "p1:" .. top_offset .. ":" .. current_epoch .. "__" .. message_payload redis.call(publish_command, channel, payload) + + if result_key ~= '' then + redis.call("hset", result_key, "e", current_epoch, "s", top_offset) + redis.call("expire", result_key, result_key_expire) + end end return {top_offset, current_epoch} diff --git a/internal/redis_lua/broker_publish_idempotent.lua b/internal/redis_lua/broker_publish_idempotent.lua new file mode 100644 index 00000000..d42a200d --- /dev/null +++ b/internal/redis_lua/broker_publish_idempotent.lua @@ -0,0 +1,20 @@ +local result_key = KEYS[1] +local payload = ARGV[1] +local channel = ARGV[2] +local publish_command = ARGV[3] +local result_key_expire = ARGV[4] + +if result_key ~= '' then + local stream_meta = redis.call("hmget", result_key, "e", "s") + local result_epoch, result_offset = stream_meta[1], stream_meta[2] + if result_epoch ~= false then + return {result_offset, result_epoch} + end +end + +redis.call(publish_command, channel, payload) + +if result_key ~= '' then + redis.call("hset", result_key, "e", "") + redis.call("expire", result_key, result_key_expire) +end