-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmnist.py
79 lines (59 loc) · 2.32 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Code from https://github.com/sorki/python-mnist
import gzip
import os
import struct
from array import array
class MNIST(object):
def __init__(self, path='.'):
self.path = path
self.test_img_fname = 't10k-images-idx3-ubyte.gz'
self.test_lbl_fname = 't10k-labels-idx1-ubyte.gz'
self.train_img_fname = 'train-images-idx3-ubyte.gz'
self.train_lbl_fname = 'train-labels-idx1-ubyte.gz'
self.test_images = []
self.test_labels = []
self.train_images = []
self.train_labels = []
def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))
self.test_images = ims
self.test_labels = labels
return ims, labels
def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))
self.train_images = ims
self.train_labels = labels
return ims, labels
@classmethod
def load(cls, path_img, path_lbl):
with gzip.open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got {}'.format(magic))
labels = array("B", file.read())
with gzip.open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051,'
'got {}'.format(magic))
image_data = array("B", file.read())
images = []
for i in range(size):
images.append([0] * rows * cols)
for i in range(size):
images[i][:] = image_data[i * rows * cols:(i + 1) * rows * cols]
return images, labels
@classmethod
def display(cls, img, width=28, threshold=200):
render = ''
for i in range(len(img)):
if i % width == 0:
render += '\n'
if img[i] > threshold:
render += '@'
else:
render += '.'
return render