forked from clvrai/furniture
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
110 lines (87 loc) · 3.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
""" Launch RL training and evaluation. """
import sys
import signal
import os
import json
import numpy as np
import torch
from six.moves import shlex_quote
from mpi4py import MPI
from config import argparser
from rl.trainer import Trainer
from util.logger import logger
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)
def run(config):
"""
Runs Trainer.
"""
rank = MPI.COMM_WORLD.Get_rank()
config.rank = rank
config.is_chef = rank == 0
config.seed = config.seed + rank
config.num_workers = MPI.COMM_WORLD.Get_size()
if config.is_chef:
logger.warn('Run a base worker.')
make_log_files(config)
else:
logger.warn('Run worker %d and disable logger.', config.rank)
import logging
logger.setLevel(logging.CRITICAL)
def shutdown(signal, frame):
logger.warn('Received signal %s: exiting', signal)
sys.exit(128+signal)
signal.signal(signal.SIGHUP, shutdown)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
# set global seed
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
os.environ["DISPLAY"] = ":1"
if config.gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(config.gpu)
assert torch.cuda.is_available()
config.device = torch.device("cuda")
else:
config.device = torch.device("cpu")
# build a trainer
trainer = Trainer(config)
if config.is_train:
trainer.train()
logger.info("Finish training")
else:
trainer.evaluate()
logger.info("Finish evaluating")
def make_log_files(config):
"""
Sets up log directories and saves git diff and command line.
"""
config.run_name = 'rl.{}.{}.{}'.format(config.env, config.prefix, config.seed)
config.log_dir = os.path.join(config.log_root_dir, config.run_name)
logger.info('Create log directory: %s', config.log_dir)
os.makedirs(config.log_dir, exist_ok=True)
config.record_dir = os.path.join(config.log_dir, 'video')
logger.info('Create video directory: %s', config.record_dir)
os.makedirs(config.record_dir, exist_ok=True)
if config.is_train:
# log git diff
cmds = [
"echo `git rev-parse HEAD` >> {}/git.txt".format(config.log_dir),
"git diff >> {}/git.txt".format(config.log_dir),
"echo 'python -m rl {}' >> {}/cmd.sh".format(
' '.join([shlex_quote(arg) for arg in sys.argv[1:]]),
config.log_dir),
]
os.system("\n".join(cmds))
# log config
param_path = os.path.join(config.log_dir, 'params.json')
logger.info('Store parameters in %s', param_path)
with open(param_path, 'w') as fp:
json.dump(config.__dict__, fp, indent=4, sort_keys=True)
if __name__ == '__main__':
args, unparsed = argparser()
if len(unparsed):
logger.error('Unparsed argument is detected:\n%s', unparsed)
else:
run(args)