-
-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathap.go
447 lines (379 loc) · 11.7 KB
/
ap.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
package tensor
import (
"fmt"
"github.com/pkg/errors"
)
// An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides
// Through the AP, there are several definitions of things, most notably there are two very specific "special cases":
// Scalar has Dims() of 0.
// - (1)
// Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0.
// - (1, 1)
// - (1, 1, 1)
// - (1, 1, 1, 1), etc
// Vector has Dims() of 1, but its shape can take several forms:
// - (x, 1)
// - (1, x)
// - (x)
// Matrix has Dims() of 2. This is the most basic form. The len(shape) has to be equal to 2 as well
// ndarray has Dims() of n.
type AP struct {
shape Shape // len(shape) is the operational definition of the dimensions
strides []int // strides is usually calculated from shape
fin bool // is this struct change-proof?
o DataOrder
Δ Triangle
}
func makeAP(size int) AP {
return AP{
shape: Shape(BorrowInts(size)),
strides: BorrowInts(size),
}
}
// MakeAP creates an AP, given the shape and strides.
func MakeAP(shape Shape, strides []int, o DataOrder, Δ Triangle) AP {
return AP{
shape: shape,
strides: strides,
o: o,
Δ: Δ,
fin: true,
}
}
// Init initializes an already created AP with a shape and stries.
// It will panic if AP is nil.
func (ap *AP) Init(shape Shape, strides []int) {
ap.shape = shape
ap.strides = strides
ap.fin = true
}
// SetShape is for very specific times when modifying the AP is necessary, such as reshaping and doing I/O related stuff
//
// Caveats:
//
// - SetShape will recalculate the strides.
//
// - If the AP is locked, nothing will happen
func (ap *AP) SetShape(s ...int) {
if !ap.fin {
// scalars are a special case, we don't want to remove it completely
if len(s) == 0 {
if ap.shape == nil || ap.strides == nil {
ap.shape = Shape{}
}
ap.shape = ap.shape[:0]
ap.strides = ap.strides[:0]
return
}
if ap.shape != nil {
ReturnInts(ap.shape)
ap.shape = nil
}
if ap.strides != nil {
ReturnInts(ap.strides)
ap.strides = nil
}
ap.shape = Shape(s).Clone()
ap.strides = ap.calcStrides()
}
}
// Shape returns the shape of the AP
func (ap *AP) Shape() Shape { return ap.shape }
// Strides returns the strides of the AP
func (ap *AP) Strides() []int { return ap.strides }
// Dims returns the dimensions of the shape in the AP
func (ap *AP) Dims() int { return ap.shape.Dims() }
// Size returns the expected array size of the shape
func (ap *AP) Size() int { return ap.shape.TotalSize() }
// String implements fmt.Stringer and runtime.Stringer
func (ap *AP) String() string { return fmt.Sprintf("%v", ap) }
// Format implements fmt.Formatter
func (ap *AP) Format(state fmt.State, c rune) {
fmt.Fprintf(state, "Shape: %v, Stride: %v, Lock: %t", ap.shape, ap.strides, ap.fin)
}
// IsVector returns whether the access pattern falls into one of three possible definitions of vectors:
// vanilla vector (not a row or a col)
// column vector
// row vector
func (ap *AP) IsVector() bool { return ap.shape.IsVector() }
// IsVectorLike returns true if the shape is vector-like (i.e. the shape only has one dim that is a non-1).
func (ap *AP) IsVectorLike() bool {
return ap.shape.IsVectorLike() && allones(ap.strides)
}
// IsColVec returns true when the access pattern has the shape (x, 1)
func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() }
// IsRowVec returns true when the access pattern has the shape (1, x)
func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() }
// IsScalar returns true if the access pattern indicates it's a scalar value.
func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() }
// IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape.
func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() }
// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices
func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 }
// IsZero tell us if the ap has zero size
func (ap *AP) IsZero() bool {
return len(ap.shape) == 0 && len(ap.strides) == 0 && !ap.fin && ap.o == 0 && ap.Δ == 0
}
// Zero zeros out an AP.
func (ap *AP) zero() {
// log.Printf("ZEROING. Called by %v", string(debug.Stack()))
// Jorge's original implementation for zeroing a AP is as below
// but to cater for the (*Dense).fix() method of the *Dense
// a nil shape is used to signal unsetness
// so we cannot just truncate the shape even though it would be a lot more efficient
// ap.shape = ap.shape[:0]
// ap.strides = ap.strides[:0]
ReturnInts([]int(ap.shape))
ReturnInts(ap.strides)
ap.zeroOnly()
}
// side effect free zeroing
func (ap *AP) zeroOnly() {
ap.shape = nil
ap.strides = nil
ap.fin = false
ap.o = 0
ap.Δ = 0
}
func (ap *AP) zeroWithDims(dims int) {
//ap.shape = BorrowInts(dims)
//ap.strides = BorrowInts(dims)
if cap(ap.shape) >= dims {
ap.shape = ap.shape[:dims]
}
ap.shape = BorrowInts(dims)
if cap(ap.strides) >= dims {
ap.strides = ap.strides[:dims]
}
ap.strides = BorrowInts(dims)
}
// Clone clones the *AP. Clearly. It returns AP
func (ap *AP) Clone() (retVal AP) {
retVal = makeAP(cap(ap.shape))
copy(retVal.shape, ap.shape)
copy(retVal.strides, ap.strides)
// handle vectors
retVal.shape = retVal.shape[:len(ap.shape)]
retVal.strides = retVal.strides[:len(ap.strides)]
retVal.fin = ap.fin
retVal.o = ap.o
retVal.Δ = ap.Δ
return
}
func (ap *AP) CloneTo(dest *AP) {
dest.shape = append(dest.shape[:0], ap.shape...)
dest.strides = append(dest.strides[:0], ap.strides...)
dest.fin = ap.fin
dest.o = ap.o
dest.Δ = ap.Δ
}
// DataOrder returns the data order of the AP.
func (ap *AP) DataOrder() DataOrder { return ap.o }
// C returns true if the access pattern is C-contiguous array
func (ap *AP) C() bool { return ap.o.IsRowMajor() && ap.o.IsContiguous() }
// F returns true if the access pattern is Fortran contiguous array
func (ap *AP) F() bool { return ap.o.IsColMajor() && ap.o.IsContiguous() }
// S returns the metadata of the sliced tensor.
func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err error) {
if len(slices) > len(ap.shape) {
// error
err = errors.Errorf(dimMismatch, len(ap.shape), len(slices))
return
}
ndEnd = size
newShape := ap.shape.Clone() // the new shape
dims := ap.Dims() // reported dimensions
newStrides := BorrowInts(dims) // the new strides
var outerDim int
order := ap.o
if ap.o.IsRowMajor() || ap.IsVector() {
outerDim = 0
} else {
outerDim = len(ap.shape) - 1
}
for i := 0; i < dims; i++ {
var sl Slice
if i <= len(slices)-1 {
sl = slices[i]
}
size := ap.shape[i]
var stride int
stride = ap.strides[i]
// if ap.IsVector() {
// // handles non-vanilla vectors
// stride = ap.strides[0]
// } else {
// stride = ap.strides[i]
// }
var start, end, step int
if start, end, step, err = SliceDetails(sl, size); err != nil {
err = errors.Wrapf(err, "Unable to get slice details on slice %d with size %d: %v", i, sl, size)
return
}
// a slice where start == end is []
ndStart = ndStart + start*stride
ndEnd = ndEnd - (size-end)*stride
if step > 0 {
if newShape[i] = (end - start) / step; (end-start)%step > 0 && i > 0 {
newShape[i]++
}
newStrides[i] = stride * step
//fix
if newShape[i] <= 0 {
newShape[i] = 1
}
} else {
newShape[i] = (end - start)
newStrides[i] = stride
}
if (sl != nil && (!ap.IsVector() && i != outerDim)) || step > 1 {
order = MakeDataOrder(order, NonContiguous)
}
}
if ndEnd-ndStart == 1 {
// scalars are a special case
newAP = AP{}
newAP.SetShape() // make it a Scalar
newAP.lock()
} else {
// drop any dimension with size 1, except the last dimension
offset := 0
for d := 0; d < dims; d++ {
if newShape[d] == 1 && offset+d <= len(slices)-1 && slices[offset+d] != nil /*&& d != t.dims-1 && dims > 2*/ {
newShape = append(newShape[:d], newShape[d+1:]...)
newStrides = append(newStrides[:d], newStrides[d+1:]...)
d--
dims--
offset++
}
}
newAP = MakeAP(newShape, newStrides, order, ap.Δ)
}
return
}
// T returns the transposed metadata based on the given input
func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) {
// prep axes
if len(axes) > 0 && len(axes) != ap.Dims() {
err = errors.Errorf(dimMismatch, ap.Dims(), len(axes))
return
}
dims := len(ap.shape)
if len(axes) == 0 || axes == nil {
axes = make([]int, dims)
for i := 0; i < dims; i++ {
axes[i] = dims - 1 - i
}
}
a = axes
if ap.shape.IsScalarEquiv() {
return ap.Clone(), a, noopError{}
}
// if axes is 0, 1, 2, 3... then no op
if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 {
return ap.Clone(), a, noopError{}
}
currentShape := ap.shape
currentStride := ap.strides
shape := make(Shape, len(currentShape))
strides := make([]int, len(currentStride))
switch {
case ap.IsScalar():
return
case ap.IsVector():
if axes[0] == 0 {
return
}
strides[0], strides[1] = 1, 1
shape[0], shape[1] = currentShape[1], currentShape[0]
default:
copy(shape, currentShape)
copy(strides, currentStride)
err = UnsafePermute(axes, shape, strides)
if err != nil {
err = handleNoOp(err)
}
}
o := MakeDataOrder(ap.o, Transposed)
retVal = MakeAP(shape, strides, o, ap.Δ)
retVal.fin = true
return
}
// locking and unlocking is used to ensure that the shape and stride doesn't change (it's not really safe though, as a direct mutation of the strides/shape would still mutate it, but at least the dimensions cannot change)
func (ap *AP) lock() { ap.fin = true }
func (ap *AP) unlock() { ap.fin = false }
func (ap *AP) calcStrides() []int {
switch {
case ap.o.IsRowMajor():
return ap.shape.CalcStrides()
case ap.o.IsColMajor():
return ap.shape.CalcStridesColMajor()
}
panic("unreachable")
}
// setDataOrder is a method such that any tensor that embeds *AP will have the same method
func (ap *AP) setDataOrder(o DataOrder) {
if !o.HasSameOrder(ap.o) {
ap.o = ap.o.toggleColMajor()
}
}
// TransposeIndex returns the new index given the old index
func TransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int {
oldCoord, err := Itol(i, oldShape, oldStrides)
if err != nil {
panic(err) // or return error?
}
/*
coordss, _ := Permute(pattern, oldCoord)
coords := coordss[0]
index, _ := Ltoi(newShape, strides, coords...)
*/
// The above is the "conceptual" algorithm.
// Too many checks above slows things down, so the below is the "optimized" edition
var index int
for i, axis := range pattern {
index += oldCoord[axis] * newStrides[i]
}
return index
}
// UntransposeIndex returns the old index given the new index
func UntransposeIndex(i int, oldShape, pattern, oldStrides, newStrides []int) int {
newPattern := make([]int, len(pattern))
for i, p := range pattern {
newPattern[p] = i
}
return TransposeIndex(i, oldShape, newPattern, oldStrides, newStrides)
}
// BroadcastStrides handles broadcasting from different shapes.
//
// Deprecated: this function will be unexported
func BroadcastStrides(destShape, srcShape Shape, destStrides, srcStrides []int) (retVal []int, err error) {
dims := len(destShape)
start := dims - len(srcShape)
if destShape.IsVector() && srcShape.IsVector() {
return []int{srcStrides[0]}, nil
}
if start < 0 {
//error
err = errors.Errorf(dimMismatch, dims, len(srcShape))
return
}
retVal = BorrowInts(len(destStrides))
for i := dims - 1; i >= start; i-- {
s := srcShape[i-start]
switch {
case s == 1:
retVal[i] = 0
case s != destShape[i]:
// error
err = errors.Errorf("Cannot broadcast from %v to %v", srcShape, destShape)
return
default:
retVal[i] = srcStrides[i-start]
}
}
for i := 0; i < start; i++ {
retVal[i] = 0
}
return
}