-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_utils.py
105 lines (72 loc) · 3.03 KB
/
eval_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import diffrax
import ipdb
import jax
import jax.numpy as np
import pontryagin_utils
def closed_loop_eval_nn_ensemble(problem_params, algo_params, V_nn, nn_params, x0s):
def u_star_fct(x):
# get optimal control input with NN ensemble.
nx = problem_params['nx']
assert x.shape == (nx,)
# here we vmap the apply function along the 0 axis of nn_params, and
# at the same time kind of un-vmap it along the xs axis, as we have
# just one x here.
xs = x[None, :]
costate_estimates = jax.vmap(V_nn.apply_grad, in_axes=(0, None))(nn_params, xs)
costate_estimates = costate_estimates.reshape(algo_params['nn_ensemble_size'], nx)
costate = costate_estimates.mean(axis=0)
# what kind of idiot wrote this function
# u_star = pontryagin_utils.u_star_costate(
# problem_params['f'],
# problem_params['l'],
# costate,
# 0., # t = 0
# x,
# problem_params['nx'],
# problem_params['nu'],
# problem_params['U_interval'],
# )
# so easy so nice. yet untested but worked fine in data generation
u_star = pontryagin_utils.u_star_new(x, costate, problem_params)
return u_star
return closed_loop_eval_general(problem_params, algo_params, u_star_fct, x0s)
def closed_loop_eval_general(problem_params, algo_params, ustar_fct, x0s):
# make closed loop simulations, and record the control cost.
# ustar_fct should be a jax function, X -> U. no time dependence.
T = algo_params['sim_T']
dt = algo_params['sim_dt']
f = problem_params['f']
l = problem_params['l']
nx = problem_params['nx']
@jax.jit
def dynamics_extended(t, z, args=None):
# one extra state to record control cost.
# optimal input is applied here according to the model.
x, cost = np.split(z, [nx])
ustar = ustar_fct(x)
xdot = f(t, x, ustar)
cost_dot = l(t, x, ustar)
zdot = np.concatenate([xdot, cost_dot.reshape(1,)])
return zdot
def solve_single(x0):
term = diffrax.ODETerm(dynamics_extended)
solver = diffrax.Tsit5() # recommended over usual RK4/5 in docs
max_steps = int(T / dt)
saveat = diffrax.SaveAt(steps=True)
# initialise 0 control cost
z0 = np.concatenate([x0, np.zeros(1)])
solution = diffrax.diffeqsolve(
term, solver, t0=0, t1=T, dt0=dt, y0=z0,
saveat=saveat, max_steps=max_steps,
)
# this should return the last calculated (= non-inf) solution.
return solution
# sol = solve_single(np.array([0, 2]))
solve_multiple = jax.vmap(solve_single)
all_sols = solve_multiple(x0s)
return all_sols
def compute_controlcost(problem_params, all_sols):
# all_sols exactly as returned by the closed_loop_eval_general.
control_costs = all_sols.ys[:, -1, problem_params['nx']]
mean_cost = control_costs.mean()
return mean_cost