forked from cyankaet/orderml
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
60 lines (56 loc) · 1.81 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
from orderenv import OrderEnv
from ornlenv import OrnlEnv
data = [[] for i in range(6)]
model = PPO.load("/wrk/kmm11/orderout/models/thirdpaperrun") #if you're 'playing' a saved model
episodes = 1
max_steps = 15
actions = [[] for i in range(episodes)]
temps = [[] for i in range(episodes)]
for i in range(episodes):
env = OrnlEnv()
#env.setVars(tns, jts, nfs, bks)
# wrap it
vec_env = make_vec_env(lambda: env, n_envs=1)
# model = PPO('MlpPolicy',vec_env, verbose=1)
finalTn = 0
finalJt = 0
finalNf = 0
finalBk = 0
obs = vec_env.reset()
fixedTn, fixedJt, fixedNf, fixedBk = env.getFixedVars()
end_step = 0
temp = 3
print(env.getFixedVars())
for step in range(max_steps):
action, _ = model.predict(obs, deterministic=True)
print("Step {}".format(step + 1))
print("Action: ", action)
obs, reward, done, info = vec_env.step(action)
print('obs=', obs, 'reward=', reward, 'done=', done)
actions[i].append(action[0])
temps[i].append(obs[0][0])
if done:
finalTn, finalJt, finalNf, finalBk = env.getVars()
end_step = step +1
print("Goal reached!", "reward=", reward)
break
if abs(finalTn - fixedTn) < 0.1 and abs(finalJt - fixedJt) < 0.1 and abs(finalNf - fixedNf) < 0.1 and abs(finalBk - fixedBk):
#print("Success for tn: ", tns)
success = 1
else:
#print("Fail for tn:", tns)
success = 0
data[0].append(fixedTn)
data[1].append(fixedJt)
data[2].append(fixedNf)
data[3].append(fixedBk)
data[4].append(success)
data[5].append(end_step)
#print(data)
savedir = '/wrk/kmm11/orderout/thirdpapertest/'
np.savetxt(savedir + "successdata.npy", data)
np.savetxt(savedir + 'temps.npy', temps)
np.savetxt(savedir + 'acts.npy', actions)