From 8752d2b8d79255c6e6c41c054de69422710d3394 Mon Sep 17 00:00:00 2001 From: Michael Hoisie Date: Tue, 25 May 2010 22:54:31 -0700 Subject: [PATCH] Hmset and Hgetall now use reflection to read and write data to hashes. --- redis.go | 255 ++++++++++++++++++++++++++++++++++++++++++++++---- redis_test.go | 34 +++++-- 2 files changed, 263 insertions(+), 26 deletions(-) diff --git a/redis.go b/redis.go index 918d6de..5279a26 100644 --- a/redis.go +++ b/redis.go @@ -3,11 +3,13 @@ package redis import ( "bufio" "bytes" + "container/vector" "fmt" "io" "io/ioutil" "net" "os" + "reflect" "strconv" "strings" ) @@ -126,9 +128,7 @@ func readResponse(reader *bufio.Reader) (interface{}, os.Error) { } func (client *Client) rawSend(c *net.TCPConn, cmd []byte) (interface{}, os.Error) { - _, err := c.Write(cmd) - if err != nil { return nil, err } @@ -190,16 +190,14 @@ func (client *Client) sendCommand(cmd string, args []string) (data interface{}, goto End } } - data, err = client.rawSend(c, cmdbuf.Bytes()) - if err == os.EOF { c, err = client.openConnection() if err != nil { goto End } - data, err = client.rawSend(c, []byte(cmd)) + data, err = client.rawSend(c, cmdbuf.Bytes()) } End: @@ -902,16 +900,94 @@ func (client *Client) Hget(key string, field string) ([]byte, os.Error) { return data, nil } -func (client *Client) Hmset(key string, mapping map[string][]byte) os.Error { +//pretty much copy the json code from here. - args := make([]string, len(mapping)*2+1) - args[0] = key - i := 1 - for k, v := range mapping { - args[i] = k - args[i+1] = string(v) - i += 2 +func valueToString(v reflect.Value) (string, os.Error) { + if v == nil { + return "null", nil + } + + switch v := v.(type) { + case *reflect.BoolValue: + x := v.Get() + if x { + return "true", nil + } else { + return "false", nil + } + + case *reflect.IntValue: + return strconv.Itoa(v.Get()), nil + case *reflect.Int8Value: + return strconv.Itoa(int(v.Get())), nil + case *reflect.Int16Value: + return strconv.Itoa(int(v.Get())), nil + case *reflect.Int32Value: + return strconv.Itoa(int(v.Get())), nil + case *reflect.Int64Value: + return strconv.Itoa64(v.Get()), nil + + case *reflect.UintValue: + return strconv.Uitoa(v.Get()), nil + case *reflect.Uint8Value: + return strconv.Uitoa(uint(v.Get())), nil + case *reflect.Uint16Value: + return strconv.Uitoa(uint(v.Get())), nil + case *reflect.Uint32Value: + return strconv.Uitoa(uint(v.Get())), nil + case *reflect.Uint64Value: + return strconv.Uitoa64(v.Get()), nil + case *reflect.UintptrValue: + return strconv.Uitoa64(uint64(v.Get())), nil + + case *reflect.FloatValue: + return strconv.Ftoa(v.Get(), 'g', -1), nil + case *reflect.Float32Value: + return strconv.Ftoa32(v.Get(), 'g', -1), nil + case *reflect.Float64Value: + return strconv.Ftoa64(v.Get(), 'g', -1), nil + + case *reflect.StringValue: + return v.Get(), nil + case *reflect.SliceValue: + typ := v.Type().(*reflect.SliceType) + if _, ok := typ.Elem().(*reflect.Uint8Type); ok { + return string(v.Interface().([]byte)), nil + } } + return "", os.NewError("Unsupported type") +} + +func (client *Client) Hmset(key string, mapping interface{}) os.Error { + var args vector.StringVector + args.Push(key) + + switch v := reflect.NewValue(mapping).(type) { + case *reflect.MapValue: + if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok { + return os.NewError("Unsupported type - map key must be a string") + } + for _, k := range v.Keys() { + args.Push(k.(*reflect.StringValue).Get()) + s, err := valueToString(v.Elem(k)) + if err != nil { + return err + } + args.Push(s) + } + case *reflect.StructValue: + st := v.Type().(*reflect.StructType) + for i := 0; i < st.NumField(); i++ { + ft := st.FieldByIndex([]int{i}) + args.Push(ft.Name) + s, err := valueToString(v.FieldByIndex([]int{i})) + if err != nil { + return err + } + args.Push(s) + } + } + _, err := client.sendCommand("HMSET", args) if err != nil { return err @@ -979,19 +1055,158 @@ func (client *Client) Hvals(key string) ([][]byte, os.Error) { return res.([][]byte), nil } -func (client *Client) Hgetall(key string) (map[string][]byte, os.Error) { - res, err := client.sendCommand("HGETALL", []string{key}) +func writeTo(data []byte, val reflect.Value) os.Error { + s := string(data) + switch v := val.(type) { + case *reflect.BoolValue: + b, err := strconv.Atob(s) + if err != nil { + return err + } + v.Set(b) + case *reflect.IntValue: + i, err := strconv.Atoi(s) + if err != nil { + return err + } + v.Set(i) + case *reflect.Int8Value: + i, err := strconv.Atoi(s) + if err != nil { + return err + } + v.Set(int8(i)) + case *reflect.Int16Value: + i, err := strconv.Atoi(s) + if err != nil { + return err + } + v.Set(int16(i)) + case *reflect.Int32Value: + i, err := strconv.Atoi(s) + if err != nil { + return err + } + v.Set(int32(i)) + case *reflect.Int64Value: + i, err := strconv.Atoi64(s) + if err != nil { + return err + } + v.Set(i) + case *reflect.UintValue: + ui, err := strconv.Atoui(s) + if err != nil { + return err + } + v.Set(ui) + + case *reflect.Uint8Value: + ui, err := strconv.Atoui(s) + if err != nil { + return err + } + v.Set(uint8(ui)) + case *reflect.Uint16Value: + ui, err := strconv.Atoui(s) + if err != nil { + return err + } + v.Set(uint16(ui)) + case *reflect.Uint32Value: + ui, err := strconv.Atoui(s) + if err != nil { + return err + } + v.Set(uint32(ui)) + case *reflect.Uint64Value: + ui, err := strconv.Atoui64(s) + if err != nil { + return err + } + v.Set(ui) + case *reflect.UintptrValue: + ui, err := strconv.Atoui64(s) + if err != nil { + return err + } + v.Set(uintptr(ui)) + case *reflect.FloatValue: + f, err := strconv.Atof(s) + if err != nil { + return err + } + v.Set(f) + case *reflect.Float32Value: + f, err := strconv.Atof32(s) + if err != nil { + return err + } + v.Set(f) + case *reflect.Float64Value: + f, err := strconv.Atof64(s) + if err != nil { + return err + } + v.Set(f) + + case *reflect.StringValue: + v.Set(s) + case *reflect.SliceValue: + typ := v.Type().(*reflect.SliceType) + if _, ok := typ.Elem().(*reflect.Uint8Type); ok { + v.Set(reflect.NewValue(data).(*reflect.SliceValue)) + } + } + return nil +} + +func writeToContainer(data [][]byte, val reflect.Value) os.Error { + switch v := val.(type) { + case *reflect.PtrValue: + return writeToContainer(data, reflect.Indirect(v)) + case *reflect.InterfaceValue: + return writeToContainer(data, v.Elem()) + case *reflect.MapValue: + if _, ok := v.Type().(*reflect.MapType).Key().(*reflect.StringType); !ok { + return os.NewError("Invalid map type") + } + elemtype := v.Type().(*reflect.MapType).Elem() + for i := 0; i < len(data)/2; i++ { + mk := reflect.NewValue(string(data[i*2])) + mv := reflect.MakeZero(elemtype) + writeTo(data[i*2+1], mv) + v.SetElem(mk, mv) + } + case *reflect.StructValue: + for i := 0; i < len(data)/2; i++ { + name := string(data[i*2]) + field := v.FieldByName(name) + if field == nil { + continue + } + writeTo(data[i*2+1], field) + } + default: + return os.NewError("Invalid container type") + } + return nil +} + + +func (client *Client) Hgetall(key string, val interface{}) os.Error { + res, err := client.sendCommand("HGETALL", []string{key}) if err != nil { - return nil, err + return err } data := res.([][]byte) - ret := make(map[string][]byte, len(data)/2) - for i := 0; i < len(data)/2; i++ { - ret[string(data[i*2])] = data[i*2+1] + err = writeToContainer(data, reflect.NewValue(val)) + if err != nil { + return err } - return ret, nil + return nil } //Server commands diff --git a/redis_test.go b/redis_test.go index 863590c..ecbd898 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1,6 +1,7 @@ package redis import ( + "fmt" "os" "reflect" "runtime" @@ -164,7 +165,7 @@ func verifyHash(t *testing.T, key string, expected map[string][]byte) { //test Hget m1 := make(map[string][]byte) for k, _ := range expected { - actual, err := client.Hget("h", k) + actual, err := client.Hget(key, k) if err != nil { t.Fatal("verifyHash Hget failed", err.String()) } @@ -174,18 +175,18 @@ func verifyHash(t *testing.T, key string, expected map[string][]byte) { t.Fatal("verifyHash Hget failed") } - //test Hkeys keys, err := client.Hkeys(key) if err != nil { t.Fatal("verifyHash Hkeys failed", err.String()) } if len(keys) != len(expected) { - t.Fatal("verifyHash Hkeys failed") + fmt.Printf("%v\n", keys) + t.Fatal("verifyHash Hkeys failed - length not equal") } for _, key := range keys { if expected[key] == nil { - t.Fatal("verifyHash Hkeys failed") + t.Fatal("verifyHash Hkeys failed missing key", key) } } @@ -198,8 +199,9 @@ func verifyHash(t *testing.T, key string, expected map[string][]byte) { t.Fatal("verifyHash Hvals failed") } + m2 := map[string][]byte{} //test Hgetall - m2, err := client.Hgetall(key) + err = client.Hgetall(key, m2) if err != nil { t.Fatal("verifyHash Hgetall failed", err.String()) } @@ -208,6 +210,10 @@ func verifyHash(t *testing.T, key string, expected map[string][]byte) { } } +type tt struct { + A, B, C, D, E string +} + func TestHash(t *testing.T) { //test cast keys := []string{"a", "b", "c", "d", "e"} @@ -218,7 +224,7 @@ func TestHash(t *testing.T) { //set with hset for k, v := range test { - client.Hset("h", k, v) + client.Hset("h", k, []byte(v)) } //test hset verifyHash(t, "h", test) @@ -228,8 +234,24 @@ func TestHash(t *testing.T) { //test hset verifyHash(t, "h2", test) + test3 := tt{"aaaaa", "bbbbb", "ccccc", "ddddd", "eeeee"} + + client.Hmset("h3", test3) + //verifyHash(t, "h3", test) + + var test4 tt + //test Hgetall + err := client.Hgetall("h3", &test4) + if err != nil { + t.Fatal("verifyHash Hgetall failed", err.String()) + } + if !reflect.DeepEqual(test4, test3) { + t.Fatal("verifyHash Hgetall failed") + } + client.Del("h") client.Del("h2") + client.Del("h3") } /*