-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
126 lines (108 loc) · 3.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Model training and evaluation."""
import json
import dvclive
from ruamel.yaml import YAML
import os
import torch
import torch.nn.functional as F
import torchvision
from dvclive import Live
dvclive = Live()
class ConvNet(torch.nn.Module):
"""Toy convolutional neural net."""
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 8, 3, padding=1)
self.maxpool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1)
self.dense1 = torch.nn.Linear(16*14*14, 32)
self.dense2 = torch.nn.Linear(32, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.maxpool1(x)
x = F.relu(self.conv2(x))
x = x.view(-1, 16*14*14)
x = F.relu(self.dense1(x))
x = self.dense2(x)
return x
def transform(dataset):
"""Get inputs and targets from dataset."""
x = dataset.data.reshape(len(dataset.data), 1, 28, 28)/255
y = dataset.targets
return x, y
def train(model, x, y, lr, weight_decay):
"""Train a single epoch."""
model.train()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def predict(model, x):
"""Get model prediction scores."""
model.eval()
with torch.no_grad():
y_pred = model(x)
return y_pred
def get_metrics(y, y_pred, y_pred_label):
"""Get loss and accuracy metrics."""
metrics = {}
criterion = torch.nn.CrossEntropyLoss()
metrics["loss"] = criterion(y_pred, y).item()
metrics["acc"] = (y_pred_label == y).sum().item()/len(y)
return metrics
def evaluate(model, x, y):
"""Evaluate model and save metrics."""
scores = predict(model, x)
_, labels = torch.max(scores, 1)
predictions = [{
"actual": int(actual),
"predicted": int(predicted)
} for actual, predicted in zip(y, labels)]
with open("predictions.json", "w") as f:
json.dump(predictions, f)
metrics = get_metrics(y, scores, labels)
with open("results.json", "w") as fd:
json.dump({'acc': metrics["acc"], 'loss': metrics["loss"]}, fd, indent=4)
return metrics
def main():
"""Train model and evaluate on test data."""
torch.manual_seed(0)
model = ConvNet()
# Load model.
if os.path.exists("model.pt"):
model.load_state_dict(torch.load("model.pt"))
# Load params.
with open("params.yaml") as f:
yaml=YAML(typ='safe')
params = yaml.load(f)
torch.manual_seed(params["seed"])
# Load train and test data.
mnist_train = torchvision.datasets.MNIST("data", download=True)
x_train, y_train = transform(mnist_train)
mnist_test = torchvision.datasets.MNIST("data", download=True, train=False)
x_test, y_test = transform(mnist_test)
try:
# Iterate over training epochs.
for i in range(1, params["num_epochs"]+1):
# Train in batches.
train_loader = torch.utils.data.DataLoader(
dataset=list(zip(x_train, y_train)),
batch_size=512,
shuffle=True)
for x_batch, y_batch in train_loader:
train(model, x_batch, y_batch, params["lr"], params["weight_decay"])
torch.save(model.state_dict(), "model.pt")
# Evaluate and checkpoint.
metrics = evaluate(model, x_test, y_test)
for k, v in metrics.items():
print('Epoch %s: %s=%s'%(i, k, v))
dvclive.log(k, v)
dvclive.next_step()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()