forked from vitali87/torch-lab
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsession10_cnn.py
113 lines (98 loc) · 3.67 KB
/
session10_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torchvision
from torch import nn
import cv2 as cv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from torchvision import transforms
from torch.nn.functional import one_hot
n_classes = 10
# AA = torchvision.datasets.MNIST("cnn/", train=True, download=True, transform=ToTensor(),
# target_transform=torchvision.transforms.Compose([
# lambda x: torch.tensor([x]), # or just torch.tensor
# lambda x: one_hot(x, 10)]),
# )
AA = torchvision.datasets.MNIST("cnn/", train=True, download=True, transform=ToTensor())
# BB = torchvision.datasets.MNIST("cnn/", train=False, download=True, transform=ToTensor(),
# target_transform=torchvision.transforms.Compose([
# lambda x: torch.tensor([x]), # or just torch.tensor
# lambda x: one_hot(x, 10)]),
# )
BB = torchvision.datasets.MNIST(
"cnn/", train=False, download=True, transform=ToTensor()
)
# img0 = AA.data[2].numpy()
# img1 = AA.data[1].numpy()
# cv.imshow("Display window",img0 )
# cv.imshow("bla", img1)
#
# # window is displayed until we press any key
# k = cv.waitKey(0)
# cv.destroyAllWindows()
dty = torch.double
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.stack = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(10, 10), dtype=dty),
# nn.MaxPool2d(kernel_size=(2, 2)),
nn.Flatten(),
nn.LazyLinear(out_features=20, dtype=dty),
nn.LazyLinear(out_features=n_classes, dtype=dty),
nn.ReLU(),
)
def forward(self, x):
return self.stack(x)
model = SimpleCNN()
model.to("cuda")
ce_loss = nn.CrossEntropyLoss()
optimiser = torch.optim.SGD(model.stack.parameters(), lr=0.001)
batch_size = 64
train_dataloader = DataLoader(AA, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(BB, batch_size=batch_size, shuffle=True)
val_losses = []
train_losses = []
n_epochs = 5
for i in range(n_epochs):
for batch, train in enumerate(train_dataloader):
x_train = train[0].double().to("cuda")
y_train = train[1].long().to("cuda")
# y_train = train[1].double().to("cuda")
y_pred = model(x_train)
step_loss = ce_loss(y_pred, y_train)
# step_loss = ce_loss(y_pred, y_train.reshape(x_train.shape[0],-1))
optimiser.zero_grad()
step_loss.backward()
optimiser.step()
with torch.no_grad():
for batch_test, test in enumerate(test_dataloader):
x_test = test[0].double().to("cuda")
y_test = test[1].long().to("cuda")
# y_test = test[1].double().to("cuda")
if batch_test == 1:
break
y_test_pred = model(x_test)
test_loss = ce_loss(y_test_pred, y_test)
# test_loss = ce_loss(y_test_pred, y_test.reshape(x_test.shape[0],-1))
print(
f"epoch {i}, batch {batch}, Train Loss: {step_loss.item()}, Test Loss {test_loss.item()}"
)
train_losses.append(step_loss.item())
val_losses.append(test_loss.item())
#
for i in model.parameters():
print(i)
#
# n_ = 20
# print(torch.round(y_test_pred[:n_]))
# print(y_test[:n_])
#
plt.figure(figsize=(10, 5))
plt.title("Training and Validation Loss")
plt.plot(val_losses, label="val")
plt.plot(train_losses, label="train")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()