-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmnist.go
94 lines (84 loc) · 2.68 KB
/
mnist.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// Package mnist provides utilities for reading the MNIST dataset.
//
// For a description of this data format, see the documentation here:
// http://yann.lecun.com/exdb/mnist/
package mnist
import (
"path/filepath"
"github.com/kujenga/goml/idx"
"github.com/kujenga/goml/lin"
)
// MNIST provides a parsed form of the dataset, mirroring the files distributed
// with it.
type MNIST struct {
// Inputs for training, as flattened lists of pixels normalized (0,1)
TrainInputs lin.Frame
// Labels for training, as one-hot encoded digit indications.
TrainLabels lin.Frame
// Inputs for testing, as flattened lists of pixels normalized (0,1)
TestInputs lin.Frame
// Labels for testing, as one-hot encoded digit indications.
TestLabels lin.Frame
}
// Read returns a parsed MNIST dataset from the indicated root directory. It
// expects files within that directory to be present as specified in the
// documentation: http://yann.lecun.com/exdb/mnist/
func Read(rootDir string) (*MNIST, error) {
trainRawImages, err := idx.ReadFile(
filepath.Join(rootDir, "train-images-idx3-ubyte.gz"))
if err != nil {
return nil, err
}
trainRawLabels, err := idx.ReadFile(
filepath.Join(rootDir, "train-labels-idx1-ubyte.gz"))
if err != nil {
return nil, err
}
testRawImages, err := idx.ReadFile(
filepath.Join(rootDir, "t10k-images-idx3-ubyte.gz"))
if err != nil {
return nil, err
}
testRawLabels, err := idx.ReadFile(
filepath.Join(rootDir, "t10k-labels-idx1-ubyte.gz"))
if err != nil {
return nil, err
}
out := &MNIST{
TrainInputs: make(lin.Frame, trainRawImages.Dimensions[0]),
TrainLabels: make(lin.Frame, trainRawLabels.Dimensions[0]),
TestInputs: make(lin.Frame, testRawImages.Dimensions[0]),
TestLabels: make(lin.Frame, testRawLabels.Dimensions[0]),
}
// Convert images to the desired (0, 1) format
for i := range out.TrainInputs {
out.TrainInputs[i] = make([]float32, 28*28)
for j := range out.TrainInputs[i] {
out.TrainInputs[i][j] = float32(
trainRawImages.Data[i*28*28+j])/255.0*0.99 + 0.01
}
}
for i := range out.TestInputs {
out.TestInputs[i] = make([]float32, 28*28)
for j := range out.TestInputs[i] {
out.TestInputs[i][j] = float32(
testRawImages.Data[i*28*28+j])/255.0*0.99 + 0.01
}
}
// Convert labels to the desired one-hot format
for i := range out.TrainLabels {
out.TrainLabels[i] = make([]float32, 10)
for j := range out.TrainLabels[i] {
out.TrainLabels[i][j] = 0.01
}
out.TrainLabels[i][trainRawLabels.Data[i]] = 0.99
}
for i := range out.TestLabels {
out.TestLabels[i] = make([]float32, 10)
for j := range out.TestLabels[i] {
out.TestLabels[i][j] = 0.01
}
out.TestLabels[i][testRawLabels.Data[i]] = 0.99
}
return out, nil
}