Skip to content

Commit

Permalink
Merge pull request #591 from matheusd/new-message-api
Browse files Browse the repository at this point in the history
multi: Make NewMessage() usable for creating messages for reading
  • Loading branch information
lthibault authored Dec 27, 2024
2 parents e9d7785 + dcb3395 commit 776eaca
Show file tree
Hide file tree
Showing 28 changed files with 418 additions and 668 deletions.
6 changes: 3 additions & 3 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -75,7 +75,7 @@ func TestPromiseFulfill(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -99,7 +99,7 @@ func TestPromiseFulfill(t *testing.T) {
h := new(dummyHook)
c := NewClient(h)
defer c.Release()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
Expand Down
20 changes: 19 additions & 1 deletion arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func (ssa *SingleSegmentArena) Release() {
type MultiSegmentArena struct {
segs []Segment

// rawData is set when the individual segments were all demuxed from
// the passed raw data slice.
rawData []byte

// bp is the bufferpool assotiated with this arena's segments if it was
// initialized for writing.
bp *bufferpool.Pool
Expand All @@ -175,6 +179,7 @@ func MultiSegment(b [][]byte) *MultiSegmentArena {
if b == nil {
msa := multiSegmentPool.Get().(*MultiSegmentArena)
msa.fromPool = true
msa.bp = &bufferpool.Default
return msa
}
return multiSegment(b)
Expand All @@ -190,6 +195,14 @@ func MultiSegment(b [][]byte) *MultiSegmentArena {
// Calling Release is optional; if not done the garbage collector
// will release the memory per usual.
func (msa *MultiSegmentArena) Release() {
// When this was demuxed from a single slice, return the entire slice.
if msa.rawData != nil && msa.bp != nil {
zeroSlice(msa.rawData)
msa.bp.Put(msa.rawData)
msa.bp = nil
}
msa.rawData = nil

for i := range msa.segs {
if msa.bp != nil {
zeroSlice(msa.segs[i].data)
Expand Down Expand Up @@ -236,7 +249,10 @@ var multiSegmentPool = sync.Pool{

// demuxArena slices data into a multi-segment arena. It assumes that
// len(data) >= hdr.totalSize().
func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte) error {
//
// bp should point to the bufferpool which will receive back data once the
// arena is released. It may be nil if this should not be returned anywhere.
func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte, bp *bufferpool.Pool) error {
maxSeg := hdr.maxSegment()
if int64(maxSeg) > int64(maxInt-1) {
return errors.New("number of segments overflows int")
Expand All @@ -261,6 +277,8 @@ func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte) error {
msa.segs[i].id = i
}

msa.rawData = data
msa.bp = bp
return nil
}

Expand Down
9 changes: 8 additions & 1 deletion canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ import (
// for equivalent structs, even as the schema evolves. The blob is
// suitable for hashing or signing.
func Canonicalize(s Struct) ([]byte, error) {
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
if !s.IsValid() {
// Ensure compatbility to existing behavior: even if the struct
// is not valid, at least the root pointer is allocated and
// returned as canonical. Without this,
// TestCanonicalize/Struct{} fails.
if _, err := msg.allocRootPointerSpace(); err != nil {
return nil, err
}
return seg.Data(), nil
}
root, err := NewRootStruct(seg, canonicalStructSize(s))
Expand Down
20 changes: 10 additions & 10 deletions canonical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "empty struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "zero data, zero pointer struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "one word data struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
s.SetUint16(0, 0xbeef)
return s
Expand All @@ -47,7 +47,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "two pointers to zero structs",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
e1, _ := NewStruct(seg, ObjectSize{DataSize: 8})
e2, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -63,7 +63,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "pointer to interface",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
iface := NewInterface(seg, 1)
s.SetPtr(0, iface.ToPtr())
Expand All @@ -76,7 +76,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -95,7 +95,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -110,7 +110,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 8, PointerCount: 1}, 2)
s.SetPtr(0, l.ToPtr())
Expand All @@ -133,7 +133,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 3)
s.SetPtr(0, l.ToPtr())
Expand All @@ -148,7 +148,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero-length struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 0)
s.SetPtr(0, l.ToPtr())
Expand Down
32 changes: 9 additions & 23 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,8 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) {
if dr.s.IsValid() {
return Struct{}, errors.New("AllocResults called multiple times")
}
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
return Struct{}, err
}
_, seg := NewSingleSegmentMessage(nil)
var err error
dr.s, err = NewRootStruct(seg, sz)
return dr.s, err
}
Expand All @@ -377,10 +375,7 @@ func (dr *dummyReturner) ReleaseResults() {
}

func TestToInterface(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
ptr Ptr
in Interface
Expand All @@ -399,10 +394,7 @@ func TestToInterface(t *testing.T) {
}

func TestInterface_value(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
in Interface
val rawPointer
Expand All @@ -421,10 +413,7 @@ func TestInterface_value(t *testing.T) {
}

func TestTransform(t *testing.T) {
_, s, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, s := NewSingleSegmentMessage(nil)
root, err := NewStruct(s, ObjectSize{PointerCount: 2})
if err != nil {
t.Fatal(err)
Expand All @@ -442,7 +431,7 @@ func TestTransform(t *testing.T) {
b.SetUint64(0, 2)
a.SetPtr(0, b.ToPtr())

dmsg, d, err := NewMessage(SingleSegment(nil))
dmsg, d := NewSingleSegmentMessage(nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -675,20 +664,17 @@ func deepPointerEqual(a, b Ptr) bool {
if !a.IsValid() || !b.IsValid() {
return false
}
msgA, _, _ := NewMessage(SingleSegment(nil))
msgA, _ := NewSingleSegmentMessage(nil)
msgA.SetRoot(a)
abytes, _ := msgA.Marshal()
msgB, _, _ := NewMessage(SingleSegment(nil))
msgB, _ := NewSingleSegmentMessage(nil)
msgB.SetRoot(b)
bbytes, _ := msgB.Marshal()
return bytes.Equal(abytes, bbytes)
}

func newEmptyStruct() Struct {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
panic(err)
}
_, seg := NewSingleSegmentMessage(nil)
s, err := NewRootStruct(seg, ObjectSize{})
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion capnpc-go/capnpc-go.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (g *generator) defineSchemaVar() error {
}
sort.Sort(uint64Slice(ids))

msg, seg, _ := capnp.NewMessage(capnp.SingleSegment(nil))
msg, seg := capnp.NewSingleSegmentMessage(nil)
req, _ := schema.NewRootCodeGeneratorRequest(seg)
// TODO(light): find largest object size and use that to allocate list
nodes, _ := req.NewNodes(int32(len(g.nodes)))
Expand Down
7 changes: 2 additions & 5 deletions capnpc-go/fileparts.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ func (sd *staticData) init(fileID uint64) {
}

func (sd *staticData) copyData(obj capnp.Ptr) (staticDataRef, error) {
m, _, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
return staticDataRef{}, err
}
err = m.SetRoot(obj)
m, _ := capnp.NewSingleSegmentMessage(nil)
err := m.SetRoot(obj)
if err != nil {
return staticDataRef{}, err
}
Expand Down
22 changes: 17 additions & 5 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,35 @@ func (d *Decoder) Decode() (*Message, error) {
if err != nil {
return nil, exc.WrapError("decode", err)
}

// Special case an empty message to return a new MultiSegment message
// ready for writing. This maintains compatibility to tests and older
// implementation of message and arenas.
if hdr.maxSegment() == 0 && total == 0 {
msg, _ := NewMultiSegmentMessage(nil)
return msg, nil
}

// TODO(someday): if total size is greater than can fit in one buffer,
// attempt to allocate buffer per segment.
if total > maxSize-uint64(len(hdr)) || total > uint64(maxInt) {
return nil, errors.New("decode: message too large")
}

// Read segments.
buf := bufferpool.Default.Get(int(total))
bp := &bufferpool.Default
buf := bp.Get(int(total))
if _, err := io.ReadFull(d.r, buf); err != nil {
return nil, exc.WrapError("decode: read segments", err)
}

arena := MultiSegment(nil)
if err = arena.demux(hdr, buf); err != nil {
if err = arena.demux(hdr, buf, bp); err != nil {
return nil, exc.WrapError("decode", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

func (d *Decoder) readHeader(maxSize uint64) (streamHeader, error) {
Expand Down Expand Up @@ -162,11 +173,12 @@ func Unmarshal(data []byte) (*Message, error) {
}

arena := MultiSegment(nil)
if err := arena.demux(hdr, data); err != nil {
if err := arena.demux(hdr, data, nil); err != nil {
return nil, exc.WrapError("unmarshal", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

// UnmarshalPacked reads a packed serialized stream into a message.
Expand Down
Loading

0 comments on commit 776eaca

Please sign in to comment.