Skip to content

Commit

Permalink
Overhaul training loop to use SMLL verison
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Sep 27, 2019
1 parent fe82bc2 commit 898fdbd
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 334 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ experiments/treelstm/pt_tlstm/lib/stanford-postagger**
experiments/treelstm/setup/*

# pulled-in resources for training loop
experiments/training_loop/beacon
experiments/training_loop/data
experiments/training_loop/smll
experiments/training_loop/source.py

# data and graphs
*.csv
Expand Down
96 changes: 65 additions & 31 deletions experiments/training_loop/analyze.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import csv
import numpy as np
import os

from validate_config import validate
from common import invoke_main, write_status, write_json, render_exception
from analysis_util import trials_stat_summary, add_detailed_summary


def data_file(data_dir, fw):
return os.path.join(data_dir, '{}-train.csv'.format(fw))


def compute_summary(data):
return {
'mean': np.mean(data),
'median': np.median(data),
'std': np.std(data)
}


def main(data_dir, config_dir, output_dir):
config, msg = validate(config_dir)
Expand All @@ -12,42 +26,62 @@ def main(data_dir, config_dir, output_dir):

frameworks = config['frameworks']
devices = config['devices']
num_reps = config['n_inputs']
num_classes = list(config['num_classes'])[0]
batch_size = list(config['batch_sizes'])[0]
epochs = config['epochs']
datasets = config['datasets']
models = config['models']

fieldnames = ['device', 'model', 'dataset', 'rep', 'epoch',
'time', 'loss', 'correct', 'total']

listing_settings = {
'Relay': 'relay',
'Keras': 'keras'
'PyTorch': 'pt'
}

fieldnames = ['device', 'batch_size', 'num_classes', 'epochs']

# output averages on each network for each framework and each device
# report final accuracy, final loss, average time per epoch across reps
ret = {}
for dev in devices:
ret[dev] = {}
for listing, framework in listing_settings.items():
ret[dev][listing] = {}
for epoch_count in epochs:
field_values = {
'device': dev,
'batch_size': batch_size,
'num_classes': num_classes,
'epochs': epoch_count
}

summary, success, msg = trials_stat_summary(data_dir, framework, 'training_loop', num_reps,
fieldnames, field_values)
if not success:
write_status(output_dir, False, msg)
return 1
ret[dev][listing][epoch_count] = summary['mean']
add_detailed_summary(ret, summary, dev, listing, epoch_count)

write_json(output_dir, 'data.json', ret)
write_status(output_dir, True, 'success')
try:
for dev in devices:
ret[dev] = {}
for listing, spec_settings in listing_settings.items():
ret[dev][listing] = {dataset: {model: {} for model in models}
for dataset in datasets}
fw = spec_settings

epoch_times = {dataset: {model: [] for model in models}
for dataset in datasets}
final_accs = {dataset: {model: [] for model in models}
for dataset in datasets}
final_losses = {dataset: {model: [] for model in models}
for dataset in datasets}

filename = data_file(data_dir, fw)
with open(filename, newline='') as csvfile:
reader = csv.DictReader(csvfile, fieldnames)
for row in reader:
if row['device'] != dev:
continue
epoch_times[row['dataset']][row['model']].append(
float(row['time']))
if int(row['epoch']) == epochs - 1:
final_accs[row['dataset']][row['model']].append(
float(row['correct'])/float(row['total']))
final_losses[row['dataset']][row['model']].append(
float(row['loss']))

for dataset in datasets:
for model in models:
ret[dev][listing][dataset][model] = {
'time': compute_summary(epoch_times[dataset][model]),
'acc': compute_summary(final_accs[dataset][model]),
'loss': compute_summary(final_losses[dataset][model])
}

write_json(output_dir, 'data.json', ret)
write_status(output_dir, True, 'success')

except Exception as e:
write_status(output_dir, False, render_exception(e))
return 1


if __name__ == '__main__':
Expand Down
15 changes: 9 additions & 6 deletions experiments/training_loop/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ source $BENCHMARK_DEPS/bash/common.sh
include_shared_python_deps
add_to_pythonpath $(pwd)

python_run_trial "run_keras.py" $config_dir $data_dir
rm -rf smll
rm -rf data
rm -f source.py
git clone [email protected]:uwsampl/smll.git

# This benchmark requires a specific branch of Beacon
rm -rf ./beacon
git clone [email protected]:MarisaKirisame/beacon.git
add_to_pythonpath $(pwd)/beacon
cd smll
stack run -- compile
cd ..
cp smll/python/source.py ./source.py

python_run_trial "run_relay.py" $config_dir $data_dir
python_run_trial "run_pt.py" $config_dir $data_dir
79 changes: 0 additions & 79 deletions experiments/training_loop/run_keras.py

This file was deleted.

146 changes: 146 additions & 0 deletions experiments/training_loop/run_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Based on https://github.com/CSCfi/machine-learning-scripts/blob/master/notebooks/pytorch-mnist-mlp.ipynb
"""
import csv
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

from common import invoke_main, read_config, write_status, render_exception
from trial_util import configure_seed
from validate_config import validate

import source

def load_model(model_name):
if model_name == 'mlp':
return source.mnist()
raise Exception('Unsupported model: ' + model_name)


def load_raw_data(dataset_name):
if dataset_name == 'mnist':
train = datasets.MNIST('./data',
train=True,
download=True,
transform=transforms.ToTensor())
validation = datasets.MNIST('./data',
train=False,
transform=transforms.ToTensor())
return (train, validation)
raise Exception('Unsupported dataset: ' + dataset_name)


def get_data_loader(raw_data, batch_size, shuffle):
return torch.utils.data.DataLoader(dataset=raw_data,
batch_size=batch_size,
shuffle=shuffle)


def train(train_loader, model, device):
for batch_idx, (data, target) in enumerate(train_loader):
# Copy data to GPU if needed
data = data.to(device)
data = data.view(-1, 28*28)
target = target.to(device)

# Calculate loss
loss = model[0](data, target)


def validate_learner(validation_loader, model, device):
criterion = nn.CrossEntropyLoss()
val_loss, correct = 0, 0
for data, target in validation_loader:
data = data.to(device)
target = target.to(device)
data = data.view(-1, 28*28)
output = model[1](data)
val_loss += criterion(output, target).data.item()
pred = output.data.max(1)[
1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

return val_loss, int(correct.item()), len(validation_loader.dataset)


def main(config_dir, output_dir):
config, msg = validate(config_dir)
if config is None:
write_status(output_dir, False, msg)
return 1

if 'pt' not in config['frameworks']:
write_status(output_dir, True, 'PyTorch not run')
return 0

configure_seed(config)
device = source.device
dev = 'gpu' # TODO ensure we can set this appropriately in SMLL

batch_size = config['batch_size']
epochs = config['epochs']
models = config['models']
datasets = config['datasets']
dry_run = config['dry_run']
reps = config['reps']

# record: dev, model, dataset, rep, epoch, time, loss, num correct, total
fieldnames = ['device', 'model', 'dataset', 'rep', 'epoch',
'time', 'loss', 'correct', 'total']
try:
info = []
for dataset in datasets:
raw_train, raw_validation = load_raw_data(dataset)
for model_name in models:
for rep in range(reps):
training = get_data_loader(raw_train, batch_size, True)
model = load_model(model_name)

# dry run: train and throw away
for dry_epoch in range(dry_run):
train(training, model, device)

# reload to reset internal state
model = load_model(model_name)
training = get_data_loader(raw_train, batch_size, True)
validation = get_data_loader(raw_validation, batch_size, False)
for epoch in range(epochs):
e_start = time.time()
train(training, model, device)
e_end = time.time()

e_time = e_end - e_start
loss, correct, total = validate_learner(
validation, model, device)
info.append((dev, model_name, dataset, rep, epoch,
e_time, loss, correct, total))
print(model_name, dataset, rep, epoch,
e_time, loss, '{}/{}'.format(correct, total))
time.sleep(4)

# dump to CSV
filename = os.path.join(output_dir, 'pt-train.csv')
with open(filename, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for row in info:
writer.writerow({
fieldnames[i]: row[i]
for i in range(len(fieldnames))
})
except Exception as e:
write_status(output_dir, False,
'Encountered exception: {}'.format(render_exception(e)))
return 1

write_status(output_dir, True, 'success')
return 0


if __name__ == '__main__':
invoke_main(main, 'config_dir', 'output_dir')
Loading

0 comments on commit 898fdbd

Please sign in to comment.