From 6d2c1be680e239ec5b1e1867bb015c8b42b52117 Mon Sep 17 00:00:00 2001 From: lonnc Date: Sat, 20 Nov 2010 17:47:00 +0000 Subject: [PATCH] Added PSubscribe methods. --- redis.go | 49 ++++++++++++------- redis_test.go | 127 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 154 insertions(+), 22 deletions(-) diff --git a/redis.go b/redis.go index 9dd3a9e..065866c 100644 --- a/redis.go +++ b/redis.go @@ -1277,32 +1277,36 @@ func (client *Client) Hgetall(key string, val interface{}) os.Error { //Publish/Subscribe type Message struct { - Channel string - Message []byte + ChannelMatched string + Channel string + Message []byte } // Subscribe to channels, will block until the subscribe channel is closed. -func (client *Client) Subscribe(subscribe <-chan string, unsubscribe <-chan string, messages chan<- Message) os.Error { +func (client *Client) Subscribe(subscribe <-chan string, unsubscribe <-chan string, psubscribe <-chan string, punsubscribe <-chan string, messages chan<- Message) os.Error { cmds := make(chan []string, 0) data := make(chan interface{}, 0) go func() { - CHANNELS: for { + var channel string + var cmd string + select { - case channel := <-subscribe: - if channel == "" { - break CHANNELS - } else { - cmds <- []string{"SUBSCRIBE", channel} - } + case channel = <-subscribe: + cmd = "SUBSCRIBE" + case channel = <-unsubscribe: + cmd = "UNSUBSCRIBE" + case channel = <-psubscribe: + cmd = "PSUBSCRIBE" + case channel = <-punsubscribe: + cmd = "UNPSUBSCRIBE" - case channel := <-unsubscribe: - if channel == "" { - break CHANNELS - } else { - cmds <- []string{"UNSUBSCRIBE", channel} - } + } + if channel == "" { + break + } else { + cmds <- []string{cmd, channel} } } close(cmds) @@ -1312,14 +1316,23 @@ func (client *Client) Subscribe(subscribe <-chan string, unsubscribe <-chan stri go func() { for response := range data { db := response.([][]byte) - messageType, channel, message := string(db[0]), string(db[1]), db[2] + messageType := string(db[0]) switch messageType { case "message": - messages <- Message{string(channel), message} + channel, message := string(db[1]), db[2] + messages <- Message{channel, channel, message} case "subscribe": // Ignore case "unsubscribe": // Ignore + case "pmessage": + channelMatched, channel, message := string(db[1]), string(db[2]), db[3] + messages <- Message{channelMatched, channel, message} + case "psubscribe": + // Ignore + case "punsubscribe": + // Ignore + default: // log.Printf("Unknown message '%s'", messageType) } diff --git a/redis_test.go b/redis_test.go index 71e535d..c89960d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -241,15 +241,19 @@ func TestBlpopTimeout(t *testing.T) { func TestSubscribe(t *testing.T) { subscribe := make(chan string, 0) unsubscribe := make(chan string, 0) + psubscribe := make(chan string, 0) + punsubscribe := make(chan string, 0) messages := make(chan Message, 0) defer func() { close(subscribe) close(unsubscribe) + close(psubscribe) + close(punsubscribe) close(messages) }() go func() { - if err := client.Subscribe(subscribe, unsubscribe, messages); err != nil { + if err := client.Subscribe(subscribe, unsubscribe, psubscribe, punsubscribe, messages); err != nil { t.Fatal("Subscribed failed", err.String()) } }() @@ -261,14 +265,13 @@ func TestSubscribe(t *testing.T) { go func() { tick := time.Tick(10 * 1000 * 1000) // 10ms timeout := time.Tick(100 * 1000 * 1000) // 100ms - LOOP: + for { select { case <-quit: - break LOOP + return case <-timeout: t.Fatal("TestSubscribe timeout") - break LOOP case <-tick: if err := client.Publish("ccc", data); err != nil { t.Fatal("Pubish failed", err.String()) @@ -288,6 +291,122 @@ func TestSubscribe(t *testing.T) { close(subscribe) } +func TestUnsubscribe(t *testing.T) { + subscribe := make(chan string, 0) + unsubscribe := make(chan string, 0) + psubscribe := make(chan string, 0) + punsubscribe := make(chan string, 0) + messages := make(chan Message, 0) + + defer func() { + close(subscribe) + close(unsubscribe) + close(psubscribe) + close(punsubscribe) + close(messages) + }() + go func() { + if err := client.Subscribe(subscribe, unsubscribe, psubscribe, punsubscribe, messages); err != nil { + t.Fatal("Subscribed failed", err.String()) + } + }() + subscribe <- "ccc" + + data := []byte("foo") + quit := make(chan bool, 0) + defer close(quit) + go func() { + tick := time.Tick(10 * 1000 * 1000) // 10ms + + for i := 0; i < 10; i++ { + <-tick + if err := client.Publish("ccc", data); err != nil { + t.Fatal("Pubish failed", err.String()) + } + } + quit <- true + }() + + msgs := 0 + for !closed(subscribe) { + select { + case msg := <-messages: + if string(msg.Message) != string(data) { + t.Fatalf("Expected %s but got %s", string(data), string(msg.Message)) + } + + // Unsubscribe after first message + if msgs == 0 { + unsubscribe <- "ccc" + } + msgs++ + case <-quit: + // Allow for a little delay and extra async messages getting through + if msgs > 3 { + t.Fatalf("Expected to have unsubscribed after 1 message but received %d", msgs) + } + return + } + } +} + + +func TestPSubscribe(t *testing.T) { + subscribe := make(chan string, 0) + unsubscribe := make(chan string, 0) + psubscribe := make(chan string, 0) + punsubscribe := make(chan string, 0) + messages := make(chan Message, 0) + + defer func() { + close(subscribe) + close(unsubscribe) + close(psubscribe) + close(punsubscribe) + close(messages) + }() + go func() { + if err := client.Subscribe(subscribe, unsubscribe, psubscribe, punsubscribe, messages); err != nil { + t.Fatal("Subscribed failed", err.String()) + } + }() + psubscribe <- "ccc.*" + + data := []byte("foo") + quit := make(chan bool, 0) + defer close(quit) + go func() { + tick := time.Tick(10 * 1000 * 1000) // 10ms + timeout := time.Tick(100 * 1000 * 1000) // 100ms + + for { + select { + case <-quit: + return + case <-timeout: + t.Fatal("TestSubscribe timeout") + case <-tick: + if err := client.Publish("ccc.foo", data); err != nil { + t.Fatal("Pubish failed", err.String()) + } + } + } + }() + + msg := <-messages + quit <- true + if msg.Channel != "ccc.foo" { + t.Fatal("Unexpected channel name") + } + if msg.ChannelMatched != "ccc.*" { + t.Fatal("Unexpected channel name") + } + if string(msg.Message) != string(data) { + t.Fatalf("Expected %s but got %s", string(data), string(msg.Message)) + } + close(subscribe) +} + func verifyHash(t *testing.T, key string, expected map[string][]byte) { //test Hget m1 := make(map[string][]byte)