-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
68 lines (57 loc) · 2.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gym
import envs
from IPython.display import clear_output
import matplotlib.pyplot as plt
import os
from datetime import datetime
import torch
import numpy as np
import csv
def make_env(env_name):
def _thunk():
env = gym.make(env_name)
return env
return _thunk
def plot(frame_idx, rewards):
clear_output(True)
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title("frame %s. reward: %s" % (frame_idx, rewards[-1]))
plt.plot(rewards)
plt.show()
class save_files:
def __init__(self):
self.date = datetime.now().strftime("%Y_%m_%d_%I_%M_%S_%p")
self.current_dir = os.getcwd()
self.path_step_reward = "results/reward_step"
self.path_best_reward = f"results/bestreward{self.date}"
self.path_model = f"results/model{self.date}"
self._save_init(self.path_step_reward)
self._save_init(self.path_best_reward)
self._save_init(self.path_model)
self.index = 1
fields = ['counter', 'step', 'reward']
with open(f"{self.path_step_reward}/reward_step{self.date}.csv", "a") as f:
writer = csv.writer(f)
writer.writerow(fields)
def _save_init(self, directory):
self.path = os.path.join(self.current_dir, directory)
if not os.path.exists(self.path):
os.makedirs(self.path)
def best_reward_save(self, all_t, all_actions, all_obs, all_rewards, control_rewards, header):
date = datetime.now().strftime("%Y_%m_%d-%I_%M_%S_%p")
np.savetxt(
f"{self.path_best_reward}/best_rewards{date}.csv",
np.c_[all_t, all_actions, all_obs, all_rewards, control_rewards],
delimiter=",",
header=header,
)
def reward_step_save(self, best_rew, longest_step, curr_tot_rew, curr_step):
fields = [self.index, curr_step, float(curr_tot_rew)]
with open(f"{self.path_step_reward}/reward_step{self.date}.csv", "a") as f:
writer = csv.writer(f)
writer.writerow(fields)
self.index += 1
def model_save(self, model):
date = datetime.now().strftime("%Y_%m_%d_%I_%M_%S_%p")
torch.save(model.state_dict(), f"{self.path_model}/model{date}.pt")