From 8eeece33868236224d51e7362e36a68642870bd2 Mon Sep 17 00:00:00 2001 From: Chewxy Date: Sun, 19 Aug 2018 11:45:45 +1000 Subject: [PATCH] V0.9.0 (#22) # v0.9.0 # The changes made in this PR is aimed at better supporting v0.9.0 of Gorgonia itself. Along the way there are some new features and optimizations, as well as some bug fixes. The majority of the work in supporting v0.9.0 of Gorgonia is to shore up the underlying architecture to support CUDA related engines. This means moving more things to rely on `Engine` while keeping the engine interface overheads low. Additionally this also means better support for column major data layouts. * Heavier reliance on `Engine` for most functions. This allows for extensibility on the data structure. * Long standing bugbear - concepts of `RowVec` and `ColVec` has been removed (thanks to @matdodgson) - Touch points: `ap.go`, `iterator.go`, `iterator_mult.go`.`shape.go`, and the tests that were correct prior to this change have semantic meaning changes too. - **POTENTIAL TECH DEBT**: `iterator_mult.go` - the solution of filling with ones is a little too dodgy for my liking. The alternative would be to change `BroadcastStrides` which will change even more things (`Concat`, `Stack` etc) * **Optimization**: - `AP` has been depointerized in `*Dense` (thanks to @docmerlin). This reduces *some* amount of GC pointer chasing, but not all - allocation is slightly improved. (`(array).fromSliceOrArrayer`, `(array).fix()` and `(array).forcefix()` are part of the improvement around the logic of allocating data. * **Bug fixes**: - Fixes subtle errors in linear algebra functions. The result is a slightly longer function but easier to reason with. - Fixes some subtle bugs in `Concat` - see also gorgonia/gorgonia#218 - Fixed some small bugs with regards to `SampleIndex` that only show up when the slices have extreme lengths. This API should have been deprecated 2 years ago, but eh... it touched a lot of external projects. * **API changes**: - `Diag` is made available. Relies heavily on an `Engine`'s implementation - `NewFlatIterator` is unexported. - `NewAP` is unexported. - `MakeAP` is used instead. - `(Tensor).DataOrder()` is added to the definiiton of what a `Tensor` is. - `(Shape).IsScalarEquiv()` is a new method. This corresponds to the change of semantics of what a `Shape` should be. - `(Shape).CalcStrides()` is exported now. This enables users to correctly calculate strides that are consistent to what the package expects. - `(Shape).CalcStridesColMajor()` is exported as the method to calculate the strides of a Col-Major `*Dense`. * **New Interfaces**: - `NonStdEngine` is an `Engine that does not allocate using the default allocator. This allows for both embedding a `DefaultEngine` while overriding the allocation behaviour. - `Diager` - any engine that can return a tensor that only contains the diagonal values of the input - `NaNChecker` and `InfChecker` - engines that can check a tensor for NaN and Inf * **New Features**: * Added full support for colmajor tensors. (fixes #10) - TODO: colmajor iterator's prev() method (see #34) - Added serialization to Protobuf and Flatbuffers * TODO: Add example for serialization (see #35 and #36) - Added more support for sparse CS tensors. * **New Subpackages**: * `native` is a subpackage that essentially gives users a native, Go-based iterator. Basically the ability to go from a `*Dense` to a `[][]T` or `[][][]T` **without extra allocations** (for the data). This was pulled into `master` earlier, but as of v0.9.0, the generic version is available too. * **Semantic Changes**: - `Shape` has semantic changes regarding whether or not a shape is scalar. A scalar shape is defined to be `Shape{}` or `Shape{1}` only. Formerly, `Shape{1,1}` was also considered to be scalar. Now they're considered to be `ScalarEquivalent` (along with `Shape{1, 1, .... , 1}`) - A `Dtype` that is is orderable is also now comparable for equality. If `RegisterOrd` is called with a new `Dtype`, it is also automatically registered as `Eq`. * **Cosmetic Changes**: - README has been updated to point to correct doc pages --- .travis.yml | 1 + .travis/test.sh | 3 +- CONTRIBUTORS.md | 3 +- README.md | 15 +- ap.go | 157 ++- ap_test.go | 56 +- api_arith_test.go | 10 +- api_cmp_generated_test.go | 56 +- api_matop.go | 7 + api_utils.go | 12 +- array.go | 93 +- benchmark_dense_matop_test.go | 58 +- consopt.go | 76 +- defaultengine.go | 7 +- defaultengine_argmethods.go | 22 +- defaultengine_arith.go | 24 +- defaultengine_cmp.go | 36 +- defaultengine_linalg.go | 158 ++- defaultengine_mapreduce.go | 20 +- defaultengine_matop_misc.go | 130 +- defaultengine_matop_stack.go | 13 +- defaultengine_matop_transpose.go | 12 +- defaultengine_matop_transpose_inplace.go | 18 + defaultengine_misc.go | 2 +- defaultengine_prep.go | 29 +- defaultengine_unary.go | 28 +- defaultenginefloat32.go | 10 +- defaultenginefloat64.go | 10 +- dense.go | 58 +- dense_assign.go | 8 +- dense_cmp_test.go | 56 +- dense_colmajor_linalg_test.go | 483 +++++++ dense_compat.go | 8 +- dense_format.go | 4 +- dense_generated.go | 2 +- dense_io.go | 1070 +++++++-------- dense_io_test.go | 49 +- dense_linalg.go | 15 +- dense_linalg_test.go | 109 +- dense_matop.go | 23 +- dense_matop_memmove.go | 11 +- dense_matop_test.go | 120 +- dense_norms.go | 4 +- dense_svd_test.go | 29 +- engine.go | 24 + example_dense_linalg_test.go | 151 +++ example_dense_matop_test.go | 2 + example_iterator_test.go | 48 +- example_tensor_basics_test.go | 120 +- flags.go | 51 +- flags_test.go | 47 +- interfaces.go | 2 +- internal/IDLs/generated.fbs | 38 + internal/IDLs/generated.proto | 52 + internal/serialization/README.md | 33 + internal/serialization/doc.go | 2 + internal/serialization/fb/AP.go | 110 ++ internal/serialization/fb/Dense.go | 152 +++ internal/serialization/fb/MaskedDense.go | 198 +++ internal/serialization/fb/Triangle.go | 18 + internal/serialization/pb/dense.go | 45 + internal/serialization/pb/generated.pb.go | 1457 +++++++++++++++++++++ iterator.go | 80 +- iterator_mult.go | 20 +- iterator_test.go | 119 +- native/example_test.go | 7 +- native/generic.go | 72 + native/generic_test.go | 67 + perf.go | 91 +- shape.go | 47 +- shape_test.go | 22 +- sparse.go | 28 +- tensor.go | 2 +- testutils_test.go | 1 + types.go | 23 + utils.go | 9 +- 76 files changed, 5010 insertions(+), 1243 deletions(-) create mode 100644 dense_colmajor_linalg_test.go create mode 100644 example_dense_linalg_test.go create mode 100644 internal/IDLs/generated.fbs create mode 100755 internal/IDLs/generated.proto create mode 100644 internal/serialization/README.md create mode 100644 internal/serialization/doc.go create mode 100644 internal/serialization/fb/AP.go create mode 100644 internal/serialization/fb/Dense.go create mode 100644 internal/serialization/fb/MaskedDense.go create mode 100644 internal/serialization/fb/Triangle.go create mode 100644 internal/serialization/pb/dense.go create mode 100644 internal/serialization/pb/generated.pb.go create mode 100644 native/generic.go create mode 100644 native/generic_test.go diff --git a/.travis.yml b/.travis.yml index 8706540..9a3402d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ branches: go: - 1.8.x - 1.9.x + - 1.10.x - tip env: diff --git a/.travis/test.sh b/.travis/test.sh index 2e00d07..37fdd87 100644 --- a/.travis/test.sh +++ b/.travis/test.sh @@ -6,9 +6,10 @@ go test -v -a -covermode=atomic -coverprofile=test.cover . go test -tags='avx' -a -covermode=atomic -coverprofile=avx.cover . go test -tags='sse' -a -covermode=atomic -coverprofile=sse.cover . go test -tags='inplacetranspose' -a -covermode=atomic -coverprofile=inplacetranspose.cover . +go test -a -covermode=atomic -coverprofile=native.cover ./native/. # because coveralls only accepts one coverage file at one time... we combine them into one gigantic one -covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover) +covers=(./test.cover ./avx.cover ./sse.cover ./inplacetranspose.cover ./native.cover) echo "mode: set" > ./final.cover tail -q -n +2 "${covers[@]}" >> ./final.cover goveralls -coverprofile=./final.cover -service=travis-ci diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 57adb1d..d94f24c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -4,6 +4,7 @@ * Naseer Dari (@ndari) - errors and error handling * Joe Kabaka (@kabaka0) - masked array functionality * Stuart Carnie (@stuartcarnie) - performance optimization for iterators +* Jorge Landivar (@docmerlin) - performance optimization for `*Dense` # Contributors @@ -13,8 +14,8 @@ * David Soller | @3ygun * Davor Kapsa | @dvrkps * James Michael DuPont | @h4ck3rm1k3 -* Jorge Landivar | @docmerlin * Yuanlin Lian | @alienchow +* Andrew SnodGrass | @pointlander diff --git a/README.md b/README.md index 62b3fbd..086cdc4 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ -# Package `tensor` [![GoDoc](https://godoc.org/github.com/gorgonia/tensor?status.svg)](https://godoc.org/github.com/gorgonia/tensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) # +# Package `tensor` [![GoDoc](https://godoc.org/gorgonia.org/tensor?status.svg)](https://godoc.org/gorgonia.org/tensor) [![GitHub version](https://badge.fury.io/gh/gorgonia%2Ftensor.svg)](https://badge.fury.io/gh/gorgonia%2Ftensor) [![Build Status](https://travis-ci.org/gorgonia/tensor.svg?branch=master)](https://travis-ci.org/gorgonia/tensor) [![Coverage Status](https://coveralls.io/repos/github/gorgonia/tensor/badge.svg?branch=master)](https://coveralls.io/github/gorgonia/tensor?branch=master) [![Go Report Card](https://goreportcard.com/badge/gorgonia.org/tensor)](https://goreportcard.com/report/gorgonia.org/tensor) [![unstable](http://badges.github.io/stability-badges/dist/unstable.svg)](http://github.com/badges/stability-badges)# + Package `tensor` is a package that provides efficient, generic (by some definitions of generic) n-dimensional arrays in Go. Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations. -The main purpose of this package is to support the operations required by [Gorgonia](https://github.com/chewxy/gorgonia). +The main purpose of this package is to support the operations required by [Gorgonia](https://gorgonia.org/gorgonia). ## Introduction ## In the data analysis world, [Numpy](http://http://www.numpy.org/) and [Matlab](https://www.mathworks.com/products/matlab.html) currently reign supreme. Both tools rely heavily on having performant n-dimensional arrays, or tensors. **There is an obvious need for multidimensional arrays in Go**. @@ -50,15 +51,15 @@ The `*Dense` tensor is the primary tensor and is represented by a singular flat ### Compressed Sparse Column Matrix ### -Coming soon +Documentation Coming soon ### Compressed Sparse Row Matrix ### -Coming soon +Documentation Coming soon ## Usage ## -To install: `go get -u "github.com/chewxy/gorgonia/tensor"` +To install: `go get -u "gorgonia.org/tensor"` To create a matrix with package `tensor` is easy: @@ -129,7 +130,7 @@ b.SetAt(1000, 0, 1, 2) fmt.Printf("b:\n%v", b) ``` -There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/github.com/chewxy/gorgonia/tensor) page +There is a whole laundry list of methods and functions available at the [godoc](https://godoc.org/gorgonia.org/tensor) page @@ -198,7 +199,7 @@ The above call will use `myEngine` to allocate memory instead. This is useful in ### Other failed designs ### -The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/chewxy/gorgonia/blob/master/tensor/ALTERNATIVEDESIGNS.md) +The alternative designs can be seen in the [ALTERNATIVE DESIGNS document](https://github.com/tensor/blob/master/ALTERNATIVEDESIGNS.md) ## Generic Features ## diff --git a/ap.go b/ap.go index b4b9176..83df9c5 100644 --- a/ap.go +++ b/ap.go @@ -26,13 +26,30 @@ type AP struct { Δ Triangle } -// NewAP creates a new AP, given the shape and strides -func NewAP(shape Shape, strides []int) *AP { - ap := borrowAP() +func makeAP(size int) AP { + return AP{ + shape: Shape(BorrowInts(size)), + strides: BorrowInts(size), + } +} + +// MakeAP creates an AP, given the shape and strides. +func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP { + return AP{ + shape: shape, + strides: strides, + o: o, + Δ: Δ, + fin: true, + } +} + +// Init initalizes an already created AP with a shape and stries. +// It will panic if AP is nil. +func (ap *AP) Init(shape Shape, strides []int) { ap.shape = shape ap.strides = strides ap.fin = true - return ap } // SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff @@ -46,6 +63,9 @@ func (ap *AP) SetShape(s ...int) { if !ap.fin { // scalars are a special case, we don't want to remove it completely if len(s) == 0 { + if ap.shape == nil || ap.strides == nil { + ap.shape = Shape{} + } ap.shape = ap.shape[:0] ap.strides = ap.strides[:0] return @@ -102,9 +122,54 @@ func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() } // IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 } -// Clone clones the *AP. Clearly. -func (ap *AP) Clone() (retVal *AP) { - retVal = BorrowAP(len(ap.shape)) +// IsZero tell us if the ap has zero size +func (ap *AP) IsZero() bool { + return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0 +} + +// Zero zeros out an AP. +func (ap *AP) zero() { + // log.Printf("ZEROING. Called by %v", string(debug.Stack())) + + // Jorge's original implementation for zeroing a AP is as below + // but to cater for the (*Dense).fix() method of the *Dense + // a nil shape is used to signal unsetness + // so we cannot just truncate the shape even though it would be a lot more efficient + + // ap.shape = ap.shape[:0] + // ap.strides = ap.strides[:0] + ReturnInts([]int(ap.shape)) + ReturnInts(ap.strides) + ap.zeroOnly() +} + +// side effect free zeroing +func (ap *AP) zeroOnly() { + ap.shape = nil + ap.strides = nil + + ap.fin = false + ap.o = 0 + ap.Δ = 0 +} + +func (ap *AP) zeroWithDims(dims int) { + //ap.shape = BorrowInts(dims) + //ap.strides = BorrowInts(dims) + if cap(ap.shape) >= dims { + ap.shape = ap.shape[:dims] + } + ap.shape = BorrowInts(dims) + if cap(ap.strides) >= dims { + ap.strides = ap.strides[:dims] + } + ap.strides = BorrowInts(dims) +} + +// Clone clones the *AP. Clearly. It returns AP +func (ap *AP) Clone() (retVal AP) { + retVal = makeAP(cap(ap.shape)) + copy(retVal.shape, ap.shape) copy(retVal.strides, ap.strides) @@ -118,21 +183,25 @@ func (ap *AP) Clone() (retVal *AP) { return } +func (ap *AP) CloneTo(dest *AP) { + dest.shape = append(dest.shape[:0], ap.shape...) + dest.strides = append(dest.strides[:0], ap.strides...) + dest.fin = ap.fin + dest.o = ap.o + dest.Δ = ap.Δ +} + // DataOrder returns the data order of the AP. func (ap *AP) DataOrder() DataOrder { return ap.o } // C returns true if the access pattern is C-contiguous array -func (ap *AP) C() bool { - return ap.o.isRowMajor() && ap.o.isContiguous() -} +func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() } // F returns true if the access pattern is Fortran contiguous array -func (ap *AP) F() bool { - return ap.o.isColMajor() && ap.o.isContiguous() -} +func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() } // S returns the metadata of the sliced tensor. -func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err error) { +func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) { if len(slices) > len(ap.shape) { // error err = errors.Errorf(dimMismatch, len(ap.shape), len(slices)) @@ -146,7 +215,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e var outerDim int order := ap.o - if ap.o.isRowMajor() || ap.IsVector() { + if ap.o.IsRowMajor() || ap.IsVector() { outerDim = 0 } else { outerDim = len(ap.shape) - 1 @@ -160,12 +229,13 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e size := ap.shape[i] var stride int - if ap.IsVector() { - // handles non-vanilla vectors - stride = ap.strides[0] - } else { - stride = ap.strides[i] - } + stride = ap.strides[i] + // if ap.IsVector() { + // // handles non-vanilla vectors + // stride = ap.strides[0] + // } else { + // stride = ap.strides[i] + // } var start, end, step int if start, end, step, err = SliceDetails(sl, size); err != nil { @@ -196,37 +266,29 @@ func (ap *AP) S(size int, slices ...Slice) (newAP *AP, ndStart, ndEnd int, err e if ndEnd-ndStart == 1 { // scalars are a special case - newAP = borrowAP() + newAP = AP{} newAP.SetShape() // make it a Scalar newAP.lock() } else { // drop any dimension with size 1, except the last dimension + offset := 0 for d := 0; d < dims; d++ { - if newShape[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { + if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { newShape = append(newShape[:d], newShape[d+1:]...) newStrides = append(newStrides[:d], newStrides[d+1:]...) d-- dims-- + offset++ } } - - //fix up strides - if newShape.IsColVec() { - stride0 := newStrides[0] - ReturnInts(newStrides) - newStrides = BorrowInts(1) - newStrides[0] = stride0 - } - - newAP = NewAP(newShape, newStrides) - newAP.o = order + newAP = MakeAP(newShape, newStrides, order, ap.Δ) } return } // T returns the transposed metadata based on the given input -func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { +func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) { // prep axes if len(axes) > 0 && len(axes) != ap.Dims() { err = errors.Errorf(dimMismatch, ap.Dims(), len(axes)) @@ -244,7 +306,7 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { // if axes is 0, 1, 2, 3... then no op if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 { - return ap, a, noopError{} + return ap.Clone(), a, noopError{} } currentShape := ap.shape @@ -270,12 +332,8 @@ func (ap *AP) T(axes ...int) (retVal *AP, a []int, err error) { } } - retVal = borrowAP() - retVal.shape = shape - retVal.strides = strides - if ap.IsVector() { - retVal.strides = retVal.strides[:1] - } + o := MakeDataOrder(ap.o, Transposed) + retVal = MakeAP(shape, strides, o, ap.Δ) retVal.fin = true return } @@ -286,14 +344,21 @@ func (ap *AP) unlock() { ap.fin = false } func (ap *AP) calcStrides() []int { switch { - case ap.o.isRowMajor(): - return ap.shape.calcStrides() - case ap.o.isColMajor(): - return ap.shape.calcStridesColMajor() + case ap.o.IsRowMajor(): + return ap.shape.CalcStrides() + case ap.o.IsColMajor(): + return ap.shape.CalcStridesColMajor() } panic("unreachable") } +// setDataOrder is a method such that any tensor that embeds *AP will have the same method +func (ap *AP) setDataOrder(o DataOrder) { + if !o.HasSameOrder(ap.o) { + ap.o = ap.o.toggleColMajor() + } +} + // TransposeIndex returns the new index given the old index func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int { oldCoord, err := Itol(i, oldShape, oldStrides) diff --git a/ap_test.go b/ap_test.go index 37e0a28..091d6e6 100644 --- a/ap_test.go +++ b/ap_test.go @@ -32,46 +32,40 @@ func sli(start int, opt ...int) dummySlice { return dummySlice{start: start, end: end, step: step} } -func dummyScalar1() *AP { - return &AP{} -} +func dummyScalar1() AP { return AP{} } -func dummyScalar2() *AP { - return &AP{ - shape: Shape{1}, - } -} +func dummyScalar2() AP { return AP{shape: Shape{1}} } -func dummyColVec() *AP { - return &AP{ +func dummyColVec() AP { + return AP{ shape: Shape{5, 1}, strides: []int{1}, } } -func dummyRowVec() *AP { - return &AP{ +func dummyRowVec() AP { + return AP{ shape: Shape{1, 5}, strides: []int{1}, } } -func dummyVec() *AP { - return &AP{ +func dummyVec() AP { + return AP{ shape: Shape{5}, strides: []int{1}, } } -func twothree() *AP { - return &AP{ +func twothree() AP { + return AP{ shape: Shape{2, 3}, strides: []int{3, 1}, } } -func twothreefour() *AP { - return &AP{ +func twothreefour() AP { + return AP{ shape: Shape{2, 3, 4}, strides: []int{12, 4, 1}, } @@ -83,7 +77,7 @@ func TestAccessPatternBasics(t *testing.T) { ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) @@ -100,21 +94,21 @@ func TestAccessPatternBasics(t *testing.T) { ap.unlock() ap.SetShape(1, 2) assert.Equal(Shape{1, 2}, ap.Shape()) - assert.Equal([]int{1}, ap.Strides()) + assert.Equal([]int{2, 1}, ap.Strides()) assert.Equal(2, ap.Dims()) assert.Equal(2, ap.Size()) - if ap.String() != "Shape: (1, 2), Stride: [1], Lock: false" { - t.Error("AP formatting error. Got %q", ap.String()) + if ap.String() != "Shape: (1, 2), Stride: [2 1], Lock: false" { + t.Errorf("AP formatting error. Got %q", ap.String()) } ap2 := ap.Clone() - assert.Equal(ap, ap2) + assert.Equal(*ap, ap2) } func TestAccessPatternIsX(t *testing.T) { assert := assert.New(t) - var ap *AP + var ap AP ap = dummyScalar1() assert.True(ap.IsScalar()) @@ -151,7 +145,7 @@ func TestAccessPatternIsX(t *testing.T) { func TestAccessPatternT(t *testing.T) { assert := assert.New(t) - var ap, apT *AP + var ap, apT AP var axes []int var err error @@ -216,16 +210,22 @@ var sliceTests = []struct { {"A[1:3]", Shape{4, 5}, []Slice{sli(1, 3)}, 5, 15, Shape{2, 5}, []int{5, 1}, true}, {"A[0:10] (intentionally over)", Shape{4, 5}, []Slice{sli(0, 10)}, 0, 20, Shape{4, 5}, []int{5, 1}, true}, // as if nothing happened {"A[:, 1:3]", Shape{4, 5}, []Slice{nil, sli(1, 3)}, 1, 18, Shape{4, 2}, []int{5, 1}, false}, + + // tensor + {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, 0, 4, Shape{2, 2}, []int{2, 1}, true}, + {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, 0, 2, Shape{1, 2}, []int{4, 1}, false}, + {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, + {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, 0, 4, Shape{1, 2, 2}, []int{4, 2, 1}, true}, } func TestAccessPatternS(t *testing.T) { assert := assert.New(t) - var ap, apS *AP + var ap, apS AP var ndStart, ndEnd int var err error for _, sts := range sliceTests { - ap = NewAP(sts.shape, sts.shape.calcStrides()) + ap = MakeAP(sts.shape, sts.shape.CalcStrides(), 0, 0) if apS, ndStart, ndEnd, err = ap.S(sts.shape.TotalSize(), sts.slices...); err != nil { t.Errorf("%v errored: %v", sts.name, err) continue @@ -234,7 +234,7 @@ func TestAccessPatternS(t *testing.T) { assert.Equal(sts.correctEnd, ndEnd, "Wrong end: %v. Want %d Got %d", sts.name, sts.correctEnd, ndEnd) assert.True(sts.correctShape.Eq(apS.shape), "Wrong shape: %v. Want %v. Got %v", sts.name, sts.correctShape, apS.shape) assert.Equal(sts.correctStride, apS.strides, "Wrong strides: %v. Want %v. Got %v", sts.name, sts.correctStride, apS.strides) - assert.Equal(sts.contiguous, apS.DataOrder().isContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) + assert.Equal(sts.contiguous, apS.DataOrder().IsContiguous(), "Wrong contiguity for %v Want %t.", sts.name, sts.contiguous) } } diff --git a/api_arith_test.go b/api_arith_test.go index d7bd9a5..687e4b7 100644 --- a/api_arith_test.go +++ b/api_arith_test.go @@ -26,7 +26,7 @@ func TestMod(t *testing.T) { // scalar if res, err = Mod(a, 1.0); err != nil { - t.Fatal("Error: %v", err) + t.Fatalf("Error: %v", err) } assert.Equal(t, correct, res.Data()) } @@ -41,10 +41,10 @@ func TestFMA(t *testing.T) { y2 := y.Clone().(*Dense) we, willFailEq := willerr(a, numberTypes, nil) - // _, ok1 := q.Engine().(FMAer) - // _, ok2 := q.Engine().(Muler) - // _, ok3 := q.Engine().(Adder) - // we = we || (!ok1 && (!ok2 || !ok3)) + _, ok1 := q.Engine().(FMAer) + _, ok2 := q.Engine().(Muler) + _, ok3 := q.Engine().(Adder) + we = we || (!ok1 && (!ok2 || !ok3)) f, err := FMA(a, x, y) if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly { diff --git a/api_cmp_generated_test.go b/api_cmp_generated_test.go index 163ae5c..002587b 100644 --- a/api_cmp_generated_test.go +++ b/api_cmp_generated_test.go @@ -62,7 +62,7 @@ func TestGt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -120,7 +120,7 @@ func TestGte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -178,7 +178,7 @@ func TestLt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -236,7 +236,7 @@ func TestLte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -294,7 +294,7 @@ func TestEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -328,7 +328,7 @@ func TestEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe(t *testing.T) { @@ -363,7 +363,7 @@ func TestNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGt_assame(t *testing.T) { @@ -422,7 +422,7 @@ func TestGt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -482,7 +482,7 @@ func TestGte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -542,7 +542,7 @@ func TestLt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -602,7 +602,7 @@ func TestLte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -662,7 +662,7 @@ func TestEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -699,7 +699,7 @@ func TestEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestNe_assame(t *testing.T) { @@ -737,7 +737,7 @@ func TestNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestGtScalar(t *testing.T) { @@ -792,7 +792,7 @@ func TestGtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -848,7 +848,7 @@ func TestGteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -904,7 +904,7 @@ func TestLtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -960,7 +960,7 @@ func TestLteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1016,7 +1016,7 @@ func TestEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1048,7 +1048,7 @@ func TestEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar(t *testing.T) { @@ -1081,7 +1081,7 @@ func TestNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestGtScalar_assame(t *testing.T) { @@ -1138,7 +1138,7 @@ func TestGtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -1196,7 +1196,7 @@ func TestGteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -1254,7 +1254,7 @@ func TestLtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -1312,7 +1312,7 @@ func TestLteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1370,7 +1370,7 @@ func TestEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1405,7 +1405,7 @@ func TestEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestNeScalar_assame(t *testing.T) { @@ -1441,6 +1441,6 @@ func TestNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/api_matop.go b/api_matop.go index 0db34e7..8c687b2 100644 --- a/api_matop.go +++ b/api_matop.go @@ -108,3 +108,10 @@ func Materialize(t Tensor) Tensor { return t } } + +func Diag(t Tensor) (retVal Tensor, err error) { + if d, ok := t.Engine().(Diager); ok { + return d.Diag(t) + } + return nil, errors.Errorf("Unable to perform diagonalization of tensor ") +} diff --git a/api_utils.go b/api_utils.go index 12bea19..3cf55f0 100644 --- a/api_utils.go +++ b/api_utils.go @@ -53,11 +53,11 @@ func SortIndex(in interface{}) (out []int) { // SampleIndex samples a slice or a Tensor. // TODO: tidy this up. func SampleIndex(in interface{}) int { - var l int + // var l int switch list := in.(type) { case []int: var sum, i int - l = len(list) + // l = len(list) r := rand.Int() for { sum += list[i] @@ -69,7 +69,7 @@ func SampleIndex(in interface{}) int { case []float64: var sum float64 var i int - l = len(list) + // l = len(list) r := rand.Float64() for { sum += list[i] @@ -85,7 +85,7 @@ func SampleIndex(in interface{}) int { var sum float64 r := rand.Float64() data := list.Float64s() - l = len(data) + // l = len(data) for { datum := data[i] if math.IsNaN(datum) || math.IsInf(datum, 0) { @@ -102,7 +102,7 @@ func SampleIndex(in interface{}) int { var sum float32 r := rand.Float32() data := list.Float32s() - l = len(data) + // l = len(data) for { datum := data[i] if math32.IsNaN(datum) || math32.IsInf(datum, 0) { @@ -121,5 +121,5 @@ func SampleIndex(in interface{}) int { default: panic("Not yet implemented") } - return l - 1 + return -1 } diff --git a/array.go b/array.go index 4162280..321c9cf 100644 --- a/array.go +++ b/array.go @@ -18,10 +18,8 @@ type array struct { // makeHeader makes a array Header func makeHeader(t Dtype, length int) storage.Header { - size := int(calcMemSize(t, length)) - s := make([]byte, size) return storage.Header{ - Ptr: unsafe.Pointer(&s[0]), + Ptr: malloc(t, length), L: length, C: length, } @@ -75,6 +73,7 @@ func arrayFromSlice(x interface{}) array { } } +// fromSlice populates the value from a slice func (a *array) fromSlice(x interface{}) { xT := reflect.TypeOf(x) if xT.Kind() != reflect.Slice { @@ -91,20 +90,45 @@ func (a *array) fromSlice(x interface{}) { a.v = x } +// fromSliceOrTensor populates the value from a slice or anything that can form an array +func (a *array) fromSliceOrArrayer(x interface{}) { + if T, ok := x.(arrayer); ok { + xp := T.arrPtr() + + // if the underlying array hasn't been allocated, or not enough has been allocated + if a.Ptr == nil || a.L < xp.L || a.C < xp.C { + a.t = xp.t + a.L = xp.L + a.C = xp.C + a.Ptr = malloc(a.t, a.L) + } + + a.t = xp.t + a.L = xp.L + a.C = xp.C + copyArray(a, T.arrPtr()) + a.v = nil // tell the GC to release whatever a.v may hold + a.forcefix() // fix it such that a.v has a value and is not nil + return + } + a.fromSlice(x) +} + +// fix fills the a.v empty interface{} if it's not nil func (a *array) fix() { if a.v == nil { - shdr := reflect.SliceHeader{ - Data: uintptr(a.Ptr), - Len: a.L, - Cap: a.C, - } - sliceT := reflect.SliceOf(a.t.Type) - ptr := unsafe.Pointer(&shdr) - val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) - a.v = val.Interface() + a.forcefix() } } +// forcefix fills the a.v empty interface{}. No checks are made if the thing is empty +func (a *array) forcefix() { + sliceT := reflect.SliceOf(a.t.Type) + ptr := unsafe.Pointer(&a.Header) + val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) + a.v = val.Interface() +} + // byteSlice casts the underlying slice into a byte slice. Useful for copying and zeroing, but not much else func (a array) byteSlice() []byte { return storage.AsByteSlice(&a.Header, a.t.Type) @@ -132,6 +156,7 @@ func (a *array) sliceInto(i, j int, res *array) { res.fix() } +// slice slices an array func (a array) slice(start, end int) array { if end > a.L { panic("Index out of range") @@ -240,6 +265,13 @@ func (a *array) rtype() reflect.Type { return a.t.Type } /* MEMORY MOVEMENT STUFF */ +// malloc is standard Go allocation of a block of memory - the plus side is that Go manages the memory +func malloc(t Dtype, length int) unsafe.Pointer { + size := int(calcMemSize(t, length)) + s := make([]byte, size) + return unsafe.Pointer(&s[0]) +} + // calcMemSize calulates the memory size of an array (given its size) func calcMemSize(dt Dtype, size int) int64 { return int64(dt.Size()) * int64(size) @@ -288,6 +320,7 @@ func copyDense(dst, src DenseTensor) int { // return copyArray(dst.arr(), src.arr()) } +// copyDenseSliced copies a DenseTensor, but both are sliced func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, send int) int { if dst.Dtype() != src.Dtype() { panic("Cannot copy DenseTensors of different types") @@ -316,12 +349,14 @@ func copyDenseSliced(dst DenseTensor, dstart, dend int, src DenseTensor, sstart, return copyArraySliced(dst.arr(), dstart, dend, src.arr(), sstart, send) } +// copyDenseIter copies a DenseTensor, with iterator func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { if dst.Dtype() != src.Dtype() { panic("Cannot copy Dense arrays of different types") } - if !dst.RequiresIterator() && !src.RequiresIterator() { + // if they all don't need iterators, and have the same data order + if !dst.RequiresIterator() && !src.RequiresIterator() && dst.DataOrder().HasSameOrder(src.DataOrder()) { return copyDense(dst, src), nil } @@ -336,6 +371,7 @@ func copyDenseIter(dst, src DenseTensor, diter, siter Iterator) (int, error) { siter = FlatIteratorFromDense(src) } + // if it's a masked tensor, we copy the mask as well if ms, ok := src.(MaskedTensor); ok && ms.IsMasked() { if md, ok := dst.(MaskedTensor); ok { dmask := md.Mask() @@ -388,12 +424,34 @@ func getPointer(a interface{}) unsafe.Pointer { case string: return unsafe.Pointer(&at) case uintptr: - return unsafe.Pointer(&at) + return unsafe.Pointer(at) case unsafe.Pointer: return at // POINTERS + case *bool: + return unsafe.Pointer(at) + case *int: + return unsafe.Pointer(at) + case *int8: + return unsafe.Pointer(at) + case *int16: + return unsafe.Pointer(at) + case *int32: + return unsafe.Pointer(at) + case *int64: + return unsafe.Pointer(at) + case *uint: + return unsafe.Pointer(at) + case *uint8: + return unsafe.Pointer(at) + case *uint16: + return unsafe.Pointer(at) + case *uint32: + return unsafe.Pointer(at) + case *uint64: + return unsafe.Pointer(at) case *float32: return unsafe.Pointer(at) case *float64: @@ -402,11 +460,18 @@ func getPointer(a interface{}) unsafe.Pointer { return unsafe.Pointer(at) case *complex128: return unsafe.Pointer(at) + case *string: + return unsafe.Pointer(at) + case *uintptr: + return unsafe.Pointer(*at) + case *unsafe.Pointer: + return *at } panic("Cannot get pointer") } +// scalarToHeader creates a Header from a scalar value func scalarToHeader(a interface{}) *storage.Header { hdr := borrowHeader() hdr.Ptr = getPointer(a) diff --git a/benchmark_dense_matop_test.go b/benchmark_dense_matop_test.go index 2c5b8a7..2a4ee4a 100644 --- a/benchmark_dense_matop_test.go +++ b/benchmark_dense_matop_test.go @@ -1,6 +1,9 @@ package tensor -import "testing" +import ( + "math/rand" + "testing" +) func BenchmarkDense_Transpose(b *testing.B) { T := New(WithShape(100, 100, 2), WithBacking(Range(Byte, 0, 100*100*2))) @@ -64,7 +67,7 @@ func BenchmarkGetWithIterator(b *testing.B) { f = data[next] } if _, ok := err.(NoOpError); !ok { - b.Error("Error: %v", err) + b.Errorf("Error: %v", err) } } _ = f @@ -85,8 +88,57 @@ func BenchmarkComplicatedGet(b *testing.B) { f = data[next] } if _, ok := err.(NoOpError); !ok { - b.Error("Error: %v", err) + b.Errorf("Error: %v", err) } } _ = f } + +var atCoords [10000][2]int + +func init() { + for i := range atCoords { + atCoords[i][0] = rand.Intn(100) + atCoords[i][1] = rand.Intn(100) + } +} + +var at1, at2 float64 + +// func BenchmarkAtWithNativeIterator(b *testing.B) { +// T := New(WithShape(100, 100), Of(Float64)) +// it, err := NativeMatrixF64(T) +// if err != nil { +// b.Fatalf("Error: %v", err) +// } + +// var j int +// for i := 0; i < b.N; i++ { + +// if j >= len(atCoords) { +// j = 0 +// } + +// at := atCoords[j] +// at1 = it[at[0]][at[1]] +// j++ +// } +// } + +func BenchmarkAt(b *testing.B) { + T := New(WithShape(100, 100), Of(Float64)) + var j int + for i := 0; i < b.N; i++ { + if j >= len(atCoords) { + j = 0 + } + + at := atCoords[j] + _, err := T.At(at[0], at[1]) + if err != nil { + b.Errorf("Error: %v", err) + } + + j++ + } +} diff --git a/consopt.go b/consopt.go index 9118cad..ee4b4cf 100644 --- a/consopt.go +++ b/consopt.go @@ -10,6 +10,7 @@ type ConsOpt func(Tensor) // Of is a construction option for a Tensor. func Of(a Dtype) ConsOpt { + Register(a) f := func(t Tensor) { switch tt := t.(type) { case *Dense: @@ -172,9 +173,11 @@ func WithEngine(e Engine) ConsOpt { if e != nil && !e.AllocAccessible() { tt.flag = MakeMemoryFlag(tt.flag, NativelyInaccessible) } - // if oe, ok := e.(standardEngine); ok { - // tt.oe = oe - // } + + tt.oe = nil + if oe, ok := e.(standardEngine); ok { + tt.oe = oe + } case *CS: tt.e = e if e != nil && !e.AllocAccessible() { @@ -185,14 +188,75 @@ func WithEngine(e Engine) ConsOpt { return f } -func AsFortran() ConsOpt { +// AsFortran creates a *Dense with a col-major layout. +// If the optional backing argument is passed, the backing is assumed to be C-order (row major), and +// it will be transposed before being used. +func AsFortran(backing interface{}) ConsOpt { f := func(t Tensor) { switch tt := t.(type) { case *Dense: - if tt.AP == nil { - // create AP + if backing != nil { + // put the data into the tensor, then make a clone tensor to transpose + tt.fromSliceOrArrayer(backing) + // create a temporary tensor, to which the transpose will be done + tmp := NewDense(tt.Dtype(), tt.shape.Clone()) + copyArray(tmp.arrPtr(), tt.arrPtr()) + tmp.T() + tmp.Transpose() + // copy the data back to the current tensor + copyArray(tt.arrPtr(), tmp.arrPtr()) + // cleanup: return the temporary tensor back to the pool + ReturnTensor(tmp) } + tt.AP.o = MakeDataOrder(tt.AP.o, ColMajor) + if tt.AP.shape != nil { + ReturnInts(tt.AP.strides) + tt.AP.strides = nil + tt.AP.strides = tt.AP.calcStrides() + } + case *CS: + panic("AsFortran is not an available option for Compressed Sparse layouts") + } + } + return f +} + +func AsDenseDiag(backing interface{}) ConsOpt { + f := func(t Tensor) { + switch tt := t.(type) { + case *Dense: + if bt, ok := backing.(Tensor); ok { + backing = bt.Data() + } + xT := reflect.TypeOf(backing) + if xT.Kind() != reflect.Slice { + panic("Expected a slice") + } + xV := reflect.ValueOf(backing) + l := xV.Len() + // elT := xT.Elem() + + sli := reflect.MakeSlice(xT, l*l, l*l) + + shape := Shape{l, l} + strides := shape.CalcStrides() + for i := 0; i < l; i++ { + idx, err := Ltoi(shape, strides, i, i) + if err != nil { + panic(err) + } + + at := sli.Index(idx) + xi := xV.Index(i) + at.Set(xi) + } + + tt.fromSliceOrArrayer(sli.Interface()) + tt.setShape(l, l) + + default: + panic("AsDenseDiag is not available as an option for CS") } } return f diff --git a/defaultengine.go b/defaultengine.go index 6dd1f45..cace41a 100644 --- a/defaultengine.go +++ b/defaultengine.go @@ -66,12 +66,7 @@ func (e StdEng) Memcpy(dst, src Memory) error { func (e StdEng) Accessible(mem Memory) (Memory, error) { return mem, nil } -func (e StdEng) WorksWith(order DataOrder) bool { - if order.isColMajor() { - return false - } - return true -} +func (e StdEng) WorksWith(order DataOrder) bool { return true } func (e StdEng) checkAccessible(t Tensor) error { if !t.IsNativelyAccessible() { diff --git a/defaultengine_argmethods.go b/defaultengine_argmethods.go index 3cedd84..5632fa6 100644 --- a/defaultengine_argmethods.go +++ b/defaultengine_argmethods.go @@ -59,17 +59,21 @@ func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e if _, ok := err.(NoOpError); !ok && err != nil { return } else if ok { - newAP = t.Info().Clone() + t.Info().CloneTo(&newAP) } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() @@ -144,15 +148,19 @@ func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err e } else if ok { newAP = t.Info().Clone() } - defer ReturnAP(newAP) it := IteratorFromDense(t) - iteratorLoadAP(it, newAP) + iteratorLoadAP(it, &newAP) lastSize := it.Shape()[len(it.Shape())-1] newShape := it.Shape().Clone() newShape = newShape[:len(newShape)-1] - defer ReturnInts(newShape) + + // cleanup + defer func() { + newAP.zero() + ReturnInts(newShape) + }() if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { mask := mt.Mask() diff --git a/defaultengine_arith.go b/defaultengine_arith.go index 01d9784..3017aaa 100644 --- a/defaultengine_arith.go +++ b/defaultengine_arith.go @@ -16,7 +16,7 @@ func (e StdEng) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -83,7 +83,7 @@ func (e StdEng) Sub(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -150,7 +150,7 @@ func (e StdEng) Mul(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -217,7 +217,7 @@ func (e StdEng) Div(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -284,7 +284,7 @@ func (e StdEng) Pow(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -351,7 +351,7 @@ func (e StdEng) Mod(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } typ := a.Dtype().Type @@ -418,7 +418,7 @@ func (e StdEng) AddScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -506,7 +506,7 @@ func (e StdEng) SubScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -594,7 +594,7 @@ func (e StdEng) MulScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -682,7 +682,7 @@ func (e StdEng) DivScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -770,7 +770,7 @@ func (e StdEng) PowScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t @@ -858,7 +858,7 @@ func (e StdEng) ModScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } a := t diff --git a/defaultengine_cmp.go b/defaultengine_cmp.go index 98f61e1..b3651d7 100644 --- a/defaultengine_cmp.go +++ b/defaultengine_cmp.go @@ -18,7 +18,7 @@ func (e StdEng) Gt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -98,7 +98,7 @@ func (e StdEng) Gte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -178,7 +178,7 @@ func (e StdEng) Lt(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err erro var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -258,7 +258,7 @@ func (e StdEng) Lte(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err err var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -338,7 +338,7 @@ func (e StdEng) ElEq(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -418,7 +418,7 @@ func (e StdEng) ElNe(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, err er var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -498,7 +498,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -564,7 +564,7 @@ func (e StdEng) GtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -610,7 +610,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -676,7 +676,7 @@ func (e StdEng) GteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -722,7 +722,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -788,7 +788,7 @@ func (e StdEng) LtScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -834,7 +834,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -900,7 +900,7 @@ func (e StdEng) LteScalar(t Tensor, s interface{}, leftTensor bool, opts ...Func } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -942,7 +942,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -1008,7 +1008,7 @@ func (e StdEng) EqScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) @@ -1050,7 +1050,7 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO var reuse DenseTensor var safe, same bool - if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, same, err = handleFuncOpts(t.Shape(), t.Dtype(), t.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if !safe { @@ -1116,7 +1116,7 @@ func (e StdEng) NeScalar(t Tensor, s interface{}, leftTensor bool, opts ...FuncO } // handle special case where A and B have both len 1 - if dataB.L == 1 && dataB.L == 1 { + if dataA.L == 1 && dataB.L == 1 { switch { case same && safe && reuse != nil && leftTensor: storage.Copy(typ, dataReuse, dataA) diff --git a/defaultengine_linalg.go b/defaultengine_linalg.go index 486e7a0..45a8527 100644 --- a/defaultengine_linalg.go +++ b/defaultengine_linalg.go @@ -286,13 +286,14 @@ func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { var rd *Dense if rd, err = a.TensorMul(b, axesA, axesB); err != nil { + panic(err) return } if reuse != nil { copyDense(reuse, rd) - ReturnAP(reuse.Info()) - reuse.setAP(rd.Info().Clone()) + ap := rd.Info().Clone() + reuse.setAP(&ap) defer ReturnTensor(rd) // swap out the underlying data and metadata // reuse.data, rd.data = rd.data, reuse.data @@ -403,12 +404,35 @@ func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { n := ad.oshape()[1] tA := blas.NoTrans - if ad.oldAP() != nil { + do := a.DataOrder() + z := ad.oldAP().IsZero() + + var lda int + switch { + case do.IsRowMajor() && z: + lda = n + case do.IsRowMajor() && !z: + tA = blas.Trans + lda = n + case do.IsColMajor() && z: tA = blas.Trans + lda = m + m, n = n, m + case do.IsColMajor() && !z: + lda = m + m, n = n, m } - lda := ad.ostrides()[0] + incX, incY := 1, 1 // step size + // ASPIRATIONAL TODO: different incX and incY + // TECHNICAL DEBT. TECHDEBT. TECH DEBT + // Example use case: + // log.Printf("a %v %v", ad.Strides(), ad.ostrides()) + // log.Printf("b %v", b.Strides()) + // incX := a.Strides()[0] + // incY = b.Strides()[0] + switch A := ad.Data().(type) { case []float64: x := bd.Float64s() @@ -438,49 +462,61 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { return errors.Wrapf(err, opFail, "StdEng.MatMul") } - tA, tB := blas.NoTrans, blas.NoTrans - if ad.oldAP() != nil { - tA = blas.Trans - } - - // Special case if b is (1, N) - if bd.oldAP() != nil || bd.IsRowVec() { - tB = blas.Trans - } + ado := a.DataOrder() + bdo := b.DataOrder() + cdo := prealloc.DataOrder() + // get result shapes. k is the shared dimension + // a is (m, k) + // b is (k, n) + // c is (m, n) var m, n, k int m = ad.Shape()[0] k = ad.Shape()[1] n = bd.Shape()[1] // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides() - lda := ad.ostrides()[0] - ldb := bd.ostrides()[0] - ldc := pd.ostrides()[0] + // lda in colmajor = number of rows; + // lda in row major = number of cols + var lda, ldb, ldc int + switch { + case ado.IsColMajor(): + lda = m + case ado.IsRowMajor(): + lda = k + } - // special case: if a is (1, N) x (N, M), then we can just use GEMV - if ad.IsRowVec() { - tB = blas.Trans - if bd.oldAP() != nil { - tB = blas.NoTrans + switch { + case bdo.IsColMajor(): + ldb = bd.Shape()[0] + case bdo.IsRowMajor(): + ldb = n + } + + switch { + case cdo.IsColMajor(): + ldc = prealloc.Shape()[0] + case cdo.IsRowMajor(): + ldc = prealloc.Shape()[1] + } + + // check for trans + tA, tB := blas.NoTrans, blas.NoTrans + if !ad.oldAP().IsZero() { + tA = blas.Trans + if ado.IsRowMajor() { + lda = m + } else { + lda = k } - m = bd.Shape()[0] - n = bd.Shape()[1] - switch A := ad.Data().(type) { - case []float64: - B := bd.Float64s() - C := pd.Float64s() - alpha, beta := float64(1), float64(0) - whichblas.Dgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - case []float32: - B := bd.Float32s() - C := pd.Float32s() - alpha, beta := float32(1), float32(0) - whichblas.Sgemv(tB, m, n, alpha, B, ldb, A, lda, beta, C, ldc) - default: - return errors.Errorf(typeNYI, "matMul a is row vec", ad.Data()) + } + if !bd.oldAP().IsZero() { + tB = blas.Trans + if bdo.IsRowMajor() { + ldb = bd.Shape()[0] + } else { + ldb = bd.Shape()[1] } - return } switch A := ad.Data().(type) { @@ -488,12 +524,20 @@ func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { B := bd.Float64s() C := pd.Float64s() alpha, beta := float64(1), float64(0) - whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Dgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } case []float32: B := bd.Float32s() C := pd.Float32s() alpha, beta := float32(1), float32(0) - whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + if ado.IsColMajor() && bdo.IsColMajor() { + whichblas.Sgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) + } else { + whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + } default: return errors.Errorf(typeNYI, "matMul", ad.Data()) } @@ -510,12 +554,40 @@ func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { m := ad.Size() n := bd.Size() + pdo := pd.DataOrder() // the stride of a Vector is always going to be [1], // incX := t.Strides()[0] // incY := other.Strides()[0] incX, incY := 1, 1 - lda := pd.Strides()[0] + // lda := pd.Strides()[0] + var lda int + switch { + case pdo.IsColMajor(): + aShape := a.Shape().Clone() + bShape := b.Shape().Clone() + if err = a.Reshape(aShape[0], 1); err != nil { + return err + } + if err = b.Reshape(1, bShape[0]); err != nil { + return err + } + + if err = e.MatMul(a, b, prealloc); err != nil { + return err + } + + if err = b.Reshape(bShape...); err != nil { + return + } + if err = a.Reshape(aShape...); err != nil { + return + } + return nil + + case pdo.IsRowMajor(): + lda = pd.Shape()[1] + } switch x := ad.Data().(type) { case []float64: @@ -559,13 +631,13 @@ func (e StdEng) checkTwoFloatTensors(a, b Tensor) (ad, bd DenseTensor, err error func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { if err = e.checkAccessible(a); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(b); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") } if err = e.checkAccessible(ret); err != nil { - return nil, nil, nil, errors.Wrap(err, "checkTwoTensors: ret is not accessible") + return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") } if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { diff --git a/defaultengine_mapreduce.go b/defaultengine_mapreduce.go index 203a839..4964ab4 100644 --- a/defaultengine_mapreduce.go +++ b/defaultengine_mapreduce.go @@ -16,7 +16,7 @@ func (e StdEng) Map(fn interface{}, a Tensor, opts ...FuncOpt) (retVal Tensor, e var reuse DenseTensor var safe, _, incr bool - if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, _, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return } switch { @@ -102,18 +102,18 @@ func (e StdEng) Reduce(fn interface{}, a Tensor, axis int, defaultValue interfac // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, fn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -147,18 +147,18 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, // actual call out to the internal engine switch { - case (axis == 0 && at.DataOrder().isRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().isColMajor()): + case (axis == 0 && at.DataOrder().IsRowMajor()) || ((axis == lastAxis || axis == len(a.Shape())-1) && at.DataOrder().IsColMajor()): var size, split int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } size = a.Shape()[0] split = a.DataSize() / size storage.CopySliced(typ, dataReuse, 0, split, dataA, 0, split) err = e.E.ReduceFirst(typ, dataA, dataReuse, split, size, firstFn) - case (axis == lastAxis && at.DataOrder().isRowMajor()) || (axis == 0 && at.DataOrder().isColMajor()): + case (axis == lastAxis && at.DataOrder().IsRowMajor()) || (axis == 0 && at.DataOrder().IsColMajor()): var dimSize int - if at.DataOrder().isColMajor() { + if at.DataOrder().IsColMajor() { return nil, errors.Errorf("NYI: colmajor") } dimSize = a.Shape()[axis] @@ -328,7 +328,7 @@ func (StdEng) prepReduce(a Tensor, axis int, opts ...FuncOpt) (at, reuse DenseTe // FUNC PREP var safe bool - if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, _, _, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { err = errors.Wrap(err, "Unable to prep unary tensor") return } diff --git a/defaultengine_matop_misc.go b/defaultengine_matop_misc.go index 23607c6..9faed77 100644 --- a/defaultengine_matop_misc.go +++ b/defaultengine_matop_misc.go @@ -1,6 +1,12 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) + +var ( + _ Diager = StdEng{} +) func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) { switch tt := t.(type) { @@ -104,9 +110,18 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen all[0] = a copy(all[1:], Ts) + // TODO: OPIMIZATION + // When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor) + // just flat copy + // + + // isOuter is true when the axis is the outermost axis + // isInner is true when the axis is the inner most axis + isOuter := axis == 0 + isInner := axis == (a.Shape().Dims() - 1) + // special case var start, end int - for _, T := range all { end += T.Shape()[axis] slices := make([]Slice, axis+1) @@ -117,15 +132,124 @@ func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTen return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat") } - if v.IsVector() && T.IsMatrix() && axis == 0 { + switch { + case v.IsVector() && T.IsMatrix() && axis == 0: v.reshape(v.shape[0], 1) + case T.IsRowVec() && axis == 0: + T.reshape(T.Shape()[1]) + case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv(): + copyArray(v.arrPtr(), T.arrPtr()) + if mt, ok := T.(MaskedTensor); ok { + copy(v.mask, mt.Mask()) + } + continue + default: + diff := retVal.Shape().Dims() - v.Shape().Dims() + if diff > 0 && isOuter { + newShape := make(Shape, v.Shape().Dims()+diff) + for i := 0; i < diff; i++ { + newShape[i] = 1 + } + copy(newShape[diff:], v.Shape()) + v.reshape(newShape...) + } else if diff > 0 && isInner { + newShape := v.Shape().Clone() + newStrides := v.strides + for i := 0; i < diff; i++ { + newShape = append(newShape, 1) + newStrides = append(newStrides, 1) + } + v.shape = newShape + v.strides = newStrides + } + } + + var vmask, Tmask []bool + vmask = v.mask + v.mask = nil + if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() { + Tmask = mt.Mask() + mt.SetMask(nil) + } if err = assignArray(v, T); err != nil { return nil, errors.Wrap(err, "Unable to assignArray in denseConcat") } + // if it's a masked tensor, we copy the mask as well + if Tmask != nil { + if vmask != nil { + if cap(vmask) < len(Tmask) { + vmask2 := make([]bool, len(Tmask)) + copy(vmask2, vmask) + vmask = vmask2 + } + copy(vmask, Tmask) + v.SetMask(vmask) + } + // mt.SetMask(Tmask) + } + start = end } return retVal, nil } + +func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) { + a, ok := t.(DenseTensor) + if !ok { + return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()") + } + + if a.Dims() != 2 { + err = errors.Errorf(dimMismatch, 2, a.Dims()) + return + } + + if err = typeclassCheck(a.Dtype(), numberTypes); err != nil { + return nil, errors.Wrap(err, "Diagonal") + } + + rstride := a.Strides()[0] + cstride := a.Strides()[1] + + r := a.Shape()[0] + c := a.Shape()[1] + + m := MinInt(r, c) + stride := rstride + cstride + + b := a.Clone().(DenseTensor) + b.Zero() + + switch a.rtype().Size() { + case 1: + bdata := b.hdr().Uint8s() + adata := a.hdr().Uint8s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 2: + bdata := b.hdr().Uint16s() + adata := a.hdr().Uint16s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 4: + bdata := b.hdr().Uint32s() + adata := a.hdr().Uint32s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + case 8: + bdata := b.hdr().Uint64s() + adata := a.hdr().Uint64s() + for i := 0; i < m; i++ { + bdata[i] = adata[i*stride] + } + default: + return nil, errors.Errorf(typeNYI, "Arbitrary sized diag") + } + return b, nil +} diff --git a/defaultengine_matop_stack.go b/defaultengine_matop_stack.go index 1a43a7e..368ddb5 100644 --- a/defaultengine_matop_stack.go +++ b/defaultengine_matop_stack.go @@ -28,15 +28,13 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV info := t.Info() var newStrides []int - if info.o.isColMajor() { - newStrides = newShape.calcStridesColMajor() + if info.o.IsColMajor() { + newStrides = newShape.CalcStridesColMajor() } else { - newStrides = newShape.calcStrides() + newStrides = newShape.CalcStrides() } - ap := NewAP(newShape, newStrides) - ap.o = info.o - ap.Δ = info.Δ + ap := MakeAP(newShape, newStrides, info.o, info.Δ) allNoMat := !t.RequiresIterator() for _, ot := range others { @@ -46,8 +44,7 @@ func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retV } retVal = recycledDense(t.Dtype(), ap.Shape(), WithEngine(e)) - ReturnAP(retVal.Info()) - retVal.setAP(ap) + retVal.setAP(&ap) // the "viewStack" method is the more generalized method // and will work for all Tensors, regardless of whether it's a view diff --git a/defaultengine_matop_transpose.go b/defaultengine_matop_transpose.go index e66c4a6..8f7c86c 100644 --- a/defaultengine_matop_transpose.go +++ b/defaultengine_matop_transpose.go @@ -44,7 +44,7 @@ func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { u8s := tmpArr.Uint8s() orig := a.hdr().Uint8s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u8s[j] = orig[i] @@ -59,7 +59,7 @@ func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { u16s := tmpArr.Uint16s() orig := a.hdr().Uint16s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u16s[j] = orig[i] @@ -74,7 +74,7 @@ func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { u32s := tmpArr.Uint32s() orig := a.hdr().Uint32s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u32s[j] = orig[i] @@ -89,7 +89,7 @@ func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { u64s := tmpArr.Uint64s() orig := a.hdr().Uint64s() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { u64s[j] = orig[i] @@ -104,7 +104,7 @@ func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { strs := tmpArr.Strings() orig := a.hdr().Strings() - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { strs[j] = orig[i] @@ -122,7 +122,7 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { arbs := tmpArr.byteSlice() orig := storage.AsByteSlice(a.hdr(), rtype) - it := NewFlatIterator(a.Info()) + it := newFlatIterator(a.Info()) var j int for i, err := it.Next(); err == nil; i, err = it.Next() { srcStart := i * typeSize diff --git a/defaultengine_matop_transpose_inplace.go b/defaultengine_matop_transpose_inplace.go index 6725aea..d8a87e4 100644 --- a/defaultengine_matop_transpose_inplace.go +++ b/defaultengine_matop_transpose_inplace.go @@ -51,6 +51,9 @@ func (e StdEng) denseTranspose1(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint8s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -87,6 +90,9 @@ func (e StdEng) denseTranspose2(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint16s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -123,6 +129,9 @@ func (e StdEng) denseTranspose4(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint32s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -159,6 +168,9 @@ func (e StdEng) denseTranspose8(a DenseTensor, expStrides []int) { var i int data := a.hdr().Uint64s() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) if track.IsSet(i) && track.IsSet(dest) { @@ -195,6 +207,9 @@ func (e StdEng) denseTransposeString(a DenseTensor, expStrides []int) { var i int data := a.hdr().Strings() + if len(data) < 4 { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) @@ -233,6 +248,9 @@ func (e StdEng) denseTransposeArbitrary(a DenseTensor, expStrides []int) { tmp := make([]byte, typeSize, typeSize) var i int data := storage.AsByteSlice(a.hdr(), rtype) + if len(data) < 4*typeSize { + return + } for i = 1; ; { dest := a.transposeIndex(i, axes, expStrides) start := typeSize * i diff --git a/defaultengine_misc.go b/defaultengine_misc.go index b4bf21c..bb70e57 100644 --- a/defaultengine_misc.go +++ b/defaultengine_misc.go @@ -12,7 +12,7 @@ func (e StdEng) Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (retVal T var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), false, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), false, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } diff --git a/defaultengine_prep.go b/defaultengine_prep.go index cb358a7..c203253 100644 --- a/defaultengine_prep.go +++ b/defaultengine_prep.go @@ -3,9 +3,10 @@ package tensor import ( "github.com/pkg/errors" "gorgonia.org/tensor/internal/storage" + // "log" ) -func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { +func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -16,7 +17,7 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) if toReuse { if reuse, err = getDenseTensor(reuseT); err != nil { returnOpOpt(fo) - err = errors.Wrapf(err, "Cannot reuse a different type of Tensor in a *Dense-Scalar operation") + err = errors.Wrapf(err, "Cannot reuse a Tensor that isn't a DenseTensor. Got %T instead", reuseT) return } @@ -40,6 +41,11 @@ func handleFuncOpts(expShape Shape, expType Dtype, strict bool, opts ...FuncOpt) return } + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -101,7 +107,11 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea } // iter - useIter = a.RequiresIterator() || b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) + useIter = a.RequiresIterator() || + b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + !a.DataOrder().HasSameOrder(b.DataOrder()) || + (reuse != nil && (!a.DataOrder().HasSameOrder(reuse.DataOrder()) || !b.DataOrder().HasSameOrder(reuse.DataOrder()))) if useIter { ait = a.Iterator() bit = b.Iterator() @@ -109,6 +119,7 @@ func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Hea iit = reuse.Iterator() } } + // log.Printf("Use Itrer %v ", useIter) // swap if _, ok := a.(*CS); ok { @@ -133,12 +144,14 @@ func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse if a.IsScalar() { return } - if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = a.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && reuse.DataOrder().HasSameOrder(a.DataOrder())) + if useIter { ait = a.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } @@ -155,12 +168,14 @@ func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse if b.IsScalar() { return } - if b.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) { + useIter = b.RequiresIterator() || + (reuse != nil && reuse.RequiresIterator()) || + (reuse != nil && reuse.DataOrder().HasSameOrder(b.DataOrder())) + if useIter { bit = b.Iterator() if reuse != nil { iit = reuse.Iterator() } - useIter = true } return } diff --git a/defaultengine_unary.go b/defaultengine_unary.go index 4da968a..986e246 100644 --- a/defaultengine_unary.go +++ b/defaultengine_unary.go @@ -14,7 +14,7 @@ func (e StdEng) Neg(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -82,7 +82,7 @@ func (e StdEng) Inv(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -150,7 +150,7 @@ func (e StdEng) Square(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -218,7 +218,7 @@ func (e StdEng) Cube(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -286,7 +286,7 @@ func (e StdEng) Exp(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -354,7 +354,7 @@ func (e StdEng) Tanh(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -422,7 +422,7 @@ func (e StdEng) Log(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -490,7 +490,7 @@ func (e StdEng) Log2(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -558,7 +558,7 @@ func (e StdEng) Log10(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -626,7 +626,7 @@ func (e StdEng) Sqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -694,7 +694,7 @@ func (e StdEng) Cbrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -762,7 +762,7 @@ func (e StdEng) InvSqrt(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -830,7 +830,7 @@ func (e StdEng) Abs(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } @@ -898,7 +898,7 @@ func (e StdEng) Sign(a Tensor, opts ...FuncOpt) (retVal Tensor, err error) { } var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), true, opts...); err != nil { + if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } diff --git a/defaultenginefloat32.go b/defaultenginefloat32.go index f260479..82d48f2 100644 --- a/defaultenginefloat32.go +++ b/defaultenginefloat32.go @@ -10,7 +10,7 @@ import ( "gorgonia.org/vecf32" ) -func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF32(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -30,6 +30,12 @@ func handleFuncOptsF32(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -175,7 +181,7 @@ func (e Float32Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), opts...); err != nil { + if reuse, safe, toReuse, incr, err = handleFuncOptsF32(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if err = e.checkThree(a, b, reuse); err != nil { diff --git a/defaultenginefloat64.go b/defaultenginefloat64.go index 6fe2786..b0d9466 100644 --- a/defaultenginefloat64.go +++ b/defaultenginefloat64.go @@ -10,7 +10,7 @@ import ( "gorgonia.org/vecf64" ) -func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { +func handleFuncOptsF64(expShape Shape, o DataOrder, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr bool, err error) { fo := ParseFuncOpts(opts...) reuseT, incr := fo.IncrReuse() @@ -30,6 +30,12 @@ func handleFuncOptsF64(expShape Shape, opts ...FuncOpt) (reuse DenseTensor, safe err = errors.Wrapf(err, "Cannot use reuse: shape mismatch") return } + + if !incr && reuse != nil { + reuse.setDataOrder(o) + // err = reuse.reshape(expShape...) + } + } returnOpOpt(fo) return @@ -175,7 +181,7 @@ func (e Float64Engine) Add(a Tensor, b Tensor, opts ...FuncOpt) (retVal Tensor, var reuse DenseTensor var safe, toReuse, incr bool - if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), opts...); err != nil { + if reuse, safe, toReuse, incr, err = handleFuncOptsF64(a.Shape(), a.DataOrder(), opts...); err != nil { return nil, errors.Wrap(err, "Unable to handle funcOpts") } if err = e.checkThree(a, b, reuse); err != nil { diff --git a/dense.go b/dense.go index 95e3f53..2824261 100644 --- a/dense.go +++ b/dense.go @@ -13,7 +13,7 @@ const ( // Dense represents a dense tensor - this is the most common form of tensors. It can be used to represent vectors, matrices.. etc type Dense struct { - *AP + AP array flag MemoryFlag @@ -21,7 +21,7 @@ type Dense struct { oe standardEngine // optimized engine // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes - old *AP + old AP transposeWith []int // if viewOf != nil, then this *Dense is a view. @@ -54,7 +54,7 @@ func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) retVal.array.t = dt retVal.array.L = size retVal.array.C = size - retVal.AP = BorrowAP(shape.Dims()) + retVal.AP.zeroWithDims(shape.Dims()) for _, opt := range opts { opt(retVal) @@ -78,9 +78,14 @@ func (t *Dense) addMask(mask []bool) { } func (t *Dense) makeArray(size int) { - if am, ok := t.e.(arrayMaker); ok { - am.makeArray(&t.array, t.t, size) + + switch te := t.e.(type) { + case NonStdEngine: + t.flag = MakeMemoryFlag(t.flag, ManuallyManaged) + case arrayMaker: + te.makeArray(&t.array, t.t, size) return + default: } mem, err := t.e.Alloc(calcMemSize(t.t, size)) @@ -97,7 +102,7 @@ func (t *Dense) makeArray(size int) { } // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging. -func (t *Dense) Info() *AP { return t.AP } +func (t *Dense) Info() *AP { return &t.AP } // Dtype returns the data type of the *Dense tensor. func (t *Dense) Dtype() Dtype { return t.t } @@ -123,11 +128,11 @@ func (t *Dense) Engine() Engine { return t.e } // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens func (t *Dense) Reshape(dims ...int) error { - if t.viewOf != 0 && t.o.isNotContiguous() { + if t.viewOf != 0 && t.o.IsNotContiguous() { return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") } - if t.old != nil { + if !t.old.IsZero() { t.Transpose() } @@ -159,7 +164,7 @@ func (t *Dense) IsView() bool { // IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing func (t *Dense) IsMaterializable() bool { - return t.viewOf != 0 || t.old != nil + return t.viewOf != 0 || !t.old.IsZero() } // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) @@ -172,15 +177,16 @@ func (t *Dense) IsNativelyAccessible() bool { return t.flag.nativelyAccessible() func (t *Dense) Clone() interface{} { if t.e != nil { retVal := new(Dense) - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.t = t.t retVal.e = t.e retVal.oe = t.oe retVal.flag = t.flag retVal.makeArray(t.L) - if t.old != nil { + if !t.old.IsZero() { retVal.old = t.old.Clone() + t.old.CloneTo(&retVal.old) } copyDense(retVal, t) retVal.lock() @@ -246,13 +252,9 @@ func (t *Dense) setShape(s ...int) { return } -func (t *Dense) setAP(ap *AP) { t.AP = ap } +func (t *Dense) setAP(ap *AP) { t.AP = *ap } func (t *Dense) fix() { - if t.AP == nil { - return - } - if t.e == nil { t.e = StdEng{} } @@ -298,31 +300,33 @@ func (t *Dense) makeMask() { // sanity is a function that sanity checks that a tensor is correct. func (t *Dense) sanity() error { - if t.AP != nil && t.Shape() == nil && t.array.Ptr == nil { + if !t.AP.IsZero() && t.Shape() == nil && t.array.Ptr == nil { return errors.New(emptyTensor) } size := t.L expected := t.Size() if t.viewOf == 0 && size != expected && !t.IsScalar() { - return errors.Errorf(shapeMismatch, t.Shape(), size) + return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") } // TODO: sanity check for views return nil } -func (t *Dense) isTransposed() bool { return t.old == nil } +// isTransposed returns true if the *Dense holds a transposed array. +func (t *Dense) isTransposed() bool { return t.old.IsZero() } // oshape returns the original shape func (t *Dense) oshape() Shape { - if t.old != nil { + if !t.old.IsZero() { return t.old.Shape() } return t.Shape() } +// ostrides returns the original strides func (t *Dense) ostrides() []int { - if t.old != nil { + if !t.old.IsZero() { return t.old.Strides() } return t.Strides() @@ -333,14 +337,14 @@ func (t *Dense) ShallowClone() *Dense { retVal := borrowDense() retVal.e = t.e retVal.oe = t.oe - retVal.AP = t.AP.Clone() + t.AP.CloneTo(&retVal.AP) retVal.flag = t.flag retVal.array = t.array return retVal } -func (t *Dense) oldAP() *AP { return t.old } -func (t *Dense) setOldAP(ap *AP) { t.old = ap } +func (t *Dense) oldAP() *AP { return &t.old } +func (t *Dense) setOldAP(ap *AP) { t.old = *ap } func (t *Dense) transposeAxes() []int { return t.transposeWith } func (t *Dense) parentTensor() *Dense { if t.viewOf != 0 { @@ -537,7 +541,7 @@ func (t *Dense) Memset(x interface{}) error { return errors.Errorf(inaccessibleData, t) } if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) return t.array.memsetIter(x, it) } return t.array.Memset(x) @@ -560,7 +564,7 @@ func (t *Dense) Eq(other interface{}) bool { func (t *Dense) Zero() { if t.IsMaterializable() { - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) if err := t.zeroIter(it); err != nil { panic(err) } @@ -590,7 +594,7 @@ func (t *Dense) RequiresIterator() bool { return false } // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required - if !t.o.isContiguous() || t.old != nil || t.IsMasked() { + if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { return true } return false diff --git a/dense_assign.go b/dense_assign.go index 0fdc1d4..8b2783e 100644 --- a/dense_assign.go +++ b/dense_assign.go @@ -84,11 +84,11 @@ func assignArray(dest, src DenseTensor) (err error) { return } dap := dest.Info() - sap := NewAP(tmpShape, newStrides) - sap.o = src.Info().o + sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ) - diter := NewFlatIterator(dap) - siter := NewFlatIterator(sap) + diter := newFlatIterator(dap) + siter := newFlatIterator(&sap) _, err = copyDenseIter(dest, src, diter, siter) + sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up. Don't double free return } diff --git a/dense_cmp_test.go b/dense_cmp_test.go index 4c1db8e..a0bc5b6 100644 --- a/dense_cmp_test.go +++ b/dense_cmp_test.go @@ -62,7 +62,7 @@ func TestDense_Gt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -120,7 +120,7 @@ func TestDense_Gte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -178,7 +178,7 @@ func TestDense_Lt(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -236,7 +236,7 @@ func TestDense_Lte(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -294,7 +294,7 @@ func TestDense_ElEq(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -328,7 +328,7 @@ func TestDense_ElEq(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe(t *testing.T) { @@ -363,7 +363,7 @@ func TestDense_ElNe(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_Gt_assame(t *testing.T) { @@ -422,7 +422,7 @@ func TestDense_Gt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -482,7 +482,7 @@ func TestDense_Gte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -542,7 +542,7 @@ func TestDense_Lt_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -602,7 +602,7 @@ func TestDense_Lte_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -662,7 +662,7 @@ func TestDense_ElEq_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -699,7 +699,7 @@ func TestDense_ElEq_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } } func TestDense_ElNe_assame(t *testing.T) { @@ -737,7 +737,7 @@ func TestDense_ElNe_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElNe failed: %v", err) + t.Errorf("Transitivity test for ElNe failed: %v", err) } } func TestDense_GtScalar(t *testing.T) { @@ -792,7 +792,7 @@ func TestDense_GtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -848,7 +848,7 @@ func TestDense_GteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -904,7 +904,7 @@ func TestDense_LtScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -960,7 +960,7 @@ func TestDense_LteScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1016,7 +1016,7 @@ func TestDense_ElEqScalar(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1048,7 +1048,7 @@ func TestDense_ElEqScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar(t *testing.T) { @@ -1081,7 +1081,7 @@ func TestDense_ElNeScalar(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } func TestDense_GtScalar_assame(t *testing.T) { @@ -1138,7 +1138,7 @@ func TestDense_GtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gt failed: %v", err) + t.Errorf("Transitivity test for Gt failed: %v", err) } } @@ -1196,7 +1196,7 @@ func TestDense_GteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Gte failed: %v", err) + t.Errorf("Transitivity test for Gte failed: %v", err) } } @@ -1254,7 +1254,7 @@ func TestDense_LtScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lt failed: %v", err) + t.Errorf("Transitivity test for Lt failed: %v", err) } } @@ -1312,7 +1312,7 @@ func TestDense_LteScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for Lte failed: %v", err) + t.Errorf("Transitivity test for Lte failed: %v", err) } } @@ -1370,7 +1370,7 @@ func TestDense_ElEqScalar_assame(t *testing.T) { return true } if err := quick.Check(transFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Transitivity test for ElEq failed: %v", err) + t.Errorf("Transitivity test for ElEq failed: %v", err) } symFn := func(q *Dense) bool { @@ -1405,7 +1405,7 @@ func TestDense_ElEqScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElEq failed: %v", err) + t.Errorf("Symmetry test for ElEq failed: %v", err) } } func TestDense_ElNeScalar_assame(t *testing.T) { @@ -1441,6 +1441,6 @@ func TestDense_ElNeScalar_assame(t *testing.T) { } if err := quick.Check(symFn, &quick.Config{Rand: newRand(), MaxCount: quickchecks}); err != nil { - t.Error("Symmetry test for ElNe failed: %v", err) + t.Errorf("Symmetry test for ElNe failed: %v", err) } } diff --git a/dense_colmajor_linalg_test.go b/dense_colmajor_linalg_test.go new file mode 100644 index 0000000..feccfc5 --- /dev/null +++ b/dense_colmajor_linalg_test.go @@ -0,0 +1,483 @@ +package tensor + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +var colMajorTraceTests = []struct { + data interface{} + + correct interface{} + err bool +}{ + {[]int{0, 1, 2, 3, 4, 5}, int(4), false}, + {[]int8{0, 1, 2, 3, 4, 5}, int8(4), false}, + {[]int16{0, 1, 2, 3, 4, 5}, int16(4), false}, + {[]int32{0, 1, 2, 3, 4, 5}, int32(4), false}, + {[]int64{0, 1, 2, 3, 4, 5}, int64(4), false}, + {[]uint{0, 1, 2, 3, 4, 5}, uint(4), false}, + {[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false}, + {[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false}, + {[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false}, + {[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false}, + {[]float32{0, 1, 2, 3, 4, 5}, float32(4), false}, + {[]float64{0, 1, 2, 3, 4, 5}, float64(4), false}, + {[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false}, + {[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false}, + {[]bool{true, false, true, false, true, false}, nil, true}, +} + +func TestColMajor_Dense_Trace(t *testing.T) { + assert := assert.New(t) + for i, tts := range colMajorTraceTests { + T := New(WithShape(2, 3), AsFortran(tts.data)) + trace, err := T.Trace() + + if checkErr(t, tts.err, err, "Trace", i) { + continue + } + assert.Equal(tts.correct, trace) + + // + T = New(WithBacking(tts.data)) + _, err = T.Trace() + if err == nil { + t.Error("Expected an error when Trace() on non-matrices") + } + } +} + +var colMajorInnerTests = []struct { + a, b interface{} + shapeA, shapeB Shape + + correct interface{} + err bool +}{ + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false}, + + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false}, + + // stupids: type differences + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, + + // differing size + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true}, + + // A is not a matrix + {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true}, +} + +func TestColMajor_Dense_Inner(t *testing.T) { + for i, its := range colMajorInnerTests { + a := New(WithShape(its.shapeA...), AsFortran(its.a)) + b := New(WithShape(its.shapeB...), AsFortran(its.b)) + + T, err := a.Inner(b) + if checkErr(t, its.err, err, "Inner", i) { + continue + } + + assert.Equal(t, its.correct, T) + } +} + +var colMajorMatVecMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, + Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, + []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, + + // stupids : unpossible shapes (wrong A) + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad A shape + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad B shape + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + //stupids: bad reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + //stupids: bad incr shape + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch A and B (non-Float) + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, + + // stupids: type mismatch, reuse + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, + + // stupids: type mismatch, incr + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, + + // stupids: type mismatch, incr not a Number + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, + Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, + []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, +} + +func TestColMajor_Dense_MatVecMul(t *testing.T) { + assert := assert.New(t) + for i, mvmt := range colMajorMatVecMulTests { + a := New(WithShape(mvmt.shapeA...), AsFortran(mvmt.a)) + b := New(WithShape(mvmt.shapeB...), AsFortran(mvmt.b)) + + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } + + T, err := a.MatVecMul(b) + if checkErr(t, mvmt.err, err, "Safe", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // incr + incr := New(WithShape(mvmt.shapeI...), AsFortran(mvmt.incr)) + T, err = a.MatVecMul(b, WithIncr(incr)) + if checkErr(t, mvmt.errIncr, err, "WithIncr", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mvmt.shapeR...), AsFortran(mvmt.reuse)) + T, err = a.MatVecMul(b, WithReuse(reuse)) + if checkErr(t, mvmt.errReuse, err, "WithReuse", i) { + continue + } + + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) { + continue + } + assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mvmt.correctIncrReuse, T.Data()) + } +} + +var colMajorMatMulTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Float32s + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, + + // Edge cases - Row Vecs (Float64) + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float64{0, 0, 0, 1, 0, 2}, []float64{100, 103, 101, 105, 102, 107}, []float64{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, + + // Edge cases - Row Vecs (Float32) + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, + Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, + []float32{0, 0, 0, 1, 0, 2}, []float32{100, 103, 101, 105, 102, 107}, []float32{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, + Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, + []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, + Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, + []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, + + // stupids - bad shape (not matrices): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (incompatible shapes): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - bad shape (bad reuse shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids - bad shape (bad incr shape): + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids - type mismatch (a,b) + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids - type mismatch (a,b) + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (b not float) + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids type mismatch (a not float) + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, + + // stupids: type mismatch (incr) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, + + // stupids: type mismatch (reuse) + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, + + // stupids: type mismatch (reuse) + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, + Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, + []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, +} + +func TestColMajorDense_MatMul(t *testing.T) { + assert := assert.New(t) + for i, mmt := range colMajorMatMulTests { + a := New(WithShape(mmt.shapeA...), AsFortran(mmt.a)) + b := New(WithShape(mmt.shapeB...), AsFortran(mmt.b)) + + T, err := a.MatMul(b) + if checkErr(t, mmt.err, err, "Safe", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(mmt.correct, T.Data(), "Test %d", i) + + // incr + incr := New(WithShape(mmt.shapeI...), AsFortran(mmt.incr)) + T, err = a.MatMul(b, WithIncr(incr)) + if checkErr(t, mmt.errIncr, err, "WithIncr", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(mmt.shapeR...), AsFortran(mmt.reuse)) + T, err = a.MatMul(b, WithReuse(reuse)) + + if checkErr(t, mmt.errReuse, err, "WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correct, T.Data()) + + // reuse AND incr + T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse)) + if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) { + continue + } + assert.True(mmt.correctShape.Eq(T.Shape())) + assert.Equal(mmt.correctIncrReuse, T.Data()) + } +} + +var colMajorOuterTests = []linalgTest{ + // Float64s + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // Float32s + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float32{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, false}, + + // stupids - a or b not vector + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - bad incr shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, true, false}, + + // stupids - bad reuse shape + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + false, false, true}, + + // stupids - b not Float + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a not Float + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids - a-b type mismatch + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, + + // stupids a-b type mismatch + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, + Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, + []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, + true, false, false}, +} + +func TestColMajor_Dense_Outer(t *testing.T) { + assert := assert.New(t) + for i, ot := range colMajorOuterTests { + a := New(WithShape(ot.shapeA...), AsFortran(ot.a)) + b := New(WithShape(ot.shapeB...), AsFortran(ot.b)) + + T, err := a.Outer(b) + if checkErr(t, ot.err, err, "Safe", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // incr + incr := New(WithShape(ot.shapeI...), AsFortran(ot.incr)) + T, err = a.Outer(b, WithIncr(incr)) + if checkErr(t, ot.errIncr, err, "WithIncr", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncr, T.Data()) + + // reuse + reuse := New(WithShape(ot.shapeR...), AsFortran(ot.reuse)) + T, err = a.Outer(b, WithReuse(reuse)) + if checkErr(t, ot.errReuse, err, "WithReuse", i) { + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correct, T.Data()) + + // reuse AND incr + T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse)) + if err != nil { + t.Errorf("Reuse and Incr error'd %+v", err) + continue + } + assert.True(ot.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsColMajor()) + assert.Equal(ot.correctIncrReuse, T.Data()) + } +} diff --git a/dense_compat.go b/dense_compat.go index 151ae0a..a1b90ab 100644 --- a/dense_compat.go +++ b/dense_compat.go @@ -394,14 +394,12 @@ func FromMat64(m *mat.Dense, opts ...FuncOpt) *Dense { func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { // checks: if !t.IsNativelyAccessible() { - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") } if !t.IsMatrix() { // error - err = errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return + return nil, errors.Errorf("Cannot convert *Dense to *mat.Dense. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) } fo := ParseFuncOpts(opts...) @@ -420,7 +418,7 @@ func ToMat64(t *Dense, opts ...FuncOpt) (retVal *mat.Dense, err error) { case !t.IsMaterializable(): data = convToFloat64s(t) default: - it := NewFlatIterator(t.AP) + it := newFlatIterator(&t.AP) var next int for next, err = it.Next(); err == nil; next, err = it.Next() { if err = handleNoOp(err); err != nil { diff --git a/dense_format.go b/dense_format.go index 9d31994..859477f 100644 --- a/dense_format.go +++ b/dense_format.go @@ -249,7 +249,7 @@ func (f *fmtState) writeVElision() { // Special care also needs be taken for the verb 's' - it prints a super compressed version of the tensor, only printing 4 cols and 4 rows. func (t *Dense) Format(s fmt.State, c rune) { if c == 'i' { - fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\t", t.AP, t.old, t.transposeWith) + fmt.Fprintf(s, "INFO:\n\tAP: %v\n\tOLD: %v\n\tTRANS %v\n\tENGINE: %T\n", t.AP, t.old, t.transposeWith, t.e) return } @@ -353,7 +353,7 @@ func (t *Dense) Format(s fmt.State, c rune) { } // standard stuff - it := NewIterator(t.AP) + it := NewIterator(&t.AP) coord := it.Coord() firstRow := true diff --git a/dense_generated.go b/dense_generated.go index 93ea20b..6349bfb 100644 --- a/dense_generated.go +++ b/dense_generated.go @@ -88,7 +88,7 @@ func I(dt Dtype, r, c, k int) *Dense { panic(err) } var nexts []int - iter := NewFlatIterator(s.AP) + iter := newFlatIterator(&s.AP) nexts, err = iter.Slice(rs{i, s.Size(), c + 1}) switch s.t.Kind() { diff --git a/dense_io.go b/dense_io.go index 84896eb..c55c66a 100644 --- a/dense_io.go +++ b/dense_io.go @@ -14,26 +14,140 @@ import ( "strconv" "strings" + flatbuffers "github.com/google/flatbuffers/go" "github.com/pkg/errors" + "gorgonia.org/tensor/internal/serialization/fb" + "gorgonia.org/tensor/internal/serialization/pb" ) +/* GOB SERIALIZATION */ + +// GobEncode implements gob.GobEncoder +func (t *Dense) GobEncode() (p []byte, err error) { + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + + if err = encoder.Encode(t.Shape()); err != nil { + return + } + + if err = encoder.Encode(t.Strides()); err != nil { + return + } + + if err = encoder.Encode(t.AP.o); err != nil { + return + } + + if err = encoder.Encode(t.AP.Δ); err != nil { + return + } + + if err = encoder.Encode(t.mask); err != nil { + return + } + + data := t.Data() + if err = encoder.Encode(&data); err != nil { + return + } + + return buf.Bytes(), err +} + +// GobDecode implements gob.GobDecoder +func (t *Dense) GobDecode(p []byte) (err error) { + buf := bytes.NewBuffer(p) + decoder := gob.NewDecoder(buf) + + var shape Shape + if err = decoder.Decode(&shape); err != nil { + return + } + + var strides []int + if err = decoder.Decode(&strides); err != nil { + return + } + + var o DataOrder + var tr Triangle + if err = decoder.Decode(&o); err == nil { + if err = decoder.Decode(&tr); err != nil { + return + } + } + + t.AP.Init(shape, strides) + t.AP.o = o + t.AP.Δ = tr + + var mask []bool + if err = decoder.Decode(&mask); err != nil { + return + } + + var data interface{} + if err = decoder.Decode(&data); err != nil { + return + } + + t.fromSlice(data) + t.addMask(mask) + t.fix() + if t.e == nil { + t.e = StdEng{} + } + return t.sanity() +} + +/* NPY SERIALIZATION */ + +var npyDescRE = regexp.MustCompile(`'descr':\s*'([^']*)'`) +var rowOrderRE = regexp.MustCompile(`'fortran_order':\s*(False|True)`) +var shapeRE = regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) + type binaryWriter struct { io.Writer - error + err error seq int } -func (w binaryWriter) w(x interface{}) { - if w.error != nil { +func (w *binaryWriter) w(x interface{}) { + if w.err != nil { return } - binary.Write(w, binary.LittleEndian, x) + w.err = binary.Write(w, binary.LittleEndian, x) w.seq++ } -func (w binaryWriter) Error() string { - return fmt.Sprintf("Error at sequence %d : %v", w.seq, w.error.Error()) +func (w *binaryWriter) Err() error { + if w.err == nil { + return nil + } + return errors.Wrapf(w.err, "Sequence %d", w.seq) +} + +type binaryReader struct { + io.Reader + err error + seq int +} + +func (r *binaryReader) Read(data interface{}) { + if r.err != nil { + return + } + r.err = binary.Read(r.Reader, binary.LittleEndian, data) + r.seq++ +} + +func (r *binaryReader) Err() error { + if r.err == nil { + return nil + } + return errors.Wrapf(r.err, "Sequence %d", r.seq) } // WriteNpy writes the *Tensor as a numpy compatible serialized file. @@ -64,8 +178,8 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { bw.w(byte(1)) // major version bw.w(byte(0)) // minor version bw.w(uint16(len(header))) // 4 bytes to denote header length - if bw.error != nil { - return bw + if err = bw.Err(); err != nil { + return err } bw.Write([]byte(header)) @@ -86,176 +200,57 @@ func (t *Dense) WriteNpy(w io.Writer) (err error) { } } - if bw.error != nil { - return bw - } - return nil -} - -// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. -// If tensor is masked, invalid values are replaced by the default fill value. -func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { - // checks: - if !t.IsMatrix() { - // error - err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) - return - } - format := "%v" - if len(formats) > 0 { - format = formats[0] - } - - cw := csv.NewWriter(w) - it := IteratorFromDense(t) - coord := it.Coord() - - // rows := t.Shape()[0] - cols := t.Shape()[1] - record := make([]string, 0, cols) - var i, k, lastCol int - isMasked := t.IsMasked() - fillval := t.FillValue() - fillstr := fmt.Sprintf(format, fillval) - for i, err = it.Next(); err == nil; i, err = it.Next() { - record = append(record, fmt.Sprintf(format, t.Get(i))) - if isMasked { - if t.mask[i] { - record[k] = fillstr - } - k++ - } - if lastCol == cols-1 { - if err = cw.Write(record); err != nil { - // TODO: wrap errors - return - } - cw.Flush() - record = record[:0] - } - - // cleanup - switch { - case t.IsRowVec(): - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - case t.IsColVec(): - // lastRow = coord[len(coord)-1] - lastCol = coord[len(coord)-2] - case t.IsVector(): - lastCol = coord[len(coord)-1] - default: - // lastRow = coord[len(coord)-2] - lastCol = coord[len(coord)-1] - } - } - return nil -} - -// GobEncode implements gob.GobEncoder -func (t *Dense) GobEncode() (p []byte, err error) { - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - - if err = encoder.Encode(t.Shape()); err != nil { - return - } - - if err = encoder.Encode(t.Strides()); err != nil { - return - } - - if err = encoder.Encode(t.AP.o); err != nil { - return - } - - if err = encoder.Encode(t.AP.Δ); err != nil { - return - } - - if err = encoder.Encode(t.mask); err != nil { - return - } - - data := t.Data() - if err = encoder.Encode(&data); err != nil { - return - } - - return buf.Bytes(), err + return bw.Err() } // ReadNpy reads NumPy formatted files into a *Dense func (t *Dense) ReadNpy(r io.Reader) (err error) { + br := binaryReader{Reader: r} var magic [6]byte - if _, err = r.Read(magic[:]); err != nil { - return - } - if string(magic[:]) != "\x93NUMPY" { - err = errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) - return + if br.Read(magic[:]); string(magic[:]) != "\x93NUMPY" { + return errors.Errorf("Not a numpy file. Got %q as the magic number instead", string(magic[:])) } - var version byte - if err = binary.Read(r, binary.LittleEndian, &version); err != nil { - return - } - if version != 1 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + var version, minor byte + if br.Read(&version); version != 1 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } - var minor byte - if err = binary.Read(r, binary.LittleEndian, &minor); err != nil { - return - } - if minor != 0 { - err = errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") - return + if br.Read(&minor); minor != 0 { + return errors.New("Only verion 1.0 of numpy's serialization format is currently supported (65535 bytes ought to be enough for a header)") } var headerLen uint16 - if err = binary.Read(r, binary.LittleEndian, &headerLen); err != nil { - return - } - + br.Read(&headerLen) header := make([]byte, int(headerLen)) - if _, err = r.Read(header); err != nil { + br.Read(header) + if err = br.Err(); err != nil { return } - desc := regexp.MustCompile(`'descr':\s*'([^']*)'`) - match := desc.FindSubmatch(header) - if match == nil { - err = errors.New("No dtype information in npy file") - return + // extract stuff from header + var match [][]byte + if match = npyDescRE.FindSubmatch(header); match == nil { + return errors.New("No dtype information in npy file") } // TODO: check for endianness. For now we assume everything is little endian - var dt Dtype - if dt, err = fromNumpyDtype(string(match[1][1:])); err != nil { + if t.t, err = fromNumpyDtype(string(match[1][1:])); err != nil { return } - t.t = dt - rowOrder := regexp.MustCompile(`'fortran_order':\s*(False|True)`) - match = rowOrder.FindSubmatch(header) - if match == nil { - err = errors.New("No Row Order information found in the numpy file") - return + if match = rowOrderRE.FindSubmatch(header); match == nil { + return errors.New("No Row Order information found in the numpy file") } if string(match[1]) != "False" { - err = errors.New("Cannot yet read from Fortran Ordered Numpy files") - return + return errors.New("Cannot yet read from Fortran Ordered Numpy files") } - shpRe := regexp.MustCompile(`'shape':\s*\(([^\(]*)\)`) - match = shpRe.FindSubmatch(header) - if match == nil { - err = errors.New("No shape information found in npy file") - return + if match = shapeRE.FindSubmatch(header); match == nil { + return errors.New("No shape information found in npy file") } sizesStr := strings.Split(string(match[1]), ",") + var shape Shape for _, s := range sizesStr { s = strings.Trim(s, " ") @@ -273,163 +268,166 @@ func (t *Dense) ReadNpy(r io.Reader) (err error) { if t.e == nil { t.e = StdEng{} } - t.makeArray(size) switch t.t.Kind() { case reflect.Int: data := t.Ints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int8: data := t.Int8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int16: data := t.Int16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int32: data := t.Int32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Int64: data := t.Int64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint: data := t.Uints() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint8: data := t.Uint8s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint16: data := t.Uint16s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint32: data := t.Uint32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Uint64: data := t.Uint64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float32: data := t.Float32s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Float64: data := t.Float64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex64: data := t.Complex64s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } case reflect.Complex128: data := t.Complex128s() for i := 0; i < size; i++ { - if err = binary.Read(r, binary.LittleEndian, &data[i]); err != nil { - return - } + br.Read(&data[i]) } } - t.AP = BorrowAP(len(shape)) + if err = br.Err(); err != nil { + return err + } + + t.AP.zeroWithDims(len(shape)) t.setShape(shape...) t.fix() return t.sanity() } -// GobDecode implements gob.GobDecoder -func (t *Dense) GobDecode(p []byte) (err error) { - buf := bytes.NewBuffer(p) - decoder := gob.NewDecoder(buf) - - var shape Shape - if err = decoder.Decode(&shape); err != nil { - return - } +/* CSV SERIALIZATION */ - var strides []int - if err = decoder.Decode(&strides); err != nil { +// WriteCSV writes the *Dense to a CSV. It accepts an optional string formatting ("%v", "%f", etc...), which controls what is written to the CSV. +// If tensor is masked, invalid values are replaced by the default fill value. +func (t *Dense) WriteCSV(w io.Writer, formats ...string) (err error) { + // checks: + if !t.IsMatrix() { + // error + err = errors.Errorf("Cannot write *Dense to CSV. Expected number of dimensions: <=2, T has got %d dimensions (Shape: %v)", t.Dims(), t.Shape()) return } - - var o DataOrder - var tr Triangle - if err = decoder.Decode(&o); err == nil { - if err = decoder.Decode(&tr); err != nil { - return - } + format := "%v" + if len(formats) > 0 { + format = formats[0] } - t.AP = NewAP(shape, strides) - t.AP.o = o - t.AP.Δ = tr + cw := csv.NewWriter(w) + it := IteratorFromDense(t) + coord := it.Coord() - var mask []bool - if err = decoder.Decode(&mask); err != nil { - return - } + // rows := t.Shape()[0] + cols := t.Shape()[1] + record := make([]string, 0, cols) + var i, k, lastCol int + isMasked := t.IsMasked() + fillval := t.FillValue() + fillstr := fmt.Sprintf(format, fillval) + for i, err = it.Next(); err == nil; i, err = it.Next() { + record = append(record, fmt.Sprintf(format, t.Get(i))) + if isMasked { + if t.mask[i] { + record[k] = fillstr + } + k++ + } + if lastCol == cols-1 { + if err = cw.Write(record); err != nil { + // TODO: wrap errors + return + } + cw.Flush() + record = record[:0] + } - var data interface{} - if err = decoder.Decode(&data); err != nil { - return + // cleanup + switch { + case t.IsRowVec(): + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + case t.IsColVec(): + // lastRow = coord[len(coord)-1] + lastCol = coord[len(coord)-2] + case t.IsVector(): + lastCol = coord[len(coord)-1] + default: + // lastRow = coord[len(coord)-2] + lastCol = coord[len(coord)-1] + } } - t.fromSlice(data) - t.addMask(mask) - t.fix() - return t.sanity() + return nil } -// convFromStrs conversts a []string to a slice of the Dtype provided -func convFromStrs(to Dtype, record []string) (interface{}, error) { +// convFromStrs converts a []string to a slice of the Dtype provided. It takes a provided backing slice. +// If into is nil, then a backing slice will be created. +func convFromStrs(to Dtype, record []string, into interface{}) (interface{}, error) { var err error switch to.Kind() { case reflect.Int: retVal := make([]int, len(record)) + var backing []int + if into == nil { + backing = make([]int, 0, len(record)) + } else { + backing = into.([]int) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 0); err != nil { @@ -437,9 +435,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int8: retVal := make([]int8, len(record)) + var backing []int8 + if into == nil { + backing = make([]int8, 0, len(record)) + } else { + backing = into.([]int8) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 8); err != nil { @@ -447,9 +453,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int8(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int16: retVal := make([]int16, len(record)) + var backing []int16 + if into == nil { + backing = make([]int16, 0, len(record)) + } else { + backing = into.([]int16) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 16); err != nil { @@ -457,9 +471,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int16(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int32: retVal := make([]int32, len(record)) + var backing []int32 + if into == nil { + backing = make([]int32, 0, len(record)) + } else { + backing = into.([]int32) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 32); err != nil { @@ -467,9 +489,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int32(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Int64: retVal := make([]int64, len(record)) + var backing []int64 + if into == nil { + backing = make([]int64, 0, len(record)) + } else { + backing = into.([]int64) + } + for i, v := range record { var i64 int64 if i64, err = strconv.ParseInt(v, 10, 64); err != nil { @@ -477,9 +507,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = int64(i64) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint: retVal := make([]uint, len(record)) + var backing []uint + if into == nil { + backing = make([]uint, 0, len(record)) + } else { + backing = into.([]uint) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 0); err != nil { @@ -487,9 +525,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint8: retVal := make([]uint8, len(record)) + var backing []uint8 + if into == nil { + backing = make([]uint8, 0, len(record)) + } else { + backing = into.([]uint8) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 8); err != nil { @@ -497,9 +543,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint8(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint16: retVal := make([]uint16, len(record)) + var backing []uint16 + if into == nil { + backing = make([]uint16, 0, len(record)) + } else { + backing = into.([]uint16) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 16); err != nil { @@ -507,9 +561,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint16(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint32: retVal := make([]uint32, len(record)) + var backing []uint32 + if into == nil { + backing = make([]uint32, 0, len(record)) + } else { + backing = into.([]uint32) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 32); err != nil { @@ -517,9 +579,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint32(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Uint64: retVal := make([]uint64, len(record)) + var backing []uint64 + if into == nil { + backing = make([]uint64, 0, len(record)) + } else { + backing = into.([]uint64) + } + for i, v := range record { var u uint64 if u, err = strconv.ParseUint(v, 10, 64); err != nil { @@ -527,9 +597,17 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = uint64(u) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float32: retVal := make([]float32, len(record)) + var backing []float32 + if into == nil { + backing = make([]float32, 0, len(record)) + } else { + backing = into.([]float32) + } + for i, v := range record { var f float64 if f, err = strconv.ParseFloat(v, 32); err != nil { @@ -537,15 +615,33 @@ func convFromStrs(to Dtype, record []string) (interface{}, error) { } retVal[i] = float32(f) } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil case reflect.Float64: retVal := make([]float64, len(record)) + var backing []float64 + if into == nil { + backing = make([]float64, 0, len(record)) + } else { + backing = into.([]float64) + } + for i, v := range record { if retVal[i], err = strconv.ParseFloat(v, 64); err != nil { return nil, err } } - return retVal, nil + backing = append(backing, retVal...) + return backing, nil + case reflect.String: + var backing []string + if into == nil { + backing = make([]string, 0, len(record)) + } else { + backing = into.([]string) + } + backing = append(backing, record...) + return backing, nil default: return nil, errors.Errorf(methodNYI, "convFromStrs", to) } @@ -564,307 +660,223 @@ func (t *Dense) ReadCSV(r io.Reader, opts ...FuncOpt) (err error) { cr := csv.NewReader(r) var record []string - var row interface{} var rows, cols int - - switch as.Kind() { - case reflect.Int: - var backing []int - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int, record); err != nil { - return - } - backing = append(backing, row.([]int)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int8: - var backing []int8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int8, record); err != nil { - return - } - backing = append(backing, row.([]int8)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int16: - var backing []int16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int16, record); err != nil { - return - } - backing = append(backing, row.([]int16)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int32: - var backing []int32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int32, record); err != nil { - return - } - backing = append(backing, row.([]int32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Int64: - var backing []int64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Int64, record); err != nil { - return - } - backing = append(backing, row.([]int64)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint: - var backing []uint - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint, record); err != nil { - return - } - backing = append(backing, row.([]uint)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint8: - var backing []uint8 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint8, record); err != nil { - return - } - backing = append(backing, row.([]uint8)...) - cols = len(record) - rows++ + var backing interface{} + for { + record, err = cr.Read() + if err == io.EOF { + break + } else if err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint16: - var backing []uint16 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } - - if row, err = convFromStrs(Uint16, record); err != nil { - return - } - backing = append(backing, row.([]uint16)...) - cols = len(record) - rows++ + if backing, err = convFromStrs(as, record, backing); err != nil { + return } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint32: - var backing []uint32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } - - if err != nil { - return - } + cols = len(record) + rows++ + } + t.fromSlice(backing) + t.AP.zero() + t.AP.SetShape(rows, cols) + return nil + return errors.Errorf("not yet handled") +} - if row, err = convFromStrs(Uint32, record); err != nil { - return - } - backing = append(backing, row.([]uint32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Uint64: - var backing []uint64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } +/* FB SERIALIZATION */ - if err != nil { - return - } +// FBEncode encodes to a byte slice using flatbuffers. +// +// Only natively accessible data can be encided +func (t *Dense) FBEncode() ([]byte, error) { + builder := flatbuffers.NewBuilder(1024) + + fb.DenseStartShapeVector(builder, len(t.shape)) + for i := len(t.shape) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.shape[i])) + } + shape := builder.EndVector(len(t.shape)) + + fb.DenseStartStridesVector(builder, len(t.strides)) + for i := len(t.strides) - 1; i >= 0; i-- { + builder.PrependInt32(int32(t.strides[i])) + } + strides := builder.EndVector(len(t.strides)) + + var o uint32 + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + o = 0 + case t.o.IsRowMajor() && !t.o.IsContiguous(): + o = 1 + case t.o.IsColMajor() && t.o.IsContiguous(): + o = 2 + case t.o.IsColMajor() && !t.o.IsContiguous(): + o = 3 + } + + var triangle int32 + switch t.Δ { + case NotTriangle: + triangle = fb.TriangleNOT_TRIANGLE + case Upper: + triangle = fb.TriangleUPPER + case Lower: + triangle = fb.TriangleLOWER + case Symmetric: + triangle = fb.TriangleSYMMETRIC + } + + dt := builder.CreateString(t.Dtype().String()) + data := t.byteSlice() + + fb.DenseStartDataVector(builder, len(data)) + for i := len(data) - 1; i >= 0; i-- { + builder.PrependUint8(data[i]) + } + databyte := builder.EndVector(len(data)) + + fb.DenseStart(builder) + fb.DenseAddShape(builder, shape) + fb.DenseAddStrides(builder, strides) + fb.DenseAddO(builder, o) + fb.DenseAddT(builder, triangle) + fb.DenseAddType(builder, dt) + fb.DenseAddData(builder, databyte) + serialized := fb.DenseEnd(builder) + builder.Finish(serialized) + + return builder.FinishedBytes(), nil +} - if row, err = convFromStrs(Uint64, record); err != nil { - return - } - backing = append(backing, row.([]uint64)...) - cols = len(record) - rows++ +// FBDecode decodes a byteslice from a flatbuffer table into a *Dense +func (t *Dense) FBDecode(buf []byte) error { + serialized := fb.GetRootAsDense(buf, 0) + + o := serialized.O() + switch o { + case 0: + t.o = 0 + case 1: + t.o = MakeDataOrder(NonContiguous) + case 2: + t.o = MakeDataOrder(ColMajor) + case 3: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + + tri := serialized.T() + switch tri { + case fb.TriangleNOT_TRIANGLE: + t.Δ = NotTriangle + case fb.TriangleUPPER: + t.Δ = Upper + case fb.TriangleLOWER: + t.Δ = Lower + case fb.TriangleSYMMETRIC: + t.Δ = Symmetric + } + + t.shape = Shape(BorrowInts(serialized.ShapeLength())) + for i := 0; i < serialized.ShapeLength(); i++ { + t.shape[i] = int(int32(serialized.Shape(i))) + } + + t.strides = BorrowInts(serialized.StridesLength()) + for i := 0; i < serialized.ShapeLength(); i++ { + t.strides[i] = int(serialized.Strides(i)) + } + typ := string(serialized.Type()) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float32: - var backing []float32 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + } - if err != nil { - return - } + if t.e == nil { + t.e = StdEng{} + } + t.makeArray(t.shape.TotalSize()) - if row, err = convFromStrs(Float32, record); err != nil { - return - } - backing = append(backing, row.([]float32)...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.Float64: - var backing []float64 - for { - record, err = cr.Read() - if err == io.EOF { - break - } + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, serialized.DataBytes()) + t.forcefix() + return t.sanity() +} - if err != nil { - return - } +/* PB SERIALIZATION */ + +// PBEncode encodes the Dense into a protobuf byte slice. +func (t *Dense) PBEncode() ([]byte, error) { + var toSerialize pb.Dense + toSerialize.Shape = make([]int32, len(t.shape)) + for i, v := range t.shape { + toSerialize.Shape[i] = int32(v) + } + toSerialize.Strides = make([]int32, len(t.strides)) + for i, v := range t.strides { + toSerialize.Strides[i] = int32(v) + } + + switch { + case t.o.IsRowMajor() && t.o.IsContiguous(): + toSerialize.O = pb.RowMajorContiguous + case t.o.IsRowMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.RowMajorNonContiguous + case t.o.IsColMajor() && t.o.IsContiguous(): + toSerialize.O = pb.ColMajorContiguous + case t.o.IsColMajor() && !t.o.IsContiguous(): + toSerialize.O = pb.ColMajorNonContiguous + } + toSerialize.T = pb.Triangle(t.Δ) + toSerialize.Type = t.t.String() + data := t.byteSlice() + toSerialize.Data = make([]byte, len(data)) + copy(toSerialize.Data, data) + return toSerialize.Marshal() +} - if row, err = convFromStrs(Float64, record); err != nil { - return - } - backing = append(backing, row.([]float64)...) - cols = len(record) - rows++ +// PBDecode unmarshalls a protobuf byteslice into a *Dense. +func (t *Dense) PBDecode(buf []byte) error { + var toSerialize pb.Dense + if err := toSerialize.Unmarshal(buf); err != nil { + return err + } + t.shape = make(Shape, len(toSerialize.Shape)) + for i, v := range toSerialize.Shape { + t.shape[i] = int(v) + } + t.strides = make([]int, len(toSerialize.Strides)) + for i, v := range toSerialize.Strides { + t.strides[i] = int(v) + } + + switch toSerialize.O { + case pb.RowMajorContiguous: + case pb.RowMajorNonContiguous: + t.o = MakeDataOrder(NonContiguous) + case pb.ColMajorContiguous: + t.o = MakeDataOrder(ColMajor) + case pb.ColMajorNonContiguous: + t.o = MakeDataOrder(ColMajor, NonContiguous) + } + t.Δ = Triangle(toSerialize.T) + typ := string(toSerialize.Type) + for _, dt := range allTypes.set { + if dt.String() == typ { + t.t = dt + break } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - case reflect.String: - var backing []string - for { - record, err = cr.Read() - if err == io.EOF { - break - } + } - if err != nil { - return - } - backing = append(backing, record...) - cols = len(record) - rows++ - } - t.fromSlice(backing) - t.AP = new(AP) - t.AP.SetShape(rows, cols) - return nil - default: - return errors.Errorf("%v not yet handled", as) + if t.e == nil { + t.e = StdEng{} } - return errors.Errorf("not yet handled") + t.makeArray(t.shape.TotalSize()) + + // allocated data. Now time to actually copy over the data + db := t.byteSlice() + copy(db, toSerialize.Data) + return t.sanity() } diff --git a/dense_io_test.go b/dense_io_test.go index 01de3f0..3c75973 100644 --- a/dense_io_test.go +++ b/dense_io_test.go @@ -23,7 +23,7 @@ func TestSaveLoadNumpy(t *testing.T) { script := "import numpy as np\nx = np.load('test.npy')\nprint(x)" - cmd := exec.Command("python2") + cmd := exec.Command("python") stdin, err := cmd.StdinPipe() if err != nil { t.Error(err) @@ -204,5 +204,52 @@ func TestDense_GobEncodeDecode(t *testing.T) { assert.Equal(T.mask, T3.mask) } +} + +func TestDense_FBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.FBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.FBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } +} + +func TestDense_PBEncodeDecode(t *testing.T) { + assert := assert.New(t) + for _, gtd := range serializationTestData { + T := New(WithShape(2, 2), WithBacking(gtd)) + buf, err := T.PBEncode() + if err != nil { + t.Errorf("UNPOSSIBLE!: %v", err) + continue + } + + T2 := new(Dense) + if err = T2.PBDecode(buf); err != nil { + t.Errorf("Error while decoding %v: %v", gtd, err) + continue + } + + assert.Equal(T.Shape(), T2.Shape()) + assert.Equal(T.Strides(), T2.Strides()) + assert.Equal(T.Data(), T2.Data()) + + // TODO: MASKED ARRAY + } } diff --git a/dense_linalg.go b/dense_linalg.go index ca07663..6493808 100644 --- a/dense_linalg.go +++ b/dense_linalg.go @@ -1,6 +1,8 @@ package tensor -import "github.com/pkg/errors" +import ( + "github.com/pkg/errors" +) // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices func (t *Dense) Trace() (retVal interface{}, err error) { @@ -87,6 +89,9 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e @@ -133,10 +138,12 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e - if mm, ok := e.(MatMuler); ok { if err = mm.MatMul(t, other, retVal); err != nil { return @@ -170,6 +177,9 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) if retVal == nil { retVal = recycledDense(t.t, expectedShape) + if t.o.IsColMajor() { + AsFortran(nil)(retVal) + } } e := t.e @@ -310,7 +320,6 @@ func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err return } doOther.Transpose() - if err = doOther.Reshape(newShapeO...); err != nil { return } diff --git a/dense_linalg_test.go b/dense_linalg_test.go index bfd316c..a9a24dc 100644 --- a/dense_linalg_test.go +++ b/dense_linalg_test.go @@ -10,6 +10,7 @@ import ( type linalgTest struct { a, b interface{} shapeA, shapeB Shape + transA, transB bool reuse, incr interface{} shapeR, shapeI Shape @@ -118,89 +119,94 @@ func TestDense_Inner(t *testing.T) { var matVecMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, + // float64s with transposed matrix + {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, + Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, + []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, + // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, - {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, + {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, // stupids : unpossible shapes (wrong A) - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad A shape - {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, + {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad B shape - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, //stupids: bad reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, //stupids: bad incr shape - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B - {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch A and B (non-Float) - {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, // stupids: type mismatch, reuse - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, // stupids: type mismatch, incr - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, // stupids: type mismatch, incr not a Number - {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, + {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, } @@ -211,12 +217,19 @@ func TestDense_MatVecMul(t *testing.T) { a := New(WithBacking(mvmt.a), WithShape(mvmt.shapeA...)) b := New(WithBacking(mvmt.b), WithShape(mvmt.shapeB...)) + if mvmt.transA { + if err := a.T(); err != nil { + t.Error(err) + continue + } + } T, err := a.MatVecMul(b) if checkErr(t, mvmt.err, err, "Safe", i) { continue } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // incr @@ -227,6 +240,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correctIncr, T.Data()) // reuse @@ -237,6 +251,7 @@ func TestDense_MatVecMul(t *testing.T) { } assert.True(mvmt.correctShape.Eq(T.Shape())) + assert.True(T.DataOrder().IsRowMajor()) assert.Equal(mvmt.correct, T.Data()) // reuse AND incr @@ -251,89 +266,89 @@ func TestDense_MatVecMul(t *testing.T) { var matMulTests = []linalgTest{ // Float64s - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Float32s - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, // Edge cases - Row Vecs (Float64) - {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, []float64{0, 0, 0, 0, 1, 2}, []float64{100, 101, 102, 103, 105, 107}, []float64{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, // Edge cases - Row Vecs (Float32) - {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, []float32{0, 0, 0, 0, 1, 2}, []float32{100, 101, 102, 103, 105, 107}, []float32{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, + {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, - {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, + {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, // stupids - bad shape (not matrices): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (incompatible shapes): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - bad shape (bad reuse shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids - bad shape (bad incr shape): - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids - type mismatch (a,b) - {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids - type mismatch (a,b) - {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (b not float) - {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids type mismatch (a not float) - {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, // stupids: type mismatch (incr) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, // stupids: type mismatch (reuse) - {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, // stupids: type mismatch (reuse) - {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, + {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, } @@ -382,55 +397,55 @@ func TestDense_MatMul(t *testing.T) { var outerTests = []linalgTest{ // Float64s - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, // Float32s - {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, false}, // stupids - a or b not vector - {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, + {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - bad incr shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, true, false}, // stupids - bad reuse shape - {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, false, false, true}, // stupids - b not Float - {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a not Float - {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids - a-b type mismatch - {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, + {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, // stupids a-b type mismatch - {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, + {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, true, false, false}, diff --git a/dense_matop.go b/dense_matop.go index 46e8a55..1a3b815 100644 --- a/dense_matop.go +++ b/dense_matop.go @@ -5,17 +5,16 @@ import "github.com/pkg/errors" // T performs a thunked transpose. It doesn't actually do anything, except store extra information about the post-transposed shapes and strides // Usually this is more than enough, as BLAS will handle the rest of the transpose func (t *Dense) T(axes ...int) (err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { return handleNoOp(err) } // is there any old transposes that need to be done first? // this is important, because any old transposes for dim >=3 are merely permutations of the strides - if t.old != nil { + if !t.old.IsZero() { if t.IsVector() { // the transform that was calculated was a waste of time - return it to the pool then untranspose - ReturnAP(transform) t.UT() return } @@ -31,7 +30,6 @@ func (t *Dense) T(axes ...int) (err error) { // if it is reversed, well, we just restore the backed up one if isReversed { - ReturnAP(transform) t.UT() return } @@ -58,18 +56,17 @@ func (t *Dense) T(axes ...int) (err error) { // // Nothing will happen if there was no previous transpose func (t *Dense) UT() { - if t.old != nil { - ReturnAP(t.AP) + if !t.old.IsZero() { ReturnInts(t.transposeWith) t.AP = t.old - t.old = nil + t.old.zeroOnly() t.transposeWith = nil } } // SafeT is exactly like T(), except it returns a new *Dense. The data is also copied over, unmoved. func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { - var transform *AP + var transform AP if transform, axes, err = t.AP.T(axes...); err != nil { if err = handleNoOp(err); err != nil { return @@ -82,7 +79,7 @@ func (t *Dense) SafeT(axes ...int) (retVal *Dense, err error) { retVal.e = t.e retVal.oe = t.oe retVal.AP = transform - retVal.old = t.AP.Clone() + t.AP.CloneTo(&retVal.old) retVal.transposeWith = axes return @@ -209,7 +206,7 @@ func (t *Dense) CopyTo(other *Dense) error { // // The method treats as equivalent to a colon slice. T.Slice(nil) is equivalent to T[:] in Numpy syntax func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { - var newAP *AP + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { @@ -236,15 +233,14 @@ func (t *Dense) Slice(slices ...Slice) (retVal View, err error) { // The underlying data is the same. // This method will override ALL the metadata in view. func (t *Dense) SliceInto(view *Dense, slices ...Slice) (retVal View, err error) { - var newAP *AP + var newAP AP var ndStart, ndEnd int if newAP, ndStart, ndEnd, err = t.AP.S(t.len(), slices...); err != nil { return } - ReturnAP(view.AP) - view.AP = nil + view.AP.zero() view.array.v = nil // reset view.t = t.t @@ -314,6 +310,7 @@ func (t *Dense) RollAxis(axis, start int, safe bool) (retVal *Dense, err error) func (t *Dense) transposeIndex(i int, transposePat, strides []int) int { oldCoord, err := Itol(i, t.oshape(), t.ostrides()) if err != nil { + err = errors.Wrapf(err, "transposeIndex ItoL failure. i %d original shape %v. original strides %v", i, t.oshape(), t.ostrides()) panic(err) } diff --git a/dense_matop_memmove.go b/dense_matop_memmove.go index 05033ef..fe05f2a 100644 --- a/dense_matop_memmove.go +++ b/dense_matop_memmove.go @@ -9,7 +9,7 @@ import "github.com/pkg/errors" // https://en.wikipedia.org/wiki/In-place_matrix_transposition func (t *Dense) Transpose() error { // if there is no oldinfo, that means the current info is the latest, and not the transpose - if t.old == nil { + if t.old.IsZero() { return nil } @@ -18,8 +18,7 @@ func (t *Dense) Transpose() error { } defer func() { - ReturnAP(t.old) - t.old = nil + t.old.zero() t.transposeWith = nil }() @@ -27,10 +26,10 @@ func (t *Dense) Transpose() error { // important! because the strides would have changed once the underlying data changed var expStrides []int - if t.AP.o.isColMajor() { - expStrides = expShape.calcStridesColMajor() + if t.AP.o.IsColMajor() { + expStrides = expShape.CalcStridesColMajor() } else { - expStrides = expShape.calcStrides() + expStrides = expShape.CalcStrides() } defer ReturnInts(expStrides) defer func() { diff --git a/dense_matop_test.go b/dense_matop_test.go index 51ee94a..ca02d3e 100644 --- a/dense_matop_test.go +++ b/dense_matop_test.go @@ -135,10 +135,10 @@ var transposeTests = []struct { correctData interface{} }{ {"c.T()", Shape{4, 1}, nil, []float64{0, 1, 2, 3}, - Shape{1, 4}, []int{1}, []int{1}, []float64{0, 1, 2, 3}}, + Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}}, {"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3}, - Shape{4, 1}, []int{1}, []int{1}, []float32{0, 1, 2, 3}}, + Shape{4, 1}, []int{4, 1}, []int{1, 1}, []float32{0, 1, 2, 3}}, {"v.T()", Shape{4}, nil, []int{0, 1, 2, 3}, Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}}, @@ -216,10 +216,10 @@ func TestDense_Transpose(t *testing.T) { } assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides, T.Strides()) + assert.Equal(tts.correctStrides, T.Strides(), "Transpose %v. Expected stride: %v. Got %v", tts.name, tts.correctStrides, T.Strides()) T.Transpose() assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape()) - assert.Equal(tts.correctStrides2, T.Strides(), "Transpose %v - Wrong strides", tts.name) + assert.Equal(tts.correctStrides2, T.Strides(), "Transpose2 %v - Expected stride %v. Got %v", tts.name, tts.correctStrides2, T.Strides()) assert.Equal(tts.correctData, T.Data(), "Transpose %v", tts.name) } @@ -236,7 +236,7 @@ func TestDense_Transpose(t *testing.T) { t.Errorf("Stacked .T() #1 for vector. Error: %v", err) goto matrev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(T.IsColVec()) @@ -251,7 +251,7 @@ matrev: t.Errorf("Stacked .T() #2 for matrix reverse. Error: %v", err) goto matnorev } - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) assert.True(Shape{2, 3}.Eq(T.Shape())) @@ -278,12 +278,12 @@ func TestTUT(t *testing.T) { T = New(Of(Float64), WithShape(2, 3, 4)) T.T() T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) T.T(2, 0, 1) T.UT() - assert.Nil(T.old) + assert.True(T.old.IsZero()) assert.Nil(T.transposeWith) } @@ -493,16 +493,16 @@ var denseSliceTests = []struct { // colvec {"c[0]", Range(Int64, 0, 5), Shape{5, 1}, []Slice{ss(0)}, ScalarShape(), nil, int64(0)}, - {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1}, []float32{0, 1}}, - {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2}, []float64{0, 1, 2, 3, 4}}, + {"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1, 1}, []float32{0, 1}}, + {"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2, 1}, []float64{0, 1, 2, 3, 4}}, // // rowvec {"r[0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{ss(0)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[0:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 5, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}}, {"r[:, 0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, ss(0)}, ScalarShape(), nil, float64(0)}, - {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{1}, []float64{0, 1}}, - {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{2}, []float64{1, 2, 3, 4}}, + {"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{5, 1}, []float64{0, 1}}, + {"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{5, 2}, []float64{1, 2, 3, 4}}, // // matrix {"A[0]", Range(Float64, 0, 6), Shape{2, 3}, []Slice{ss(0)}, Shape{1, 3}, []int{1}, Range(Float64, 0, 3)}, @@ -540,7 +540,7 @@ func TestDense_Slice(t *testing.T) { assert.True(Shape{2}.Eq(V.Shape())) assert.Equal([]int{3}, V.Strides()) assert.Equal([]float32{0, 1, 2, 3}, V.Data()) - assert.Nil(V.(*Dense).old) + assert.True(V.(*Dense).old.IsZero()) // slice a sliced V, err = V.Slice(makeRS(1, 2)) @@ -623,49 +623,61 @@ func TestDense_RollAxis(t *testing.T) { } var concatTests = []struct { - name string - dt Dtype - a interface{} - shape Shape - axis int + name string + dt Dtype + a interface{} + b interface{} + shape Shape + shapeB Shape + axis int correctShape Shape correctData interface{} }{ // Float64 - {"vector", Float64, nil, Shape{2}, 0, Shape{4}, []float64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float64, nil, Shape{2, 2}, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float64, nil, Shape{2, 2}, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float64, nil, nil, Shape{2}, nil, 0, Shape{4}, []float64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}}, // Float32 - {"vector", Float32, nil, Shape{2}, 0, Shape{4}, []float32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Float32, nil, Shape{2, 2}, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Float32, nil, Shape{2, 2}, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Float32, nil, nil, Shape{2}, nil, 0, Shape{4}, []float32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Float32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Float32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}}, // Int - {"vector", Int, nil, Shape{2}, 0, Shape{4}, []int{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int, nil, Shape{2, 2}, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int, nil, Shape{2, 2}, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int, nil, nil, Shape{2}, nil, 0, Shape{4}, []int{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}}, // Int64 - {"vector", Int64, nil, Shape{2}, 0, Shape{4}, []int64{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int64, nil, Shape{2, 2}, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int64, nil, Shape{2, 2}, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int64, nil, nil, Shape{2}, nil, 0, Shape{4}, []int64{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}}, // Int32 - {"vector", Int32, nil, Shape{2}, 0, Shape{4}, []int32{0, 1, 0, 1}}, - {"matrix; axis 0 ", Int32, nil, Shape{2, 2}, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Int32, nil, Shape{2, 2}, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Int32, nil, nil, Shape{2}, nil, 0, Shape{4}, []int32{0, 1, 0, 1}}, + {"matrix; axis 0 ", Int32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Int32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}}, // Byte - {"vector", Byte, nil, Shape{2}, 0, Shape{4}, []byte{0, 1, 0, 1}}, - {"matrix; axis 0 ", Byte, nil, Shape{2, 2}, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, - {"matrix; axis 1 ", Byte, nil, Shape{2, 2}, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, + {"vector", Byte, nil, nil, Shape{2}, nil, 0, Shape{4}, []byte{0, 1, 0, 1}}, + {"matrix; axis 0 ", Byte, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}}, + {"matrix; axis 1 ", Byte, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}}, // Bool - {"vector", Bool, []bool{true, false}, Shape{2}, 0, Shape{4}, []bool{true, false, true, false}}, - {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, - {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, Shape{2, 2}, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + {"vector", Bool, []bool{true, false}, nil, Shape{2}, nil, 0, Shape{4}, []bool{true, false, true, false}}, + {"matrix; axis 0 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}}, + {"matrix; axis 1 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}}, + + // gorgonia/gorgonia#218 related + {"matrix; axis 0", Float64, nil, nil, Shape{2, 2}, Shape{1, 2}, 0, Shape{3, 2}, []float64{0, 1, 2, 3, 0, 1}}, + {"matrix; axis 1", Float64, nil, nil, Shape{2, 2}, Shape{2, 1}, 1, Shape{2, 3}, []float64{0, 1, 0, 2, 3, 1}}, + {"colvec matrix, axis 0", Float64, nil, nil, Shape{2, 1}, Shape{1, 1}, 0, Shape{3, 1}, []float64{0, 1, 0}}, + {"rowvec matrix, axis 1", Float64, nil, nil, Shape{1, 2}, Shape{1, 1}, 1, Shape{1, 3}, []float64{0, 1, 0}}, + + {"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}}, + {"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}}, + // {"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{222}}, } func TestDense_Concat(t *testing.T) { @@ -676,15 +688,24 @@ func TestDense_Concat(t *testing.T) { if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } assert.True(cts.correctShape.Eq(T2.Shape())) @@ -694,24 +715,31 @@ func TestDense_Concat(t *testing.T) { //Masked case for _, cts := range concatTests { - var T0, T1 *Dense if cts.a == nil { T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) T0.MaskedEqual(castToDt(0.0, cts.dt)) - T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) } else { T0 = New(WithShape(cts.shape...), WithBacking(cts.a)) T0.MaskedEqual(castToDt(0.0, cts.dt)) + } + + switch { + case cts.shapeB == nil && cts.a == nil: + T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize()))) + case cts.shapeB == nil && cts.a != nil: T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a))) - T1.MaskedEqual(castToDt(0.0, cts.dt)) + case cts.shapeB != nil && cts.b == nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize()))) + case cts.shapeB != nil && cts.b != nil: + T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b)) } + T1.MaskedEqual(castToDt(0.0, cts.dt)) T2, err := T0.Concat(cts.axis, T1) if err != nil { - t.Error(err) + t.Errorf("Test %v failed: %v", cts.name, err) continue } diff --git a/dense_norms.go b/dense_norms.go index ad75c0f..63d460a 100644 --- a/dense_norms.go +++ b/dense_norms.go @@ -94,8 +94,8 @@ func (t *Dense) Norm(ord NormOrder, axes ...int) (retVal *Dense, err error) { if len(axes) == 0 { if ord.IsUnordered() || (ord.IsFrobenius() && dims == 2) || (ord == Norm(2) && dims == 1) { backup := t.AP - ap := BorrowAP(1) - defer ReturnAP(ap) + ap := makeAP(1) + defer ap.zero() ap.unlock() ap.SetShape(t.Size()) diff --git a/dense_svd_test.go b/dense_svd_test.go index 36e4e16..89c5306 100644 --- a/dense_svd_test.go +++ b/dense_svd_test.go @@ -1,6 +1,7 @@ package tensor import ( + "fmt" "testing" "github.com/pkg/errors" @@ -103,6 +104,27 @@ func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) { return nil } +func Example_DenseSVD() { + T := New( + WithShape(4, 5), + WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}), + ) + _, u, _, _ := T.SVD(true, true) + uT := u.Clone().(*Dense) + uT.T() + eye, err := u.MatMul(uT) + fmt.Println(eye) + fmt.Println(err) + + // Output: + // ⎡1 0 0 0⎤ + // ⎢0 1 0 0⎥ + // ⎢0 0 1 0⎥ + // ⎣0 0 0 1⎦ + // + // +} + func TestDense_SVD(t *testing.T) { var T, T2, s, u, v *Dense var err error @@ -134,7 +156,6 @@ func TestDense_SVD(t *testing.T) { t.Errorf("Expected v = %v. Got %v instead", stts.correctVData, v.Data()) } } - // standard tests for i, stfs := range svdtestsFull { T = New(WithShape(stfs...), WithBacking(Random(Float64, stfs.TotalSize()))) @@ -143,14 +164,14 @@ func TestDense_SVD(t *testing.T) { // full if s, u, v, err = T.SVD(true, true); err != nil { t.Error(err) + fmt.Println(err) continue } - if err = testSVD(T, T2, s, u, v, "full", i); err != nil { t.Error(err) + fmt.Println(err) continue } - // thin if s, u, v, err = T.SVD(true, false); err != nil { t.Error(err) @@ -183,8 +204,8 @@ func TestDense_SVD(t *testing.T) { if !allClose(s.Data(), svd.Values(nil), closeenoughf64) { t.Errorf("Singular value mismatch between Full and None decomposition. Expected %v. Got %v instead", svd.Values(nil), s.Data()) } - } + } // this is illogical T = New(Of(Float64), WithShape(2, 2)) if _, _, _, err = T.SVD(false, true); err == nil { diff --git a/engine.go b/engine.go index af56f6b..9e3ede7 100644 --- a/engine.go +++ b/engine.go @@ -59,6 +59,11 @@ type arrayMaker interface { makeArray(arr *array, t Dtype, size int) } +// NonStdEngine are any engines that do not allocate using the default built in allocator +type NonStdEngine interface { + NonStdAlloc() // noop +} + /* Data Agnostic Execution Engine Methods */ // Transposer is any engine that can perform an unsafe transpose of a tensor. @@ -86,6 +91,11 @@ type Repeater interface { Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) } +// Diager is any engine that can return a tensor that only contains the diagonal values of the input +type Diager interface { + Diag(a Tensor) (Tensor, error) +} + /* NUMBER INTERFACES All these are expected to be unsafe on the first tensor */ @@ -369,6 +379,20 @@ type Argminer interface { Argmin(t Tensor, axis int) (Tensor, error) } +// NaNChecker checks that the tensor contains a NaN +// Errors are to be returned if the concept of NaN does not apply to the data type. +// Other errors may also occur. See specific implementations for details +type NaNChecker interface { + HasNaN(t Tensor) (bool, error) +} + +// InfChecker checks that the tensor contains a Inf. +// Errors are to be returned if the concept of Inf does not apply to the data type. +// Other errors may also occur. See specific implementations for details +type InfChecker interface { + HasInf(t Tensor) (bool, error) +} + /* Internal interfaces for faster shit */ type denseArgmaxer interface { diff --git a/example_dense_linalg_test.go b/example_dense_linalg_test.go new file mode 100644 index 0000000..d558481 --- /dev/null +++ b/example_dense_linalg_test.go @@ -0,0 +1,151 @@ +package tensor + +import ( + "fmt" +) + +func ExampleDense_MatMul() { + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 15), WithBacking(Range(Float64, 0, 150))) + T1 := New(WithShape(15, 10), WithBacking(Range(Float64, 150, 0))) + T2, err := MatMul(T0, T1) + handleErr(err) + + fmt.Printf("T2:\n%v", T2) + + // Output: + // T2: + // ⎡ 5600 5495 5390 5285 ... 4970 4865 4760 4655⎤ + // ⎢ 23600 23270 22940 22610 ... 21620 21290 20960 20630⎥ + // ⎢ 41600 41045 40490 39935 ... 38270 37715 37160 36605⎥ + // ⎢ 59600 58820 58040 57260 ... 54920 54140 53360 52580⎥ + // . + // . + // . + // ⎢113600 112145 110690 109235 ... 104870 103415 101960 100505⎥ + // ⎢131600 129920 128240 126560 ... 121520 119840 118160 116480⎥ + // ⎢149600 147695 145790 143885 ... 138170 136265 134360 132455⎥ + // ⎣167600 165470 163340 161210 ... 154820 152690 150560 148430⎦ + +} + +func ExampleDense_MatVecMul() { + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(2, 3), WithBacking(Range(Float64, 1, 7))) + T1 := New(WithShape(3), WithBacking(Range(Float64, 0, 3))) + T2, err := T0.MatVecMul(T1) + handleErr(err) + + fmt.Printf("T2:\n%v\n", T2) + + // Output: + // T2: + // [ 8 17] +} + +func ExampleDense_MatVecMul_rowMajorSliced() { + // ASPIRATIONAL TODO: IncX and incY of differering values + + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 12), WithBacking(Range(Float64, 1, 121))) + T1 := New(WithShape(3, 3), WithBacking(Range(Float64, 1, 10))) + T2, err := T0.Slice(makeRS(1, 3), makeRS(3, 6)) + handleErr(err) + T3, err := T1.Slice(nil, makeRS(1, 2)) + handleErr(err) + + // here the + formatting option is used because you should know that after this particular slice, the result will be a vector + fmt.Printf("T2:\n%+v", T2) + fmt.Printf("T3:\n%+v\n", T3) + + // here we print the underlying slice of T3 just to show that it's actually a much larger slice + fmt.Printf("Underlying Slice: %v\n", T3.Data()) + + T4, err := T2.(*Dense).MatVecMul(T3) + handleErr(err) + + fmt.Printf("T4:\n%v\n", T4) + + // Outputz: + // T2: + // Matrix (2, 3) [10 1] + // ⎡14 15 16⎤ + // ⎣24 25 26⎦ + // T3: + // Vector (3) [3] + // [2 5 8] + // Underlying Slice: [2 3 4 5 6 7 8] + // T4: + // [261 441] + +} + +func ExampleDense_MatMul_sliced() { + //ASPIRATIONAL TODO: incX and incY of different sizes + handleErr := func(err error) { + if err != nil { + panic(err) + } + } + + T0 := New(WithShape(10, 15), WithBacking(Range(Float64, 0, 150))) + T1 := New(WithShape(15, 10), WithBacking(Range(Float64, 150, 0))) + T2, err := MatMul(T0, T1) + handleErr(err) + + fmt.Printf("T2:\n%v", T2) + + // Slice T0 to only take a (2, 3) on the upper quadrant + // T3 := T0[0:3, 0:2] + T3, err := T0.Slice(makeRS(0, 3), makeRS(0, 2)) + handleErr(err) + fmt.Printf("T3:\n%v", T3) + + T4, err := T1.Slice(makeRS(13, 15), makeRS(8, 10)) + handleErr(err) + fmt.Printf("T4:\n%v", T4) + + T5, err := T3.(*Dense).MatMul(T4) + handleErr(err) + fmt.Printf("T3xT4:\n%v", T5) + + // Outputz: + // T2: + // ⎡ 5600 5495 5390 5285 ... 4970 4865 4760 4655⎤ + // ⎢ 23600 23270 22940 22610 ... 21620 21290 20960 20630⎥ + // ⎢ 41600 41045 40490 39935 ... 38270 37715 37160 36605⎥ + // ⎢ 59600 58820 58040 57260 ... 54920 54140 53360 52580⎥ + // . + // . + // . + // ⎢113600 112145 110690 109235 ... 104870 103415 101960 100505⎥ + // ⎢131600 129920 128240 126560 ... 121520 119840 118160 116480⎥ + // ⎢149600 147695 145790 143885 ... 138170 136265 134360 132455⎥ + // ⎣167600 165470 163340 161210 ... 154820 152690 150560 148430⎦ + // T3: + // ⎡ 0 1⎤ + // ⎢15 16⎥ + // ⎣30 31⎦ + // T4: + // ⎡12 11⎤ + // ⎣ 2 1⎦ + // T3xT4: + // ⎡ 2 1⎤ + // ⎢212 181⎥ + // ⎣422 361⎦ +} diff --git a/example_dense_matop_test.go b/example_dense_matop_test.go index 02b1dc8..97a2cb8 100644 --- a/example_dense_matop_test.go +++ b/example_dense_matop_test.go @@ -176,6 +176,8 @@ func ExampleDense_Vstack() { T3 = T1.Clone().(*Dense) if T2, err = T.Vstack(T1, T3); err == nil { fmt.Printf("T.Vstack(T1, T3):\n%v\n", T2) + } else { + fmt.Printf("====\nerr %v\n%v\n===\n", err, T3.Shape()) } // Let's look at failure conditions diff --git a/example_iterator_test.go b/example_iterator_test.go index 896a097..a6b31da 100644 --- a/example_iterator_test.go +++ b/example_iterator_test.go @@ -2,31 +2,51 @@ package tensor import "fmt" -func Example_iterator() { - fmt.Println("Row Major") - T := New(WithShape(2, 3), Of(Float64)) +// This is an example of how to use `IteratorFromDense` from a row-major Dense tensor +func Example_iteratorRowmajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) for i, err := it.Start(); err == nil; i, err = it.Next() { fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) } - /* - // FOR WHEN COL MAJOR IS SUPPORTED - fmt.Println("Col Major") - T = New(WithShape(2, 3), Of(Float64), AsFortran()) - it = IteratorFromDense(T) - - for i, err := it.Start(); err == nil; i, err = it.Next() { - fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) - } - */ // Output: - // Row Major + // T: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // // i: 0, coord: [0 1] // i: 1, coord: [0 2] // i: 2, coord: [1 0] // i: 3, coord: [1 1] // i: 4, coord: [1 2] // i: 5, coord: [0 0] + +} + +// This is an example of using `IteratorFromDense` on a col-major Dense tensor. More importantly +// this example shows the order of the iteration. +func Example_iteratorcolMajor() { + T := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + it := IteratorFromDense(T) + fmt.Printf("T:\n%v\n", T) + + for i, err := it.Start(); err == nil; i, err = it.Next() { + fmt.Printf("i: %d, coord: %v\n", i, it.Coord()) + } + + // Output: + // T: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // + // i: 0, coord: [0 1] + // i: 2, coord: [0 2] + // i: 4, coord: [1 0] + // i: 1, coord: [1 1] + // i: 3, coord: [1 2] + // i: 5, coord: [0 0] + } diff --git a/example_tensor_basics_test.go b/example_tensor_basics_test.go index d588a54..49c008b 100644 --- a/example_tensor_basics_test.go +++ b/example_tensor_basics_test.go @@ -2,16 +2,19 @@ package tensor import "fmt" +// This example showcases the very basics of the package. func Example_basics() { + // Create a (2, 2)-Matrix of integers a := New(WithShape(2, 2), WithBacking([]int{1, 2, 3, 4})) fmt.Printf("a:\n%v\n", a) + // Create a (2, 3, 4)-tensor of float32s b := New(WithBacking(Range(Float32, 0, 24)), WithShape(2, 3, 4)) fmt.Printf("b:\n%1.1f", b) // Accessing data x, _ := b.At(0, 1, 2) // in Numpy syntax: b[0,1,2] - fmt.Printf("x: %v\n\n", x) + fmt.Printf("x: %1.1f\n\n", x) // Setting data b.SetAt(float32(1000), 0, 1, 2) @@ -31,7 +34,7 @@ func Example_basics() { // ⎢16.0 17.0 18.0 19.0⎥ // ⎣20.0 21.0 22.0 23.0⎦ // - // x: 6 + // x: 6.0 // // b: // ⎡ 0 1 2 3⎤ @@ -42,3 +45,116 @@ func Example_basics() { // ⎢ 16 17 18 19⎥ // ⎣ 20 21 22 23⎦ } + +// This example showcases interactions between different data orders +func Example_differingDataOrders() { + T0 := New(WithShape(2, 3), WithBacking(Range(Int, 0, 6))) // Create a (2, 3)-matrix with the standard row-major backing + T1 := New(WithShape(2, 3), WithBacking(Range(Int, 0, 6)), AsFortran(nil)) // Create a (2, 3)-matrix with a col-major backing + T2, _ := Add(T0, T1) + fmt.Printf("T0:\n%vT1:\n%vT2:\n%vT2 Data Order: %v\n\n", T0, T1, T2, T2.DataOrder()) + + // the result's data order is highly dependent on the order of operation. It will take after the first operand + T0 = New(WithShape(2, 3), WithBacking(Range(Int, 1, 7)), AsFortran(nil)) // Create a (2, 3)-matrix with a col-major backing + T1 = New(WithShape(2, 3), WithBacking(Range(Int, 1, 7))) // Create a (2, 3)-matrix with the standard row-major backing + T2, _ = Add(T0, T1) + fmt.Printf("T0:\n%vT1:\n%vT2:\n%vT2 Data Order: %v\n\n", T0, T1, T2, T2.DataOrder()) + + reuse := New(WithShape(2, 3), WithBacking([]int{1000, 1000, 1000, 1000, 1000, 1000})) + fmt.Printf("reuse Data Order: %v\n", reuse.DataOrder()) + T2, _ = Add(T0, T1, WithReuse(reuse)) + fmt.Printf("T2:\n%vT2 Data Order: %v\n\n", T2, T2.DataOrder()) + + // Output: + // T0: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // T1: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // T2: + // ⎡ 0 3 6⎤ + // ⎣ 4 7 10⎦ + // T2 Data Order: Contiguous, RowMajor + // + // + // T0: + // ⎡1 3 5⎤ + // ⎣2 4 6⎦ + // T1: + // ⎡1 2 3⎤ + // ⎣4 5 6⎦ + // T2: + // ⎡ 2 5 8⎤ + // ⎣ 6 9 12⎦ + // T2 Data Order: Contiguous, ColMajor + // + // + // reuse Data Order: Contiguous, RowMajor + // T2: + // ⎡ 2 5 8⎤ + // ⎣ 6 9 12⎦ + // T2 Data Order: Contiguous, ColMajor + +} + +// The AsFortran construction option is a bit finnicky. +func Example_asFortran() { + // Here the data is passed in and directly used without changing the underlying data + T0 := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5}), AsFortran(nil)) + fmt.Printf("T0:\n%vData: %v\n\n", T0, T0.Data()) + + // Here the data is passed into the AsFortran construction option, and it assumes that the data is already in + // row-major form. Therefore a transpose will be performed. + T1 := New(WithShape(2, 3), AsFortran([]float64{0, 1, 2, 3, 4, 5})) + fmt.Printf("T1:\n%vData: %v\n\n", T1, T1.Data()) + + // Further example of how AsFortran works: + orig := New(WithShape(2, 3), WithBacking([]float64{0, 1, 2, 3, 4, 5})) + T2 := New(WithShape(2, 3), AsFortran(orig)) + fmt.Printf("Original\n%vData: %v\n", orig, orig.Data()) + fmt.Printf("T2:\n%vData: %v\n", T2, T2.Data()) + + // Output: + // T0: + // ⎡0 2 4⎤ + // ⎣1 3 5⎦ + // Data: [0 1 2 3 4 5] + // + // T1: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 3 1 4 2 5] + // + // Original + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 1 2 3 4 5] + // T2: + // ⎡0 1 2⎤ + // ⎣3 4 5⎦ + // Data: [0 3 1 4 2 5] +} + +// The AsDenseDiag construction option creates a dense diagonal matrix from the input, either a slice or a tensor. +// The resulting shape is automatically inferred from the input vector. +// +// This is like Numpy's `diag()` function, except not stupid. Numpy's `diag()` has been a cause of errors because it's somewhat isometric: +// >>> np.diag(np.diag(np.array([1,2,3]))) +// array([1,2,3]) +func Example_asDenseDiag() { + T := New(WithShape(3), WithBacking([]int{1, 2, 3})) + T1 := New(AsDenseDiag(T)) + fmt.Printf("T1:\n%v", T1) + + T2 := New(AsDenseDiag([]float64{3.14, 6.28, 11111})) + fmt.Printf("T2:\n%v", T2) + // Output: + // T1: + //⎡1 0 0⎤ + //⎢0 2 0⎥ + //⎣0 0 3⎦ + // T2: + // ⎡ 3.14 0 0⎤ + // ⎢ 0 6.28 0⎥ + // ⎣ 0 0 11111⎦ +} diff --git a/flags.go b/flags.go index dfe551e..e8a00d0 100644 --- a/flags.go +++ b/flags.go @@ -13,8 +13,13 @@ const ( // A data can either be Contiguous (0) or NonContiguous (2). // The way DataOrder was designed causes the default to be Contiguous. NonContiguous + + // Transposed indicates that the data has been transposed + Transposed ) +var dataOrderNames = []rune("NonContiguous, RowMajorᵀNonContiguous, ColMajorᵀ") + // MakeDataOrder makes a data order. Typical examples: // MakeDataOrder(DataOrder(0)) // Row Major, contiguous // MakeDataOrder(NonContiguous // Row Major, non-contiguous @@ -30,13 +35,47 @@ func MakeDataOrder(fs ...DataOrder) (retVal DataOrder) { return } -func (f DataOrder) isColMajor() bool { return (f & ColMajor) != 0 } -func (f DataOrder) isRowMajor() bool { return !f.isColMajor() } -func (f DataOrder) isContiguous() bool { return !f.isNotContiguous() } -func (f DataOrder) isNotContiguous() bool { return (f & NonContiguous) != 0 } +// IsColMajor returns true if the data order describes a col-major data +func (f DataOrder) IsColMajor() bool { return (f & ColMajor) != 0 } + +// IsRowMajor returns true if the data order describes a row-major data +func (f DataOrder) IsRowMajor() bool { return !f.IsColMajor() } + +// IsContiguous returns true if the data order describes a contiguous data. +func (f DataOrder) IsContiguous() bool { return !f.IsNotContiguous() } + +// IsNotContiguous returns true if the data order describes a noncontiguous data. +func (f DataOrder) IsNotContiguous() bool { return (f & NonContiguous) != 0 } + +// IsTransposed returns true if the data order describes whether the data has been tranposed (but not moved) +func (f DataOrder) IsTransposed() bool { return (f & Transposed) != 0 } + func (f DataOrder) toggleColMajor() DataOrder { return f ^ (ColMajor) } -func (f DataOrder) hasSameOrder(other DataOrder) bool { - return (f.isColMajor() && other.isColMajor()) || (f.isRowMajor() && other.isRowMajor()) + +func (f DataOrder) clearTransposed() DataOrder { return f &^ (Transposed) } + +func (f DataOrder) HasSameOrder(other DataOrder) bool { + return (f.IsColMajor() && other.IsColMajor()) || (f.IsRowMajor() && other.IsRowMajor()) +} + +func (f DataOrder) String() string { + var start, end int + if f.IsRowMajor() { + end = 23 + if f.IsContiguous() { + start = 3 + } + } else { + end = 47 + start = 24 + if f.IsContiguous() { + start = 27 + } + } + if f.IsTransposed() { + end++ + } + return string(dataOrderNames[start:end]) } // Triangle is a flag representing the "triangle"ness of a matrix diff --git a/flags_test.go b/flags_test.go index 98a8772..83dd3be 100644 --- a/flags_test.go +++ b/flags_test.go @@ -35,29 +35,56 @@ func TestMemoryFlag(t *testing.T) { func TestDataOrder(t *testing.T) { var defaultFlag DataOrder - if defaultFlag.isColMajor() || defaultFlag.isNotContiguous() { - t.Errorf("Expected default flag to be row major and contiguous") + if defaultFlag.IsColMajor() || defaultFlag.IsNotContiguous() || defaultFlag.IsTransposed() { + t.Error("Expected default flag to be row major and contiguous and not transposed") } - if !(defaultFlag.isRowMajor() && defaultFlag.isContiguous()) { - t.Errorf("Expected default flag to be row major and contiguous") + if !(defaultFlag.IsRowMajor() && defaultFlag.IsContiguous()) { + t.Error("Expected default flag to be row major and contiguous") + } + if defaultFlag.String() != "Contiguous, RowMajor" { + t.Errorf("Expected string is \"Contiguous, RowMajor\". Got %q", defaultFlag.String()) + } + + ncrm := MakeDataOrder(NonContiguous) + if ncrm.IsColMajor() || ncrm.IsContiguous() { + t.Error("Expected noncontiguous row major.") + } + if ncrm.String() != "NonContiguous, RowMajor" { + t.Errorf("Expected string is \"NonContiguous, RowMajor\". Got %q", defaultFlag.String()) } cm := ColMajor - if cm.isRowMajor() { - t.Errorf("colMajor cannot be rowMajor") + if cm.IsRowMajor() { + t.Error("colMajor cannot be rowMajor") + } + if cm.IsNotContiguous() { + t.Error("ColMajor by default is contiguous") } - if cm.isNotContiguous() { - t.Errorf("ColMajor by default is contiguous") + if cm.String() != "Contiguous, ColMajor" { + t.Errorf(`Expected string is "Contiguous, ColMajor". Got %q`, cm.String()) } // check toggle rm := cm.toggleColMajor() - if rm.isColMajor() { + if rm.IsColMajor() { t.Errorf("toggled cm should be rm") } cm = rm.toggleColMajor() - if cm.isRowMajor() { + if cm.IsRowMajor() { t.Errorf("toggled rm should be cm") } + + transposed := MakeDataOrder(Transposed) + if !transposed.IsTransposed() { + t.Error("Expected transposed flag to be set") + } + if transposed.String() != "Contiguous, RowMajorᵀ" { + t.Errorf("Expected string is \"Contiguous, RowMajorᵀ\". Got %q", defaultFlag.String()) + } + untransposed := transposed.clearTransposed() + if untransposed != defaultFlag { + t.Error("Expected default flag after untransposing") + } + } diff --git a/interfaces.go b/interfaces.go index 40be33d..c4a11c2 100644 --- a/interfaces.go +++ b/interfaces.go @@ -76,7 +76,6 @@ type DenseTensor interface { Tensor Info() *AP - DataOrder() DataOrder IsMatrix() bool IsVector() bool IsRowVec() bool @@ -89,6 +88,7 @@ type DenseTensor interface { rtype() reflect.Type reshape(dims ...int) error + setDataOrder(o DataOrder) isTransposed() bool ostrides() []int oshape() Shape diff --git a/internal/IDLs/generated.fbs b/internal/IDLs/generated.fbs new file mode 100644 index 0000000..47ffce2 --- /dev/null +++ b/internal/IDLs/generated.fbs @@ -0,0 +1,38 @@ +// Generated from generated.proto + +namespace gorgonia.org.tensor.internal.serialization.pb; + +enum Triangle : int { + NOT_TRIANGLE = 0, + UPPER = 1, + LOWER = 2, + SYMMETRIC = 3, +} + +table AP { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; +} + +table Dense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; +} + +table MaskedDense { + shape:[int]; + strides:[int]; + o:uint; + t:gorgonia.org.tensor.internal.serialization.pb.Triangle; + type:string; + data:[ubyte]; + mask:[bool]; + mask_is_soft:[bool]; +} + diff --git a/internal/IDLs/generated.proto b/internal/IDLs/generated.proto new file mode 100755 index 0000000..c737106 --- /dev/null +++ b/internal/IDLs/generated.proto @@ -0,0 +1,52 @@ +syntax = "proto3"; +package gorgonia.org.tensor.internal.serialization.pb; + +import "github.com/gogo/protobuf/gogoproto/gogo.proto"; + +option (gogoproto.protosizer_all) = true; +option (gogoproto.sizer_all) = false; +option go_package = "pb"; + +message AP { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; +} + +message Dense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; +} + +message MaskedDense { + option (gogoproto.goproto_getters) = false; + option (gogoproto.typedecl) = false; + repeated int32 shape = 1; + repeated int32 strides = 2; + uint32 o = 3 [(gogoproto.casttype) = "DataOrder"]; + gorgonia.org.tensor.internal.serialization.pb.Triangle t = 4; + string type = 5; + bytes data = 6; + repeated bool mask = 7; + repeated bool mask_is_soft = 8; +} + +enum Triangle { + option (gogoproto.enumdecl) = false; + option (gogoproto.goproto_enum_prefix) = false; + option (gogoproto.goproto_enum_stringer) = false; + NOT_TRIANGLE = 0 [(gogoproto.enumvalue_customname) = "NotTriangle"]; + UPPER = 1 [(gogoproto.enumvalue_customname) = "Upper"]; + LOWER = 2 [(gogoproto.enumvalue_customname) = "Lower"]; + SYMMETRIC = 3 [(gogoproto.enumvalue_customname) = "Symmetric"]; +} + diff --git a/internal/serialization/README.md b/internal/serialization/README.md new file mode 100644 index 0000000..d3d8149 --- /dev/null +++ b/internal/serialization/README.md @@ -0,0 +1,33 @@ +# Serialization # + +This pseudopackage of sorts handles serialization. The "Canonical" serialized data structure is found in the `pb` subdirectory. + +# Protobuf generation + +Proteus needs to be installed, as does its dependencies. + + +1. `cd pb` +2. `rm generated*` +3. `proteus -f ../../IDLs -p gorgonia.org/tensor/internal/serialization/pb` +4. `cd ../../IDLs` +5. `find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` +6. `rm -rf gorgonia.org` + + +# FlatBuffer generation +1. generate protobuf first +2. delete the `import "github.com/gogo/protobuf/gogoproto/gogo.proto";` line from the generated protobuf file +3. `flatc --proto PATH/TO/generated.proto` +4. place the `generated.fbs` file in the IDLs directory +4. restore the import line in the `generated.proto` file +5. From this directory: `flatc --go-namespace fb -g PATH/TO/generated.fbs` + + +# Notes # + +`find gorgonia.org/ -mindepth 2 -type f -exec mv -i '{}' . ';'` is used to flatten and put all the stuff in the root IDLs directory. + +# The Serialization Story # + +To serialize, we copy/convert/coerce the data to the internal/serialization data structures, then call the `Marshall` methods from there \ No newline at end of file diff --git a/internal/serialization/doc.go b/internal/serialization/doc.go new file mode 100644 index 0000000..c4cb59b --- /dev/null +++ b/internal/serialization/doc.go @@ -0,0 +1,2 @@ +// package serialization provides the data structures for serialization +package serialization diff --git a/internal/serialization/fb/AP.go b/internal/serialization/fb/AP.go new file mode 100644 index 0000000..b3ca806 --- /dev/null +++ b/internal/serialization/fb/AP.go @@ -0,0 +1,110 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type AP struct { + _tab flatbuffers.Table +} + +func GetRootAsAP(buf []byte, offset flatbuffers.UOffsetT) *AP { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &AP{} + x.Init(buf, n+offset) + return x +} + +func (rcv *AP) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *AP) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *AP) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *AP) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *AP) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *AP) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *AP) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func APStart(builder *flatbuffers.Builder) { + builder.StartObject(4) +} +func APAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func APStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func APStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func APAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func APAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func APEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Dense.go b/internal/serialization/fb/Dense.go new file mode 100644 index 0000000..2a961ee --- /dev/null +++ b/internal/serialization/fb/Dense.go @@ -0,0 +1,152 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type Dense struct { + _tab flatbuffers.Table +} + +func GetRootAsDense(buf []byte, offset flatbuffers.UOffsetT) *Dense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &Dense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *Dense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *Dense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *Dense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *Dense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *Dense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *Dense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *Dense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *Dense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *Dense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *Dense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func DenseStart(builder *flatbuffers.Builder) { + builder.StartObject(6) +} +func DenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func DenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func DenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func DenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func DenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func DenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func DenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func DenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func DenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/MaskedDense.go b/internal/serialization/fb/MaskedDense.go new file mode 100644 index 0000000..271e77e --- /dev/null +++ b/internal/serialization/fb/MaskedDense.go @@ -0,0 +1,198 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +import ( + flatbuffers "github.com/google/flatbuffers/go" +) + +type MaskedDense struct { + _tab flatbuffers.Table +} + +func GetRootAsMaskedDense(buf []byte, offset flatbuffers.UOffsetT) *MaskedDense { + n := flatbuffers.GetUOffsetT(buf[offset:]) + x := &MaskedDense{} + x.Init(buf, n+offset) + return x +} + +func (rcv *MaskedDense) Init(buf []byte, i flatbuffers.UOffsetT) { + rcv._tab.Bytes = buf + rcv._tab.Pos = i +} + +func (rcv *MaskedDense) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *MaskedDense) Shape(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) ShapeLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) Strides(j int) int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetInt32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *MaskedDense) StridesLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) O() uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.GetUint32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateO(n uint32) bool { + return rcv._tab.MutateUint32Slot(8, n) +} + +func (rcv *MaskedDense) T() int32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.GetInt32(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *MaskedDense) MutateT(n int32) bool { + return rcv._tab.MutateInt32Slot(10, n) +} + +func (rcv *MaskedDense) Type() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Data(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) DataLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) DataBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *MaskedDense) Mask(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoft(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *MaskedDense) MaskIsSoftLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func MaskedDenseStart(builder *flatbuffers.Builder) { + builder.StartObject(8) +} +func MaskedDenseAddShape(builder *flatbuffers.Builder, shape flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(shape), 0) +} +func MaskedDenseStartShapeVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddStrides(builder *flatbuffers.Builder, strides flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(strides), 0) +} +func MaskedDenseStartStridesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func MaskedDenseAddO(builder *flatbuffers.Builder, o uint32) { + builder.PrependUint32Slot(2, o, 0) +} +func MaskedDenseAddT(builder *flatbuffers.Builder, t int32) { + builder.PrependInt32Slot(3, t, 0) +} +func MaskedDenseAddType(builder *flatbuffers.Builder, type_ flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(type_), 0) +} +func MaskedDenseAddData(builder *flatbuffers.Builder, data flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(data), 0) +} +func MaskedDenseStartDataVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMask(builder *flatbuffers.Builder, mask flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(6, flatbuffers.UOffsetT(mask), 0) +} +func MaskedDenseStartMaskVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseAddMaskIsSoft(builder *flatbuffers.Builder, maskIsSoft flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(7, flatbuffers.UOffsetT(maskIsSoft), 0) +} +func MaskedDenseStartMaskIsSoftVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func MaskedDenseEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/internal/serialization/fb/Triangle.go b/internal/serialization/fb/Triangle.go new file mode 100644 index 0000000..599a06b --- /dev/null +++ b/internal/serialization/fb/Triangle.go @@ -0,0 +1,18 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +package fb + +const ( + TriangleNOT_TRIANGLE = 0 + TriangleUPPER = 1 + TriangleLOWER = 2 + TriangleSYMMETRIC = 3 +) + +var EnumNamesTriangle = map[int]string{ + TriangleNOT_TRIANGLE:"NOT_TRIANGLE", + TriangleUPPER:"UPPER", + TriangleLOWER:"LOWER", + TriangleSYMMETRIC:"SYMMETRIC", +} + diff --git a/internal/serialization/pb/dense.go b/internal/serialization/pb/dense.go new file mode 100644 index 0000000..950c3ff --- /dev/null +++ b/internal/serialization/pb/dense.go @@ -0,0 +1,45 @@ +package pb + +//proteus:generate +type DataOrder byte + +// the reason for spreading the states out is because proteaus cannot handle non-iota tates +const ( + RowMajorContiguous = iota + RowMajorNonContiguous + ColMajorContiguous + ColMajorNonContiguous +) + +//proteus:generate +type Triangle byte + +const ( + NotTriangle Triangle = iota + Upper + Lower + Symmetric +) + +//proteus:generate +type AP struct { + Shape []int32 + Strides []int32 + + O DataOrder + T Triangle +} + +//proteus:generate +type Dense struct { + AP + Type string // type name + Data []byte +} + +//proteus:generate +type MaskedDense struct { + Dense + Mask []bool + MaskIsSoft []bool +} diff --git a/internal/serialization/pb/generated.pb.go b/internal/serialization/pb/generated.pb.go new file mode 100644 index 0000000..831ce90 --- /dev/null +++ b/internal/serialization/pb/generated.pb.go @@ -0,0 +1,1457 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: gorgonia.org/tensor/internal/serialization/pb/generated.proto + +/* + Package pb is a generated protocol buffer package. + + It is generated from these files: + gorgonia.org/tensor/internal/serialization/pb/generated.proto + + It has these top-level messages: + AP + Dense + MaskedDense +*/ +package pb + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" +import _ "github.com/gogo/protobuf/gogoproto" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +var Triangle_name = map[int32]string{ + 0: "NOT_TRIANGLE", + 1: "UPPER", + 2: "LOWER", + 3: "SYMMETRIC", +} +var Triangle_value = map[string]int32{ + "NOT_TRIANGLE": 0, + "UPPER": 1, + "LOWER": 2, + "SYMMETRIC": 3, +} + +func (Triangle) EnumDescriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *AP) Reset() { *m = AP{} } +func (m *AP) String() string { return proto.CompactTextString(m) } +func (*AP) ProtoMessage() {} +func (*AP) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{0} } + +func (m *Dense) Reset() { *m = Dense{} } +func (m *Dense) String() string { return proto.CompactTextString(m) } +func (*Dense) ProtoMessage() {} +func (*Dense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{1} } + +func (m *MaskedDense) Reset() { *m = MaskedDense{} } +func (m *MaskedDense) String() string { return proto.CompactTextString(m) } +func (*MaskedDense) ProtoMessage() {} +func (*MaskedDense) Descriptor() ([]byte, []int) { return fileDescriptorGenerated, []int{2} } + +func init() { + proto.RegisterType((*AP)(nil), "gorgonia.org.tensor.internal.serialization.pb.AP") + proto.RegisterType((*Dense)(nil), "gorgonia.org.tensor.internal.serialization.pb.Dense") + proto.RegisterType((*MaskedDense)(nil), "gorgonia.org.tensor.internal.serialization.pb.MaskedDense") + proto.RegisterEnum("gorgonia.org.tensor.internal.serialization.pb.Triangle", Triangle_name, Triangle_value) +} +func (m *AP) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *AP) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA2 := make([]byte, len(m.Shape)*10) + var j1 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA2[j1] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j1++ + } + dAtA2[j1] = uint8(num) + j1++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j1)) + i += copy(dAtA[i:], dAtA2[:j1]) + } + if len(m.Strides) > 0 { + dAtA4 := make([]byte, len(m.Strides)*10) + var j3 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j3++ + } + dAtA4[j3] = uint8(num) + j3++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j3)) + i += copy(dAtA[i:], dAtA4[:j3]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + return i, nil +} + +func (m *Dense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Dense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA6 := make([]byte, len(m.Shape)*10) + var j5 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA6[j5] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j5++ + } + dAtA6[j5] = uint8(num) + j5++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j5)) + i += copy(dAtA[i:], dAtA6[:j5]) + } + if len(m.Strides) > 0 { + dAtA8 := make([]byte, len(m.Strides)*10) + var j7 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA8[j7] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j7++ + } + dAtA8[j7] = uint8(num) + j7++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j7)) + i += copy(dAtA[i:], dAtA8[:j7]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + return i, nil +} + +func (m *MaskedDense) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *MaskedDense) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Shape) > 0 { + dAtA10 := make([]byte, len(m.Shape)*10) + var j9 int + for _, num1 := range m.Shape { + num := uint64(num1) + for num >= 1<<7 { + dAtA10[j9] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j9++ + } + dAtA10[j9] = uint8(num) + j9++ + } + dAtA[i] = 0xa + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j9)) + i += copy(dAtA[i:], dAtA10[:j9]) + } + if len(m.Strides) > 0 { + dAtA12 := make([]byte, len(m.Strides)*10) + var j11 int + for _, num1 := range m.Strides { + num := uint64(num1) + for num >= 1<<7 { + dAtA12[j11] = uint8(uint64(num)&0x7f | 0x80) + num >>= 7 + j11++ + } + dAtA12[j11] = uint8(num) + j11++ + } + dAtA[i] = 0x12 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(j11)) + i += copy(dAtA[i:], dAtA12[:j11]) + } + if m.O != 0 { + dAtA[i] = 0x18 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.O)) + } + if m.T != 0 { + dAtA[i] = 0x20 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(m.T)) + } + if len(m.Type) > 0 { + dAtA[i] = 0x2a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Type))) + i += copy(dAtA[i:], m.Type) + } + if len(m.Data) > 0 { + dAtA[i] = 0x32 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Data))) + i += copy(dAtA[i:], m.Data) + } + if len(m.Mask) > 0 { + dAtA[i] = 0x3a + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.Mask))) + for _, b := range m.Mask { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + if len(m.MaskIsSoft) > 0 { + dAtA[i] = 0x42 + i++ + i = encodeVarintGenerated(dAtA, i, uint64(len(m.MaskIsSoft))) + for _, b := range m.MaskIsSoft { + if b { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i++ + } + } + return i, nil +} + +func encodeVarintGenerated(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *AP) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + return n +} + +func (m *Dense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + return n +} + +func (m *MaskedDense) ProtoSize() (n int) { + var l int + _ = l + if len(m.Shape) > 0 { + l = 0 + for _, e := range m.Shape { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if len(m.Strides) > 0 { + l = 0 + for _, e := range m.Strides { + l += sovGenerated(uint64(e)) + } + n += 1 + sovGenerated(uint64(l)) + l + } + if m.O != 0 { + n += 1 + sovGenerated(uint64(m.O)) + } + if m.T != 0 { + n += 1 + sovGenerated(uint64(m.T)) + } + l = len(m.Type) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + l = len(m.Data) + if l > 0 { + n += 1 + l + sovGenerated(uint64(l)) + } + if len(m.Mask) > 0 { + n += 1 + sovGenerated(uint64(len(m.Mask))) + len(m.Mask)*1 + } + if len(m.MaskIsSoft) > 0 { + n += 1 + sovGenerated(uint64(len(m.MaskIsSoft))) + len(m.MaskIsSoft)*1 + } + return n +} + +func sovGenerated(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozGenerated(x uint64) (n int) { + return sovGenerated(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *AP) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: AP: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: AP: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Dense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Dense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Dense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *MaskedDense) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: MaskedDense: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: MaskedDense: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Shape = append(m.Shape, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Shape", wireType) + } + case 2: + if wireType == 0 { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int32 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Strides = append(m.Strides, v) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Strides", wireType) + } + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field O", wireType) + } + m.O = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.O |= (DataOrder(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field T", wireType) + } + m.T = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.T |= (Triangle(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + case 5: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Type = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + byteLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...) + if m.Data == nil { + m.Data = []byte{} + } + iNdEx = postIndex + case 7: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.Mask = append(m.Mask, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field Mask", wireType) + } + case 8: + if wireType == 0 { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } else if wireType == 2 { + var packedLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + packedLen |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if packedLen < 0 { + return ErrInvalidLengthGenerated + } + postIndex := iNdEx + packedLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + for iNdEx < postIndex { + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowGenerated + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + m.MaskIsSoft = append(m.MaskIsSoft, bool(v != 0)) + } + } else { + return fmt.Errorf("proto: wrong wireType = %d for field MaskIsSoft", wireType) + } + default: + iNdEx = preIndex + skippy, err := skipGenerated(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthGenerated + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipGenerated(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthGenerated + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowGenerated + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipGenerated(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthGenerated = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowGenerated = fmt.Errorf("proto: integer overflow") +) + +func init() { + proto.RegisterFile("gorgonia.org/tensor/internal/serialization/pb/generated.proto", fileDescriptorGenerated) +} + +var fileDescriptorGenerated = []byte{ + // 482 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xd4, 0x93, 0x4f, 0x6b, 0x13, 0x41, + 0x18, 0xc6, 0x77, 0x92, 0x6e, 0x9b, 0x4c, 0x13, 0x0d, 0x43, 0x0f, 0x43, 0x94, 0xcd, 0xd8, 0xd3, + 0x22, 0x74, 0x17, 0xf4, 0x20, 0x08, 0x1e, 0x5a, 0x1b, 0x24, 0x90, 0x7f, 0x4c, 0x52, 0x44, 0x2f, + 0x61, 0xb6, 0x3b, 0xd9, 0x0e, 0x4d, 0x76, 0x96, 0x99, 0x29, 0x52, 0xef, 0x42, 0xcd, 0x27, 0xf0, + 0x12, 0xa8, 0xda, 0x83, 0x1f, 0xc3, 0xa3, 0x17, 0xc1, 0x4f, 0x20, 0x92, 0x7e, 0x01, 0xcf, 0x9e, + 0x64, 0x27, 0x44, 0xe2, 0xd1, 0x9b, 0x3d, 0xcd, 0xf3, 0xfc, 0x66, 0x9e, 0x77, 0xde, 0x97, 0x61, + 0xe0, 0x93, 0x44, 0xaa, 0x44, 0xa6, 0x82, 0x05, 0x52, 0x25, 0xa1, 0xe1, 0xa9, 0x96, 0x2a, 0x14, + 0xa9, 0xe1, 0x2a, 0x65, 0x93, 0x50, 0x73, 0x25, 0xd8, 0x44, 0xbc, 0x66, 0x46, 0xc8, 0x34, 0xcc, + 0xa2, 0x30, 0xe1, 0x29, 0x57, 0xcc, 0xf0, 0x38, 0xc8, 0x94, 0x34, 0x12, 0xed, 0xad, 0xc7, 0x83, + 0x65, 0x3c, 0x58, 0xc5, 0x83, 0xbf, 0xe2, 0x41, 0x16, 0xd5, 0xf7, 0x12, 0x61, 0x4e, 0xce, 0xa2, + 0xe0, 0x58, 0x4e, 0xc3, 0x44, 0x26, 0x32, 0xb4, 0x55, 0xa2, 0xb3, 0xb1, 0x75, 0xd6, 0x58, 0xb5, + 0xac, 0xbe, 0xfb, 0x01, 0xc0, 0xc2, 0x7e, 0x1f, 0xed, 0x40, 0x57, 0x9f, 0xb0, 0x8c, 0x63, 0x40, + 0x8a, 0xbe, 0x4b, 0x97, 0x06, 0x61, 0xb8, 0xa5, 0x8d, 0x12, 0x31, 0xd7, 0xb8, 0x60, 0xf9, 0xca, + 0xa2, 0x3b, 0x10, 0x48, 0x5c, 0x24, 0xc0, 0xaf, 0x1e, 0x54, 0x7f, 0x7d, 0x6f, 0x94, 0x0f, 0x99, + 0x61, 0x3d, 0x15, 0x73, 0x45, 0x81, 0x44, 0x4d, 0x08, 0x0c, 0xde, 0x20, 0xc0, 0xbf, 0xf5, 0xe0, + 0x51, 0xf0, 0x4f, 0xdd, 0x07, 0x43, 0x25, 0x58, 0x9a, 0x4c, 0x38, 0x05, 0xe6, 0x71, 0xe9, 0xe2, + 0xb2, 0xe1, 0xfc, 0x7c, 0xdf, 0x70, 0x76, 0xbf, 0x02, 0xe8, 0x1e, 0xf2, 0x54, 0xf3, 0xff, 0xb1, + 0x4f, 0x84, 0xe0, 0x86, 0x39, 0xcf, 0x38, 0x76, 0x09, 0xf0, 0xcb, 0xd4, 0xea, 0x9c, 0xc5, 0xcc, + 0x30, 0xbc, 0x49, 0x80, 0x5f, 0xa1, 0x56, 0xaf, 0xcd, 0xf3, 0xb6, 0x00, 0xb7, 0x3b, 0x4c, 0x9f, + 0xf2, 0xf8, 0xc6, 0x4f, 0x95, 0xb3, 0x29, 0xd3, 0xa7, 0x78, 0x8b, 0x14, 0xfd, 0x12, 0xb5, 0x1a, + 0x11, 0x58, 0xc9, 0xd7, 0x91, 0xd0, 0x23, 0x2d, 0xc7, 0x06, 0x97, 0xec, 0x1e, 0xcc, 0x59, 0x4b, + 0x0f, 0xe4, 0x78, 0xed, 0x6d, 0xef, 0xbf, 0x01, 0xb0, 0xb4, 0xba, 0x17, 0xdd, 0x83, 0x95, 0x6e, + 0x6f, 0x38, 0x1a, 0xd2, 0xd6, 0x7e, 0xf7, 0x59, 0xbb, 0x59, 0x73, 0xea, 0xb7, 0x67, 0x73, 0xb2, + 0xdd, 0x95, 0xe6, 0xcf, 0x91, 0x1d, 0xe8, 0x1e, 0xf5, 0xfb, 0x4d, 0x5a, 0x03, 0xf5, 0xf2, 0x6c, + 0x4e, 0xdc, 0xa3, 0x2c, 0xe3, 0x2a, 0xa7, 0xed, 0xde, 0xf3, 0x26, 0xad, 0x15, 0x96, 0xb4, 0x2d, + 0x5f, 0x71, 0x85, 0xee, 0xc2, 0xf2, 0xe0, 0x45, 0xa7, 0xd3, 0x1c, 0xd2, 0xd6, 0xd3, 0x5a, 0xb1, + 0x5e, 0x9d, 0xcd, 0x49, 0x79, 0x70, 0x3e, 0x9d, 0x72, 0xa3, 0xc4, 0x71, 0xbd, 0x72, 0xf1, 0xd1, + 0x73, 0x3e, 0x5d, 0x79, 0xce, 0xe7, 0x2b, 0xcf, 0x39, 0xc0, 0x5f, 0x16, 0x1e, 0xf8, 0xb6, 0xf0, + 0xc0, 0x8f, 0x85, 0xe7, 0xbc, 0xbb, 0xf6, 0x9c, 0xcb, 0x6b, 0x0f, 0xbc, 0x2c, 0x64, 0x51, 0xb4, + 0x69, 0x7f, 0xca, 0xc3, 0xdf, 0x01, 0x00, 0x00, 0xff, 0xff, 0x0f, 0xff, 0xbb, 0x8f, 0xc8, 0x03, + 0x00, 0x00, +} diff --git a/iterator.go b/iterator.go index 70fa810..0db1158 100644 --- a/iterator.go +++ b/iterator.go @@ -1,6 +1,8 @@ package tensor -import "runtime" +import ( + "runtime" +) func requiresOrderedIterator(e Engine, t Tensor) bool { if t.IsScalar() { @@ -70,7 +72,7 @@ func NewIterator(aps ...*AP) Iterator { case 0: return nil case 1: - return NewFlatIterator(aps[0]) + return newFlatIterator(aps[0]) default: return NewMultIterator(aps...) } @@ -111,8 +113,8 @@ func iteratorLoadAP(it Iterator, ap *AP) { /* FLAT ITERATOR */ -// FlatIterator is an iterator that iterates over Tensors. It utilizes the *AP -// of a Tensor to determine what the next index is. +// FlatIterator is an iterator that iterates over Tensors according to the data's layout. +// It utilizes the *AP of a Tensor to determine what the next index is. // This data structure is similar to Numpy's flatiter, with some standard Go based restrictions of course // (such as, not allowing negative indices) type FlatIterator struct { @@ -129,16 +131,20 @@ type FlatIterator struct { isScalar bool isVector bool + + outerFirst bool } -// NewFlatIterator creates a new FlatIterator. -func NewFlatIterator(ap *AP) *FlatIterator { +// newFlatIterator creates a new FlatIterator. +func newFlatIterator(ap *AP) *FlatIterator { var strides0 int - if ap.IsVector() { - strides0 = ap.strides[0] - } else if ap.o.isColMajor() { + + if len(ap.strides) == 1 { strides0 = ap.strides[0] } + // else if ap.o.isColMajor() { + // strides0 = ap.strides[len(ap.strides)-1] + // } return &FlatIterator{ AP: ap, @@ -147,13 +153,13 @@ func NewFlatIterator(ap *AP) *FlatIterator { strides0: strides0, isScalar: ap.IsScalar(), - isVector: ap.IsVector(), + isVector: len(ap.strides) == 1, } } // FlatIteratorFromDense creates a new FlatIterator from a dense tensor func FlatIteratorFromDense(tt DenseTensor) *FlatIterator { - return NewFlatIterator(tt.Info()) + return newFlatIterator(tt.Info()) } // SetReverse initializes iterator to run backwards @@ -200,6 +206,9 @@ func (it *FlatIterator) Next() (int, error) { if it.reverse { return it.ndPrevious() } + if it.outerFirst { + return it.colMajorNDNext() + } return it.ndNext() } } @@ -233,6 +242,11 @@ func (it *FlatIterator) NextValid() (int, int, error) { a, err := it.ndPrevious() return a, -1, err } + + if it.outerFirst { + a, err := it.colMajorNDNext() + return a, 1, err + } a, err := it.ndNext() return a, 1, err } @@ -293,7 +307,6 @@ func (it *FlatIterator) singlePrevious() (int, error) { if tracked < 0 { it.done = true } - return it.lastIndex, nil } @@ -332,7 +345,39 @@ func (it *FlatIterator) ndNext() (int, error) { } func (it *FlatIterator) colMajorNDNext() (int, error) { - return 0, nil + // the reason for this weird looking bits of code is because the SSA compiler doesn't + // know how to optimize for this bit of code, not keeping things in registers correctly + // @stuartcarnie optimized this iout to great effect + + v := len(it.shape) - 1 + nextIndex := it.nextIndex + it.lastIndex = nextIndex + + // the following 3 lines causes the compiler to perform bounds check here, + // instead of being done in the loop + coord := it.shape[:v+1] + track := it.track[:v+1] + strides := it.strides[:v+1] + for i := 0; i <= v; i++ { + track[i]++ + shapeI := coord[i] + strideI := strides[i] + + if track[i] == shapeI { + if i == v { + it.done = true + } + track[i] = 0 + + nextIndex -= (shapeI - 1) * strideI + continue + } + nextIndex += strideI + break + } + it.nextIndex = nextIndex + return it.lastIndex, nil + } func (it *FlatIterator) ndPrevious() (int, error) { @@ -353,6 +398,7 @@ func (it *FlatIterator) ndPrevious() (int, error) { return it.lastIndex, nil } +// TODO v0.9.0 func (it *FlatIterator) colMajorNDPrevious() (int, error) { return 0, nil } @@ -424,10 +470,12 @@ func (it *FlatIterator) Reset() { switch { case it.IsScalar(): it.nextIndex = 0 - case it.IsRowVec(): - it.nextIndex = (it.shape[1] - 1) * it.strides[0] - case it.IsColVec(), it.IsVector(): + case it.isVector: it.nextIndex = (it.shape[0] - 1) * it.strides[0] + // case it.IsRowVec(): + // it.nextIndex = (it.shape[1] - 1) * it.strides[1] + // case it.IsColVec(): + // it.nextIndex = (it.shape[0] - 1) * it.strides[0] default: it.nextIndex = 0 for i := range it.track { diff --git a/iterator_mult.go b/iterator_mult.go index a0af458..74f9a4b 100644 --- a/iterator_mult.go +++ b/iterator_mult.go @@ -97,16 +97,19 @@ func NewMultIterator(aps ...*AP) *MultIterator { ReturnInts(apStrides) // Borrowed in BroadcastStrides but returned here - dangerous pattern? nBlocks++ } - ap2 := NewAP(it.shape[:maxDims], it.strides[offset:offset+maxDims]) - ap2.o = ap.o - ap2.Δ = ap.Δ - + ap2 := MakeAP(it.shape[:maxDims], it.strides[offset:offset+maxDims], ap.o, ap.Δ) it.whichBlock[i] = f - it.fitArr[nBlocks-1] = NewFlatIterator(ap2) + it.fitArr[nBlocks-1] = newFlatIterator(&ap2) } it.fitArr = it.fitArr[:nBlocks] it.strides = it.strides[:nBlocks*maxDims] + // fill 0s with 1s + for i := range it.strides { + if it.strides[i] == 0 { + it.strides[i] = 1 + } + } it.fit0 = it.fitArr[0] for _, f := range it.fitArr { @@ -120,7 +123,7 @@ func NewMultIterator(aps ...*AP) *MultIterator { // MultIteratorFromDense creates a new MultIterator from a list of dense tensors func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { - aps := BorrowAPList(len(tts)) + aps := make([]*AP, len(tts)) hasMask := BorrowBools(len(tts)) defer ReturnBools(hasMask) @@ -155,7 +158,6 @@ func MultIteratorFromDense(tts ...DenseTensor) *MultIterator { } } it.numMasked = numMasked - ReturnAPList(aps) return it } @@ -221,7 +223,9 @@ func (it *MultIterator) Next() (int, error) { } it.done = false for _, f := range it.fitArr { - f.Next() + if _, err := f.Next(); err != nil { + return -1, err + } it.done = it.done || f.done } for i, j := range it.whichBlock { diff --git a/iterator_test.go b/iterator_test.go index 1d7f170..d0ca6de 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -6,6 +6,12 @@ import ( "github.com/stretchr/testify/assert" ) +// newAP is a helper function now +func newAP(shape Shape, strides []int) *AP { + ap := MakeAP(shape, strides, 0, 0) + return &ap +} + var flatIterTests1 = []struct { shape Shape strides []int @@ -14,8 +20,8 @@ var flatIterTests1 = []struct { }{ {ScalarShape(), []int{}, []int{0}}, // scalar {Shape{5}, []int{1}, []int{0, 1, 2, 3, 4}}, // vector - {Shape{5, 1}, []int{1}, []int{0, 1, 2, 3, 4}}, // colvec - {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec + {Shape{5, 1}, []int{1, 1}, []int{0, 1, 2, 3, 4}}, // colvec + {Shape{1, 5}, []int{5, 1}, []int{0, 1, 2, 3, 4}}, // rowvec {Shape{2, 3}, []int{3, 1}, []int{0, 1, 2, 3, 4, 5}}, // basic mat {Shape{3, 2}, []int{1, 3}, []int{0, 3, 1, 4, 2, 5}}, // basic mat, transposed {Shape{2}, []int{2}, []int{0, 2}}, // basic 2x2 mat, sliced: Mat[:, 1] @@ -27,6 +33,11 @@ var flatIterTests1 = []struct { {Shape{4, 2, 3}, []int{1, 12, 4}, []int{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}}, // basic 3-Tensor (under (2, 0, 1) transpose) {Shape{3, 2, 4}, []int{4, 12, 1}, []int{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}}, // basic 3-Tensor (under (1, 0, 2) transpose) {Shape{4, 3, 2}, []int{1, 4, 12}, []int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}}, // basic 3-Tensor (under (2, 1, 0) transpose) + + // ARTIFICIAL CASES - TODO + // These cases should be impossible to reach in normal operation + // You would have to specially construct these + // {Shape{1, 5}, []int{1}, []int{0, 1, 2, 3, 4}}, // rowvec - NEARLY IMPOSSIBLE CASE- TODO } var flatIterSlices = []struct { @@ -49,8 +60,8 @@ func TestFlatIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) } @@ -73,8 +84,8 @@ func TestFlatIteratorReverse(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) it.SetReverse() for next, err := it.Next(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -108,7 +119,7 @@ func TestMultIterator(t *testing.T) { for i, fit := range flatIterTests1 { nexts[0] = nexts[0][:0] err = nil - ap[0] = NewAP(fit.shape, fit.strides) + ap[0] = newAP(fit.shape, fit.strides) it = NewMultIterator(ap[0]) if reverse { it.SetReverse() @@ -124,43 +135,45 @@ func TestMultIterator(t *testing.T) { nexts[0][i], nexts[0][j] = nexts[0][j], nexts[0][i] } } - assert.Equal(fit.correct, nexts[0], "Repeating flat test %d", i) + assert.Equal(fit.correct, nexts[0], "Repeating flat test %d. Reverse? %v", i, reverse) } // Test multiple iterators simultaneously - var choices = []int{0, 0, 9, 9, 0, 9} - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - nexts[j] = nexts[j][:0] - err = nil - ap[j] = NewAP(fit.shape, fit.strides) - } - it = NewMultIterator(ap...) - if reverse { - it.SetReverse() - } - for _, err := it.Next(); err == nil; _, err = it.Next() { + /* + var choices = []int{0, 0, 9, 9, 0, 9} for j := 0; j < 6; j++ { - nexts[j] = append(nexts[j], it.LastIndex(j)) + fit := flatIterTests1[choices[j]] + nexts[j] = nexts[j][:0] + err = nil + ap[j] = newAP(fit.shape, fit.strides) } - - if _, ok := err.(NoOpError); err != nil && !ok { - t.Error(err) + it = NewMultIterator(ap...) + if reverse { + it.SetReverse() } - } + for _, err := it.Next(); err == nil; _, err = it.Next() { + for j := 0; j < 6; j++ { + nexts[j] = append(nexts[j], it.LastIndex(j)) + } - for j := 0; j < 6; j++ { - fit := flatIterTests1[choices[j]] - if reverse { - for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { - nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + if _, ok := err.(NoOpError); err != nil && !ok { + t.Error(err) } } - if ap[j].IsScalar() { - assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) - } else { - assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + + for j := 0; j < 6; j++ { + fit := flatIterTests1[choices[j]] + if reverse { + for i, k := 0, len(nexts[j])-1; i < k; i, k = i+1, k-1 { + nexts[j][i], nexts[j][k] = nexts[j][k], nexts[j][i] + } + } + if ap[j].IsScalar() { + assert.Equal(fit.correct, nexts[j][:1], "Test multiple iterators %d", j) + } else { + assert.Equal(fit.correct, nexts[j], "Test multiple iterators %d", j) + } } - } + */ } } @@ -177,7 +190,7 @@ func TestIteratorInterface(t *testing.T) { for i, fit := range flatIterTests1 { nexts = nexts[:0] err = nil - ap = NewAP(fit.shape, fit.strides) + ap = newAP(fit.shape, fit.strides) it = NewIterator(ap) for next, err := it.Start(); err == nil; next, err = it.Next() { nexts = append(nexts, next) @@ -223,8 +236,8 @@ func TestFlatIterator_Chan(t *testing.T) { // basic stuff for i, fit := range flatIterTests1 { nexts = nexts[:0] - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) ch := it.Chan() for next := range ch { nexts = append(nexts, next) @@ -242,8 +255,8 @@ func TestFlatIterator_Slice(t *testing.T) { var nexts []int for i, fit := range flatIterTests1 { - ap = NewAP(fit.shape, fit.strides) - it = NewFlatIterator(ap) + ap = newAP(fit.shape, fit.strides) + it = newFlatIterator(ap) nexts, err = it.Slice(nil) if _, ok := err.(NoOpError); err != nil && !ok { t.Error(err) @@ -276,8 +289,8 @@ func TestFlatIterator_Coord(t *testing.T) { // var nexts []int var donecount int - ap = NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it = NewFlatIterator(ap) + ap = newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it = newFlatIterator(ap) var correct = [][]int{ {0, 0, 1}, @@ -315,8 +328,8 @@ func TestFlatIterator_Coord(t *testing.T) { // really this is just for completeness sake func TestFlatIterator_Reset(t *testing.T) { assert := assert.New(t) - ap := NewAP(Shape{2, 3, 4}, []int{12, 4, 1}) - it := NewFlatIterator(ap) + ap := newAP(Shape{2, 3, 4}, []int{12, 4, 1}) + it := newFlatIterator(ap) it.Next() it.Next() @@ -349,7 +362,7 @@ type oldFlatIterator struct { done bool } -// NewFlatIterator creates a new FlatIterator +// newFlatIterator creates a new FlatIterator func newOldFlatIterator(ap *AP) *oldFlatIterator { return &oldFlatIterator{ AP: ap, @@ -406,7 +419,7 @@ func BenchmarkOldFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := newOldFlatIterator(ap) for n := 0; n < b.N; n++ { @@ -426,8 +439,8 @@ func BenchmarkFlatIterator(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it := NewFlatIterator(ap) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it := newFlatIterator(ap) for n := 0; n < b.N; n++ { for _, err := it.Next(); err == nil; _, err = it.Next() { @@ -450,8 +463,8 @@ func BenchmarkFlatIteratorParallel6(b *testing.B) { it := make([]*FlatIterator, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) - it[j] = NewFlatIterator(ap[j]) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + it[j] = newFlatIterator(ap[j]) } for n := 0; n < b.N; n++ { @@ -476,7 +489,7 @@ func BenchmarkFlatIteratorMulti1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewMultIterator(ap) @@ -496,7 +509,7 @@ func BenchmarkFlatIteratorGeneric1(b *testing.B) { // as if T = NewTensor(WithShape(30, 1000, 1000)) // then T[:, 0:900:15, 250:750:50] - ap := NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap := newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) it := NewIterator(ap) @@ -519,7 +532,7 @@ func BenchmarkFlatIteratorMulti6(b *testing.B) { ap := make([]*AP, 6) for j := 0; j < 6; j++ { - ap[j] = NewAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) + ap[j] = newAP(Shape{30, 60, 10}, []int{1000000, 15000, 50}) } it := NewMultIterator(ap...) diff --git a/native/example_test.go b/native/example_test.go index 94b324a..740d103 100644 --- a/native/example_test.go +++ b/native/example_test.go @@ -6,8 +6,9 @@ import ( . "gorgonia.org/tensor" ) -// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels) -// NativeIterators are useful for this purpose. +// There are times where it is more effective to use native Go slice semantics to do work (for example, when performing batch work over kernels). +// Iterators are useful for this purpose. This package provides iterators for the standard types +// However, custom types are also available. See Vector, Matrix and Tensor3 examples. func Example_iterator() { var T *Dense T = New(WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) @@ -26,7 +27,7 @@ func Example_iterator() { } // The NativeSelect function squashes the dimensions, and returns an iterator in native Go slice semantics. -func Exampleselect() { +func Example_select() { // Selection is a bit of an interesting use case. Sometimes you don't want to iterate through the layers. // // For example, in a number of use cases where you have a 4-Tensor, you'd typically reshape it to some diff --git a/native/generic.go b/native/generic.go new file mode 100644 index 0000000..79d8dc3 --- /dev/null +++ b/native/generic.go @@ -0,0 +1,72 @@ +package native + +import ( + "reflect" + "unsafe" + + . "gorgonia.org/tensor" +) + +func Vector(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 1, t.Dtype()); err != nil { + return nil, err + } + return t.Data(), nil +} + +func Matrix(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 2, t.Dtype()); err != nil { + return nil, err + } + + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + rows := shape[0] + cols := shape[1] + rowStride := strides[0] + + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + ptr := t.Uintptr() + for i := 0; i < rows; i++ { + e := retVal.Index(i) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + return retVal.Interface(), nil +} + +func Tensor3(t *Dense) (interface{}, error) { + if err := checkNativeIterable(t, 3, t.Dtype()); err != nil { + return nil, err + } + shape := t.Shape() + strides := t.Strides() + typ := t.Dtype().Type + + layers := shape[0] + rows := shape[1] + cols := shape[2] + layerStride := strides[0] + rowStride := strides[1] + retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(reflect.SliceOf(typ))), layers, layers) + ptr := t.Uintptr() + for i := 0; i < layers; i++ { + el := retVal.Index(i) + inner := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) + for j := 0; j < rows; j++ { + e := inner.Index(j) + sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) + sh.Data = uintptr(i*layerStride+j*rowStride)*typ.Size() + ptr + sh.Len = cols + sh.Cap = cols + } + sh := (*reflect.SliceHeader)(unsafe.Pointer(el.Addr().Pointer())) + sh.Data = inner.Index(0).Addr().Pointer() + sh.Len = rows + sh.Cap = rows + } + return retVal.Interface(), nil +} diff --git a/native/generic_test.go b/native/generic_test.go new file mode 100644 index 0000000..cf09802 --- /dev/null +++ b/native/generic_test.go @@ -0,0 +1,67 @@ +package native_test + +import ( + "fmt" + + "gorgonia.org/tensor" + . "gorgonia.org/tensor/native" +) + +type MyType int + +func Example_vector() { + backing := []MyType{ + 0, 1, 2, 3, + } + T := tensor.New(tensor.WithShape(4), tensor.WithBacking(backing)) + val, err := Vector(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([]MyType) + fmt.Println(it) + + // Output: + // [0 1 2 3] +} + +func Example_matrix() { + backing := []MyType{ + 0, 1, + 2, 3, + 4, 5, + } + T := tensor.New(tensor.WithShape(3, 2), tensor.WithBacking(backing)) + val, err := Matrix(T) + if err != nil { + fmt.Printf("error: %v", err) + } + + it := val.([][]MyType) + fmt.Println(it) + + // Output: + // [[0 1] [2 3] [4 5]] +} + +func Example_tensor3() { + backing := []MyType{ + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + } + T := tensor.New(tensor.WithShape(2, 3, 4), tensor.WithBacking(backing)) + val, err := Tensor3(T) + if err != nil { + fmt.Printf("error: %v", err) + } + it := val.([][][]MyType) + fmt.Println(it) + + //Output: + // [[[0 1 2 3] [4 5 6 7] [8 9 10 11]] [[12 13 14 15] [16 17 18 19] [20 21 22 23]]] +} diff --git a/perf.go b/perf.go index 573d8be..2d20df2 100644 --- a/perf.go +++ b/perf.go @@ -83,19 +83,13 @@ func ReturnTensor(t Tensor) { } switch tt := t.(type) { case *Dense: - if tt.old != nil { - ReturnAP(tt.old) - tt.old = nil - } + tt.AP.zero() if tt.transposeWith != nil { ReturnInts(tt.transposeWith) tt.transposeWith = nil } - // return AP - ReturnAP(tt.AP) - // array reset tt.t = Dtype{} tt.array.Ptr = nil @@ -109,7 +103,7 @@ func ReturnTensor(t Tensor) { tt.flag = 0 // other reset - tt.old = nil + tt.old.zero() tt.viewOf = 0 tt.transposeWith = nil @@ -124,63 +118,14 @@ func ReturnTensor(t Tensor) { } } -/* AP POOL */ - -var apPool = make(chan *AP, PoolSize) - -func borrowAP() *AP { - select { - case ap := <-apPool: - return ap - default: - return new(AP) - } - // return apPool.Get().(*AP) -} - -// BorrowAP gets an AP from the pool. USE WITH CAUTION. -func BorrowAP(dims int) *AP { - ap := borrowAP() - ap.shape = BorrowInts(dims) - ap.strides = BorrowInts(dims) - ap.shape = ap.shape[:cap(ap.shape)] - ap.strides = ap.strides[:cap(ap.strides)] - return ap -} - -// ReturnAP returns the AP to the pool. USE WITH CAUTION. -func ReturnAP(ap *AP) { - ReturnInts([]int(ap.shape)) - ReturnInts(ap.strides) - ap.fin = false - - ap.o = 0 - ap.Δ = 0 - - if len(apPool) < cap(apPool) { - apPool <- ap - } - // apPool.Put(ap) -} - /* ---------------------------------------------------------------- ------------------ Create Pools ------------------------------------------------------------------*/ /* APLIST POOL */ -var apListPool [maxAPDims]sync.Pool - // Init function func init() { - for i := range apListPool { - size := i - apListPool[i].New = func() interface{} { return make([]*AP, size) } - } - - // for i := 0; i < PoolSize; i++ { - // intsPool <- make([]int, 8, 8) - // } for i := range intsPool { size := i @@ -222,11 +167,13 @@ func BorrowInts(size int) []int { if retVal == nil { return make([]int, size) } + // log.Printf("Borrowing %p. Called by %v", retVal, string(debug.Stack())) return retVal.([]int)[:size] } // ReturnInts returns a slice from the pool. USE WITH CAUTION. func ReturnInts(is []int) { + // log.Printf("Returning %p. Called by %v", is, string(debug.Stack())) if is == nil { return } @@ -293,36 +240,6 @@ func ReturnBools(is []bool) { // boolsPool[size].Put(is) } -// BorrowAPList gets an APList from the pool. USE WITH CAUTION. -func BorrowAPList(size int) []*AP { - if size >= 8 { - return make([]*AP, size) - } - - retVal := apListPool[size].Get() - if retVal == nil { - return make([]*AP, size) - } - return retVal.([]*AP) -} - -// ReturnAPList returns the APList to the pool. USE WITH CAUTION. -func ReturnAPList(aps []*AP) { - if aps == nil { - return - } - size := cap(aps) - if size >= 8 { - return - } - aps = aps[:cap(aps)] - for i := range aps { - aps[i] = nil - } - - apListPool[size].Put(aps) -} - // var optPool = make(chan *OpOpt, PoolSize) // var optPool = newRingbuffer(PoolSize) var optPool = &sync.Pool{ diff --git a/shape.go b/shape.go index ba0b18f..cecb41d 100644 --- a/shape.go +++ b/shape.go @@ -24,17 +24,18 @@ func (s Shape) TotalSize() int { return ProdInts([]int(s)) } -func (s Shape) calcStrides() []int { +// CalcStrides calculates the default strides for a shape +func (s Shape) CalcStrides() []int { if s.IsScalar() { return nil } retVal := BorrowInts(len(s)) - if s.IsVector() { - retVal[0] = 1 - retVal = retVal[:1] - return retVal - } + // if s.IsVector() { + // retVal[0] = 1 + // retVal = retVal[:1] + // return retVal + // } acc := 1 for i := len(s) - 1; i >= 0; i-- { @@ -48,9 +49,9 @@ func (s Shape) calcStrides() []int { return retVal } -// calcStridesWithMask is similar to calcStrides, except that it has an argument, masks. It is used to mask out given dimensions +// CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions // during calculation of stride -func (s Shape) calcStridesWithMask(mask []bool) []int { +func (s Shape) CalcStridesWithMask(mask []bool) []int { if s.IsScalar() { return nil } @@ -84,7 +85,8 @@ func (s Shape) calcStridesWithMask(mask []bool) []int { return retVal } -func (s Shape) calcStridesColMajor() []int { +// CalcStridesColMajor is like CalcStrides, but assumes a col major layout +func (s Shape) CalcStridesColMajor() []int { if s.IsScalar() { return nil } @@ -152,7 +154,23 @@ func (s Shape) Clone() Shape { } // IsScalar returns true if the access pattern indicates it's a scalar value -func (s Shape) IsScalar() bool { return len(s) == 0 || (len(s) == 1 && s[0] == 1) } +func (s Shape) IsScalar() bool { + return len(s) == 0 || (len(s) == 1 && s[0] == 1) +} + +// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value +func (s Shape) IsScalarEquiv() bool { + if len(s) == 0 { + return true + } + isEquiv := true + for i := range s { + if s[i] != 1 { + return false + } + } + return isEquiv +} // IsVector returns whether the access pattern falls into one of three possible definitions of vectors: // vanilla vector (not a row or a col) @@ -172,6 +190,9 @@ func (s Shape) IsMatrix() bool { return len(s) == 2 } // Dims returns the number of dimensions in the shape func (s Shape) Dims() int { return len(s) } +// DimSize returns the size of the dimension wanted. +// +// This method implemnents the DimSizer interface in Gorgonia. func (s Shape) DimSize(d int) (size int, err error) { if (s.IsScalar() && d != 0) || (!s.IsScalar() && d >= len(s)) { err = errors.Errorf(dimMismatch, len(s), d) @@ -221,12 +242,14 @@ func (s Shape) S(slices ...Slice) (retVal Shape, err error) { } // drop any dimension with size 1, except the last dimension + offset := 0 dims := s.Dims() for d := 0; d < dims; d++ { - if retVal[d] == 1 /*&& d != t.dims-1 && dims > 2*/ { + if retVal[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ { retVal = append(retVal[:d], retVal[d+1:]...) d-- dims-- + offset++ } } @@ -326,7 +349,7 @@ func (s Shape) Concat(axis int, ss ...Shape) (newShape Shape, err error) { } else { // validate that the rest of the dimensions match up if newShape[d] != shp[d] { - err = errors.Errorf(dimMismatch, newShape[d], shp[d]) + err = errors.Wrapf(errors.Errorf(dimMismatch, newShape[d], shp[d]), "Axis: %d, dimension it failed at: %d", axis, d) return } } diff --git a/shape_test.go b/shape_test.go index 51fe64a..aa9e1de 100644 --- a/shape_test.go +++ b/shape_test.go @@ -90,36 +90,36 @@ func TestShapeCalcStride(t *testing.T) { // scalar shape s = Shape{} - assert.Nil(s.calcStrides()) + assert.Nil(s.CalcStrides()) s = Shape{1} - assert.Nil(s.calcStrides()) + assert.Nil(s.CalcStrides()) // vector shape s = Shape{2, 1} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1, 1}, s.CalcStrides()) s = Shape{1, 2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) s = Shape{2} - assert.Equal([]int{1}, s.calcStrides()) + assert.Equal([]int{1}, s.CalcStrides()) // matrix strides s = Shape{2, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) s = Shape{5, 2} - assert.Equal([]int{2, 1}, s.calcStrides()) + assert.Equal([]int{2, 1}, s.CalcStrides()) // 3D strides s = Shape{2, 3, 4} - assert.Equal([]int{12, 4, 1}, s.calcStrides()) + assert.Equal([]int{12, 4, 1}, s.CalcStrides()) // stupid shape s = Shape{-2, 1, 2} fail := func() { - s.calcStrides() + s.CalcStrides() } assert.Panics(fail) } @@ -191,6 +191,10 @@ var shapeSliceTests = []struct { {"vec[3]", Shape{2}, []Slice{rs{3, 4, 0}}, nil, true}, {"vec[:, 0]", Shape{2}, []Slice{nil, rs{0, 1, 0}}, nil, true}, {"vec[1:4:2]", Shape{5}, []Slice{rs{1, 4, 2}}, ScalarShape(), false}, + {"tensor[0, :, :]", Shape{1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil}, Shape{2, 2}, false}, + {"tensor[:, 0, :]", Shape{1, 2, 2}, []Slice{nil, rs{0, 1, 1}, nil}, Shape{1, 2}, false}, + {"tensor[0, :, :, :]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}, nil, nil, nil}, Shape{1, 2, 2}, false}, + {"tensor[0,]", Shape{1, 1, 2, 2}, []Slice{rs{0, 1, 1}}, Shape{1, 2, 2}, false}, } func TestShape_Slice(t *testing.T) { diff --git a/sparse.go b/sparse.go index abb36c1..5de67d4 100644 --- a/sparse.go +++ b/sparse.go @@ -31,7 +31,7 @@ type coo struct { func (c *coo) Len() int { return c.data.L } func (c *coo) Less(i, j int) bool { - if c.o.isColMajor() { + if c.o.IsColMajor() { return c.colMajorLess(i, j) } return c.rowMajorLess(i, j) @@ -182,13 +182,14 @@ func CSCFromCoord(shape Shape, xs, ys []int, data interface{}) *CS { return t } -func (t *CS) Shape() Shape { return t.s } -func (t *CS) Strides() []int { return nil } -func (t *CS) Dtype() Dtype { return t.t } -func (t *CS) Dims() int { return 2 } -func (t *CS) Size() int { return t.s.TotalSize() } -func (t *CS) DataSize() int { return t.L } -func (t *CS) Engine() Engine { return t.e } +func (t *CS) Shape() Shape { return t.s } +func (t *CS) Strides() []int { return nil } +func (t *CS) Dtype() Dtype { return t.t } +func (t *CS) Dims() int { return 2 } +func (t *CS) Size() int { return t.s.TotalSize() } +func (t *CS) DataSize() int { return t.L } +func (t *CS) Engine() Engine { return t.e } +func (t *CS) DataOrder() DataOrder { return t.o } func (t *CS) Slice(...Slice) (View, error) { return nil, errors.Errorf("Slice for sparse tensors not implemented yet") @@ -232,11 +233,12 @@ func (t *CS) T(axes ...int) error { } UnsafePermute(axes, []int(t.s)) t.o = t.o.toggleColMajor() + t.o = MakeDataOrder(t.o, Transposed) return errors.Errorf(methodNYI, "T") } // UT untransposes the CS -func (t *CS) UT() { t.T() } +func (t *CS) UT() { t.T(); t.o = t.o.clearTransposed() } // Transpose is a no-op. The data does not move func (t *CS) Transpose() error { return nil } @@ -307,7 +309,7 @@ func (t *CS) Iterator() Iterator { return NewFlatSparseIterator(t) } func (t *CS) at(coord ...int) (int, bool) { var r, c int - if t.o.isColMajor() { + if t.o.IsColMajor() { r = coord[1] c = coord[0] } else { @@ -330,7 +332,7 @@ func (t *CS) Dense() *Dense { } d := recycledDense(t.t, t.Shape().Clone()) - if t.o.isColMajor() { + if t.o.IsColMajor() { for i := 0; i < len(t.indptr)-1; i++ { for j := t.indptr[i]; j < t.indptr[i+1]; j++ { d.SetAt(t.Get(j), t.indices[j], i) @@ -361,14 +363,14 @@ func (t *CS) Indices() []int { } func (t *CS) AsCSR() { - if t.o.isRowMajor() { + if t.o.IsRowMajor() { return } t.o.toggleColMajor() } func (t *CS) AsCSC() { - if t.o.isColMajor() { + if t.o.IsColMajor() { return } t.o.toggleColMajor() diff --git a/tensor.go b/tensor.go index d1b348a..a04e425 100644 --- a/tensor.go +++ b/tensor.go @@ -36,6 +36,7 @@ type Tensor interface { // Data access related RequiresIterator() bool Iterator() Iterator + DataOrder() DataOrder // ops Slicer @@ -86,7 +87,6 @@ type Tensor interface { // New creates a new Dense Tensor. For sparse arrays use their relevant construction function func New(opts ...ConsOpt) *Dense { d := borrowDense() - d.AP = new(AP) for _, opt := range opts { opt(d) } diff --git a/testutils_test.go b/testutils_test.go index e219ab1..71a43a4 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -240,6 +240,7 @@ func allClose(a, b interface{}, approxFn ...interface{}) bool { return reflect.DeepEqual(a, b) } } + func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) { switch { case expected: diff --git a/types.go b/types.go index fd8e189..69740cf 100644 --- a/types.go +++ b/types.go @@ -299,6 +299,7 @@ func RegisterFloat(a Dtype) { RegisterOrd(a) } +// RegisterOrd registers a dtype as a type that can be typed func RegisterOrd(a Dtype) { for _, dt := range ordTypes.set { if dt == a { @@ -306,8 +307,10 @@ func RegisterOrd(a Dtype) { } } ordTypes.set = append(ordTypes.set, a) + RegisterEq(a) } +// RegisterEq registers a dtype as a type that can be compared for equality func RegisterEq(a Dtype) { for _, dt := range eqTypes.set { if dt == a { @@ -315,6 +318,26 @@ func RegisterEq(a Dtype) { } } eqTypes.set = append(eqTypes.set, a) + Register(a) +} + +// Register registers a new Dtype +func Register(a Dtype) { + for _, dt := range allTypes.set { + if a == dt { + return + } + } + allTypes.set = append(allTypes.set, a) +} + +func dtypeID(a Dtype) int { + for i, v := range allTypes.set { + if a == v { + return i + } + } + return -1 } // NormOrder represents the order of the norm. Ideally, we'd only represent norms with a uint/byte. diff --git a/utils.go b/utils.go index 9dcd936..8e62448 100644 --- a/utils.go +++ b/utils.go @@ -213,7 +213,6 @@ func UnsafePermute(pattern []int, xs ...[]int) (err error) { return nil } - // CheckSlice checks a slice to see if it's sane func CheckSlice(s Slice, size int) error { start := s.Start() @@ -282,9 +281,8 @@ func reuseCheckShape(reuse DenseTensor, s Shape) (err error) { } // clean up any funny things that may be in the reuse - if oldAP := reuse.oldAP(); oldAP != nil { - ReturnAP(oldAP) - reuse.setOldAP(nil) + if oldAP := reuse.oldAP(); !oldAP.IsZero() { + oldAP.zero() } if axes := reuse.transposeAxes(); axes != nil { @@ -309,7 +307,6 @@ func memsetBools(a []bool, v bool) { } } - /* FOR ILLUSTRATIVE PURPOSES */ // Permute permutates a pattern according to xs. This function exists for illustrative purposes (i.e. the dumb, unoptimized version) @@ -385,4 +382,4 @@ func Permute(pattern []int, xs ...[]int) (retVal [][]int, err error) { } return } -*/ \ No newline at end of file +*/