diff --git a/redis.go b/redis.go index 4247586..9dd3a9e 100644 --- a/redis.go +++ b/redis.go @@ -72,6 +72,12 @@ func readBulk(reader *bufio.Reader, head string) ([]byte, os.Error) { return data, err } +func writeRequest(writer io.Writer, cmd string, args ...string) os.Error { + b := commandBytes(cmd, args...) + _, err := writer.Write(b) + return err +} + func commandBytes(cmd string, args ...string) []byte { cmdbuf := bytes.NewBufferString(fmt.Sprintf("*%d\r\n$%d\r\n%s\r\n", len(args)+1, len(cmd), cmd)) for _, s := range args { @@ -214,6 +220,83 @@ End: return data, err } +func (client *Client) sendCommands(cmdArgs <-chan []string, data chan<- interface{}) (err os.Error) { + // grab a connection from the pool + c, err := client.popCon() + + if err != nil { + goto End + } + + reader := bufio.NewReader(c) + + // Ping first to verify connection is open + err = writeRequest(c, "PING") + + // On first attempt permit a reconnection attempt + if err == os.EOF { + // Looks like we have to open a new connection + c, err = client.openConnection() + if err != nil { + goto End + } + reader = bufio.NewReader(c) + } else { + // Read Ping response + pong, err := readResponse(reader) + if pong != "PONG" { + return RedisError("Unexpected response to PING.") + } + if err != nil { + goto End + } + } + + errs := make(chan os.Error) + + go func() { + for cmdArg := range cmdArgs { + err = writeRequest(c, cmdArg[0], cmdArg[1:]...) + if err != nil { + if !closed(errs) { + errs <- err + } + break + } + } + close(errs) + }() + + go func() { + for { + response, err := readResponse(reader) + if err != nil { + if !closed(errs) { + errs <- err + } + break + } + data <- response + } + close(errs) + }() + + // Block until errs channel closes + for e := range errs { + err = e + } + +End: + + // Close client and synchronization issues are a nightmare to solve. + c.Close() + + // Push nil back onto queue + client.pushCon(nil) + + return err +} + func (client *Client) popCon() (*net.TCPConn, os.Error) { if client.pool == nil { client.pool = make(chan *net.TCPConn, MaxPoolSize) @@ -1191,6 +1274,71 @@ func (client *Client) Hgetall(key string, val interface{}) os.Error { return nil } +//Publish/Subscribe + +type Message struct { + Channel string + Message []byte +} + +// Subscribe to channels, will block until the subscribe channel is closed. +func (client *Client) Subscribe(subscribe <-chan string, unsubscribe <-chan string, messages chan<- Message) os.Error { + cmds := make(chan []string, 0) + data := make(chan interface{}, 0) + + go func() { + CHANNELS: + for { + select { + case channel := <-subscribe: + if channel == "" { + break CHANNELS + } else { + cmds <- []string{"SUBSCRIBE", channel} + } + + case channel := <-unsubscribe: + if channel == "" { + break CHANNELS + } else { + cmds <- []string{"UNSUBSCRIBE", channel} + } + } + } + close(cmds) + close(data) + }() + + go func() { + for response := range data { + db := response.([][]byte) + messageType, channel, message := string(db[0]), string(db[1]), db[2] + switch messageType { + case "message": + messages <- Message{string(channel), message} + case "subscribe": + // Ignore + case "unsubscribe": + // Ignore + default: + // log.Printf("Unknown message '%s'", messageType) + } + } + }() + + err := client.sendCommands(cmds, data) + + return err +} + +func (client *Client) Publish(channel string, val []byte) os.Error { + _, err := client.sendCommand("PUBLISH", channel, string(val)) + if err != nil { + return err + } + return nil +} + //Server commands func (client *Client) Save() os.Error { diff --git a/redis_test.go b/redis_test.go index bc911cc..71e535d 100644 --- a/redis_test.go +++ b/redis_test.go @@ -238,6 +238,56 @@ func TestBlpopTimeout(t *testing.T) { } } +func TestSubscribe(t *testing.T) { + subscribe := make(chan string, 0) + unsubscribe := make(chan string, 0) + messages := make(chan Message, 0) + + defer func() { + close(subscribe) + close(unsubscribe) + close(messages) + }() + go func() { + if err := client.Subscribe(subscribe, unsubscribe, messages); err != nil { + t.Fatal("Subscribed failed", err.String()) + } + }() + subscribe <- "ccc" + + data := []byte("foo") + quit := make(chan bool, 0) + defer close(quit) + go func() { + tick := time.Tick(10 * 1000 * 1000) // 10ms + timeout := time.Tick(100 * 1000 * 1000) // 100ms + LOOP: + for { + select { + case <-quit: + break LOOP + case <-timeout: + t.Fatal("TestSubscribe timeout") + break LOOP + case <-tick: + if err := client.Publish("ccc", data); err != nil { + t.Fatal("Pubish failed", err.String()) + } + } + } + }() + + msg := <-messages + quit <- true + if msg.Channel != "ccc" { + t.Fatal("Unexpected channel name") + } + if string(msg.Message) != string(data) { + t.Fatalf("Expected %s but got %s", string(data), string(msg.Message)) + } + close(subscribe) +} + func verifyHash(t *testing.T, key string, expected map[string][]byte) { //test Hget m1 := make(map[string][]byte)