-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlevelsets.py
2477 lines (1724 loc) · 97.1 KB
/
levelsets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import gzip
import os
import pprint
import subprocess
import sys
import time
from operator import itemgetter
import diffrax
import equinox
import flax
import ipdb
import jax
import jax.numpy as np
import matplotlib
import matplotlib.pyplot as pl
import meshcat
import meshcat.geometry as geom
import meshcat.transformations as tf
import numpy as onp
import tqdm
import nn_utils
import plotting_utils
import pontryagin_utils
import trajax_refsol
import visualiser
import wandb
from misc import *
# helper functions {{{
save_memory = True
def def_v_meanstds(v_nn):
def v_meanstd(x, vmap_params):
# find (empirical) mean and std. dev of value function.
vs_ensemble = jax.vmap(v_nn, in_axes=(0, None))(vmap_params, x)
v_mean = vs_ensemble.mean()
v_std = vs_ensemble.std()
return v_mean, v_std
def vx_meanstd(x, vmap_params):
# vmap for nn ensemble.
vx_fct = jax.jacobian(v_nn, argnums=1)
ensemble_vxs = jax.vmap(vx_fct, in_axes=(0, None))(vmap_params, x)
# now we have all_vxs.shape == (N_ensemble, nx)
# we want ensemble mean and std across axis 0.
# stds will be individual for each coordinate, sum/mean whatever later if you want.
vx_mean = ensemble_vxs.mean(axis=0)
vx_std = ensemble_vxs.std(axis=0)
return vx_mean, vx_std
v_meanstds = jax.vmap(v_meanstd, in_axes=(0, None))
vx_meanstds = jax.vmap(vx_meanstd, in_axes=(0, None))
return v_meanstds, vx_meanstds
def find_min_l(ys, v_lower, v_upper, problem_params):
# find the smallest value of l(x, u) in the given value band
# in the dataset.
# yet another alternative: from each trajectory find the highest-v
# point below v_upper. not constant size but let's not care about those
# superficialities
# same as above
all_traj_idx = np.arange(ys['v'].shape[0])
top_per_traj_idx = np.argmax(ys['v'] * (ys['v'] < v_upper), axis=1)
ys_top = jtm(lambda node: node[all_traj_idx, top_per_traj_idx], ys)
# but only use the trajectories that did not yet stop.
# = trajectories that have some v > v_upper.
crosses_v_upper = ((ys['v'] >= v_upper) & (ys['v'] < np.inf)).any(axis=1)
# but now caluclate all those l(x, u).
def l_of_y(y):
x = y['x']
vx = y['vx']
u = pontryagin_utils.u_star_general(x, vx, problem_params)
return problem_params['l'](x, u)
ls = jax.vmap(l_of_y)(ys_top)
ls_relevant = ls + np.nan * (~crosses_v_upper)
min_l = np.nanmin(ls_relevant)
return min_l
def set_value_target(all_ys, v_k, problem_params, algo_params):
# value target = largest possible value target that still
# contains trajectories with time duration <= T_value_target.
# with this data based min_l as a surrogate for the actual
# min l over the next level band.
# use actual previous value level instead?
min_l = find_min_l(all_ys, v_k/2, v_k, problem_params)
# so min value step to ensure horizon <= T is T * smallest dv/dt
# min l = min dv/dt
v_step = algo_params['T_value_target'] * min_l
v_next = v_k + v_step
return v_next
def forward_sim_nn(x0, v_nn, params, problem_params, algo_params, ensemble=True, T=10.):
# now simulates with state = {'x': system state, 'cost': control cost}.
if ensemble:
# we have a whole NN ensemble. use the mean here.
# v_nn_unnormalised_single = lambda params, x: normaliser.unnormalise_v(v_nn(params, normaliser.normalise_x(x)))
# mean across only axis resulting in a scalar. differentiate later.
v_fct = lambda x: jax.vmap(v_nn, in_axes=(0, None))(params, x).mean()
else:
v_fct = lambda x: v_nn(params, x)
def forwardsim_rhs(t, y, args):
x = y['x']
cost = y['cost']
lam_x = jax.jacobian(v_fct)(x).squeeze()
# lam_x = P_lqr @ x # <- for lqr instead
u = pontryagin_utils.u_star_general(x, lam_x, problem_params)
return {
'x': problem_params['f'](x, u),
'cost': problem_params['l'](x, u),
}
term = diffrax.ODETerm(forwardsim_rhs)
step_ctrl = diffrax.PIDController(
atol=algo_params['pontryagin_solver_atol'],
rtol=algo_params['pontryagin_solver_rtol'],
dtmin=algo_params['dtmin'],
dtmax=algo_params['dtmax'],
)
saveat = diffrax.SaveAt(steps=True, dense=True, t0=True, t1=True)
if problem_params['m'] is not None and algo_params['project_manifold']:
# projection only for state, not cost ofc
project = lambda y: {'x': problem_params['project_M'](y['x']), 'cost': y['cost']}
solver = pontryagin_utils.ProjectionSolver(project=project)
else:
solver = diffrax.Tsit5()
# start at 0 cost. the incurred cost is integrated up. in the end
# to estimate inf horizon cost, either integrate for very long, or add terminal LQR.
y0 = {
'x': x0,
'cost': 0.,
}
forward_sol = diffrax.diffeqsolve(
term, solver, t0=0., t1=T, dt0=0.01, y0=y0,
stepsize_controller=step_ctrl, saveat=saveat,
max_steps = algo_params['pontryagin_solver_maxsteps'],
throw=algo_params['throw'],
)
return forward_sol
def meshcat_forward_sims(x0s, v_nn, nn_params, problem_params, algo_params):
# just a couple of steps I find myself doing in pdb all the time
sim = lambda x0: forward_sim_nn(x0, v_nn, nn_params, problem_params, algo_params)
trajs = jax.vmap(sim)(x0s)
# convert to old (theta) repr. ugly hardcoded i know
ys = jax.vmap(jax.vmap(lambda x: np.concatenate([x[0:2], np.array([np.arctan2(x[2], x[3])]), x[4:]])))(trajs.ys['x'])
solsdict = {'t': trajs.ts, 'x': ys}
visualiser.plot_trajectories_meshcat(solsdict)
# also plot initial values.
# pl.figure('meshcat sims: initial v mean/std')
# v_means, v_stds = v_meanstds(x0s, nn_params)
# ts = np.linspace(0, 1, x0s.shape[0])
# pl.plot(ts, v_means, c='C0', label='v mean')
# pl.fill_between(ts, v_means-v_stds, v_means+v_stds, color='C0', alpha=.2, label='1σ confidence')
# pl.legend()
# pl.show()
# }}}
# define main active learning ingredients: prune & train function {{{
def prune_and_train(key, v_nn, params_sobolev_ens, all_ys, v_interval, previously_suboptimal, problem_params, algo_params, warmstart=False, is_final=False):
# these steps:
# 1. mark data which we already know to be suboptimal as such
# (based on knowing a better solution at that point already)
# 2. train the nn for the 1st time
# while gradually expanding the domain of training data (algo_params['nn_value_sweep'])
# with huber type losses to not break everything on conflicting data
# 3. remove (= mark suboptimal) all the data that falls into the linear huber regions
# meaning the NN could not fit it (easily enough).
# 4. train a second time with this "cleaned" dataset, just to remove the
# artefacts from outlier data in first training round, by settling into the
# equilibrium between gradient & weight decay.
# use the is_final flag to do the final training round. changes these things if True:
# - no pruning before training, all data & suboptimality flags are used as is
# - overrides thin_data, we need the whole dataset
# 0. redefining functions that were previously stolen from main's scope {{{
def def_v_meanstds(v_nn):
def v_meanstd(x, vmap_params):
# find (empirical) mean and std. dev of value function.
vs_ensemble = jax.vmap(v_nn, in_axes=(0, None))(vmap_params, x)
v_mean = vs_ensemble.mean()
v_std = vs_ensemble.std()
return v_mean, v_std
def vx_meanstd(x, vmap_params):
# vmap for nn ensemble.
vx_fct = jax.jacobian(v_nn, argnums=1)
ensemble_vxs = jax.vmap(vx_fct, in_axes=(0, None))(vmap_params, x)
# now we have all_vxs.shape == (N_ensemble, nx)
# we want ensemble mean and std across axis 0.
# stds will be individual for each coordinate, sum/mean whatever later if you want.
vx_mean = ensemble_vxs.mean(axis=0)
vx_std = ensemble_vxs.std(axis=0)
return vx_mean, vx_std
v_meanstds = jax.vmap(v_meanstd, in_axes=(0, None))
vx_meanstds = jax.vmap(vx_meanstd, in_axes=(0, None))
return v_meanstds, vx_meanstds
v_meanstds, vx_meanstds = def_v_meanstds(v_nn)
# }}}
# 1. mark clearly suboptimal data. {{{
v_lower, v_upper = v_interval
# now without the extra dim the vmap we already did is sufficient
# v_nn_means, v_nn_stds = v_meanstds(all_ys['x'], params_sobolev_ens)
pruning_metrics = {}
if is_final:
# in the final round just use the suboptimality flag as is
is_suboptimal = previously_suboptimal
else:
# this operation here apparently often causes RESOURCE_EXHAUSTED.
# while allocating 1.7 GB -- so we probably already have quite a few
# things going on elsewhere. still, it seems wasteful to allocate all
# this memory. for relatively small outputs. (is it copying nn_params
# every time?)
# do it with scan instead?
if save_memory:
# or map even nicer. no need for carry, almost same as vmap.
v_meanstds_params = lambda ys: v_meanstds(ys, params_sobolev_ens)
v_nn_means, v_nn_stds = jax.lax.map(v_meanstds_params, all_ys['x'])
else:
# old version. certified the same.
v_nn_means, v_nn_stds = jax.vmap(v_meanstds, in_axes=(0, None))(all_ys['x'], params_sobolev_ens)
# otherwise we might throw out some data already here if we know a better solution
if algo_params['pruning_strategy'] in ('conservative', 'conservative_past'):
# these two are now the same -- the cumsum step which
# differentiated them is now done after all these strategies. no
# matter how we conclude suboptimality of any point, the preceding
# ones will also be suboptimal due to dynamic programming principle
# be conservative: only prune POINTS (not trajectories) that
# definitely (with high prob) are outside of value level set
nn_v_likely_in_levelset = v_nn_means + 3 * v_nn_stds < v_lower
trajectory_outside_levelset = v_lower < all_ys['v']
is_suboptimal = trajectory_outside_levelset & nn_v_likely_in_levelset
elif algo_params['pruning_strategy'] == 'generous':
# start with pointwise pruning mask from conservative strategy.
# delete not only the points preceding any suboptimal point, but
# also the ones after it, as long as they are above the currently
# known value level.
nn_v_likely_in_levelset = v_nn_means + 3 * v_nn_stds < v_lower
trajectory_outside_levelset = v_lower < all_ys['v']
point_is_suboptimal = trajectory_outside_levelset & nn_v_likely_in_levelset
# clear out everything above the lower value level if there is a suboptimal point in the trajectory.
is_suboptimal = point_is_suboptimal.any(axis=1)[:, None] & (all_ys['v'] >= v_lower)
else:
pruning_strategy = algo_params['pruning_strategy']
raise ValueError(f'unknown pruning strategy "{pruning_strategy}"')
# do this cumsum step here? for ANY pruning strategy this is the
# reasonable last step...
# time goes from 0.0 at idx 0 to negative values at idx 1, 2, ... so
# cumsum marks as suboptimal the PRECEDING points in physical time even
# though in array indices they are the subsequent ones. all correct.
is_suboptimal = np.cumsum(is_suboptimal, axis=1) > 0
# keep suboptimal points marked suboptimal
is_suboptimal = np.logical_or(previously_suboptimal, is_suboptimal)
# next step: build training data out of this pruned mess.
in_band = (all_ys['v'] <= v_upper)
# this has proven not to be a great idea...
if algo_params['include_future_data']:
# just randomly throw in a bit more data for training.
v_upper_train = v_upper + (v_upper - v_lower)
in_band = (all_ys['v'] <= v_upper_train)
if algo_params['thin_data'] and not is_final:
# much simpler strategy: just exclude way past data.
v_cutoff = v_lower / algo_params['thin_data_denominator']
in_band = in_band & (v_cutoff <= all_ys['v'])
bool_train_idx = in_band & ~is_suboptimal
# }}}
# 2. train the NN for the first time. {{{
usable_ys = jax.tree_util.tree_map(lambda node: node[bool_train_idx], all_ys)
# split into train/test set.
train_ys, test_ys = nn_utils.train_test_split(usable_ys, train_frac=algo_params['nn_train_fraction'])
train_key, key = jax.random.split(key)
params_old = params_sobolev_ens
# final training round flag. in the final training round we have saved
# data from the run and want to retrain the nn with ALL data. but we
# have no nn params, we cannot warmstart.
# if is_final:
# warmstart = False
# do not do this ^^ anymore, decide for yourself with the warmstart flag if you want it
if warmstart:
# continue from previous params, only last portion of training.
# since we are doing this sweep, can we do EVERYTHING with tiny learning rate instead?
params_sobolev_ens, oups_sobolev_ens = v_nn.train_sobolev_ensemble_warmstarted(
train_key, train_ys, v_lower, v_upper, params_sobolev_ens, problem_params, algo_params
)
else:
# training from scratch
# raise NotImplementedError('are you sure? not really doing this anymore. plz implement v sweep here too')
# BUT with twice v_upper -- these two are for the sweep, NOT the entire value interval.
# and in this case we want no sweep, we want all data at once.
params_sobolev_ens, oups_sobolev_ens = v_nn.train_sobolev_ensemble(
train_key, train_ys, v_upper, v_upper, problem_params, algo_params
)
n_params = count_floats(params_sobolev_ens) / algo_params['nn_ensemble_size']
n_data = count_floats(train_ys)
pruning_metrics['params_data_ratio'] = n_params / n_data
# mean of the last couple iterations.
final_trainloss = oups_sobolev_ens['lossterms']['total_loss'][:, -100:].mean()
# and loss over test set.
test_losses, test_lossterms = jax.vmap(v_nn.sobolev_loss_batch_mean, in_axes=(None, 0, None, None, None))(key, params_sobolev_ens, test_ys, problem_params, algo_params)
final_testloss = np.mean(test_losses)
# weight norm is not stochastically approximated so we can use just the last one.
final_weightnorm = oups_sobolev_ens['weight_norm'][:, -1].mean()
pruning_metrics['final_trainloss'] = final_trainloss
pruning_metrics['final_testloss'] = final_testloss
pruning_metrics['final_weightnorm'] = final_weightnorm
# }}}
# in the final training round we already know all the suboptimality
# flags and presume they are correct. so this is not needed anymore.
if not is_final:
# 3. classify outliers
# {{{
# evaluate all this stuff again yolo
v_means_trained, v_stds_trained = v_meanstds(usable_ys['x'], params_sobolev_ens)
vx_means_trained, vx_stds_trained = vx_meanstds(usable_ys['x'], params_sobolev_ens)
# sobolev loss inner must be vmapped along axes y, v_pred, vx_pred.
# this means: in_axes = (None, 0, 0, 0, None, None)
all_losses, all_auxs = jax.vmap(v_nn.sobolev_loss_inner, in_axes = (None, 0, 0, 0, None, None))(
key, usable_ys, v_means_trained, vx_means_trained, problem_params, algo_params
)
v_outliers = all_auxs['v_loss_linear']
print(f'v outliers: {100*v_outliers.mean():.3f}%')
vx_outliers = all_auxs['vx_loss_linear']
print(f'vx outliers: {100*vx_outliers.mean():.3f}%')
print(f'either outliers: {100*(vx_outliers|v_outliers).mean():.3f}%')
print(f'both outliers: {100*(vx_outliers&v_outliers).mean():.3f}%')
# these boolean idxs are all with respect to usable_ys.
is_outlier = all_auxs['vx_loss_linear'] | all_auxs['v_loss_linear'] # is & better here?
is_new = (v_lower <= usable_ys['v']) & (usable_ys['v'] <= v_upper)
new_suboptimal = is_outlier & is_new
# now: update the full is_suboptimal array with these new indices
# is_suboptimal[bool_train_idx] = new_suboptimal
# (by construction of train idx, is_suboptimal[bool_train_idx] == False
is_suboptimal = is_suboptimal.at[bool_train_idx].set(new_suboptimal)
# then the cumsum thing
is_suboptimal = np.cumsum(is_suboptimal, axis=1) > 0
# }}}
# 4. second training run.
# {{{
# TODO last thing: switch huber loss to quadratic loss here.
# --> probably not relevant if removing outliers anyway.
bool_train_idx = in_band & ~is_suboptimal
usable_ys = jax.tree_util.tree_map(lambda node: node[bool_train_idx], all_ys)
train_ys, test_ys = nn_utils.train_test_split(usable_ys, train_frac=algo_params['nn_train_fraction'])
# shorter second training run, just to find an equilibrium of data vs weight decay.
algo_params_second = algo_params.copy()
algo_params_second['lr_init'] = algo_params['lr_final']
algo_params_second['nn_N_epochs'] = algo_params['nn_N_epochs'] / 8
if warmstart:
# continue from previous params, only last portion of training.
# also don't do the sweep anymore -- always sample up to v_upper.
params_sobolev_ens, oups_sobolev_ens_new = v_nn.train_sobolev_ensemble_warmstarted(
train_key, train_ys, v_upper, v_upper, params_sobolev_ens, problem_params, algo_params_second
)
else:
# training from scratch
raise NotImplementedError('are you sure? not really doing this anymore. plz implement v sweep here too')
params_sobolev_ens, oups_sobolev_ens_new = v_nn.train_sobolev_ensemble(
train_key, train_ys, problem_params, algo_params_second
)
# mean of the last couple iterations.
final_trainloss = oups_sobolev_ens_new['lossterms']['total_loss'][:, -100:].mean()
# and loss over test set.
test_losses, test_lossterms = jax.vmap(v_nn.sobolev_loss_batch_mean, in_axes=(None, 0, None, None, None))(key, params_sobolev_ens, test_ys, problem_params, algo_params)
final_testloss = np.mean(test_losses)
final_weightnorm = oups_sobolev_ens_new['weight_norm'][:, -1].mean()
pruning_metrics['final_trainloss_second'] = final_trainloss
pruning_metrics['final_testloss_second'] = final_testloss
pruning_metrics['final_weightnorm_second'] = final_weightnorm
# all these shapes are (N_nn_ensemble, N_trainsteps) -- ofc we want concatenation along trainsteps
oups_sobolev_ens = jtm(lambda a, b: np.concatenate([a, b], axis=1), oups_sobolev_ens, oups_sobolev_ens_new)
# }}}
return params_sobolev_ens, oups_sobolev_ens, is_suboptimal, pruning_metrics
# }}}
def main(problem_params, algo_params):
print(f'jax default backend = {jax.default_backend()}')
pl.rcParams['figure.figsize'] = (16, 10)
key = jax.random.PRNGKey(algo_params['seed'])
# find terminal LQR & xfs {{{
# then define a function unitsphere_to_dXf, which we then feed with uniform
# points from the unitsphere to arrive at boundary conditions for first
# backward shooting step. (if initial_shooting == 'lqr' not quite)
# in manifold case, this is still something which we should do purely
# on the tangent space...
if problem_params['m'] is not None:
K_lqr, P_lqr, Proj_tangent = pontryagin_utils.get_terminal_lqr(problem_params, return_tangent_projection=True)
# find the LQR controller in tangent space basis
# essentially undo what we did inside the lqr function...
P_lqr_tangent = Proj_tangent @ P_lqr @ Proj_tangent.T
K_lqr_tangent = K_lqr @ Proj_tangent.T
# here we can do cholesky just fine
L_lqr_tangent = np.linalg.cholesky(P_lqr_tangent)
assert rnd(L_lqr_tangent @ L_lqr_tangent.T, P_lqr_tangent) < 1e-6, 'cholesky decomposition wrong or inaccurate'
# same as below, except we project the (ambient space) point to the tangent space
# and then back. this has no affine part though and the nonzero x_eq is disregarded,
# so we change it to a lambda function which includes that.
# and finally, we project back to the manifold using the provided function.
unitsphere_to_dXf_linear = Proj_tangent.T @ np.linalg.inv(L_lqr_tangent) @ Proj_tangent * np.sqrt(problem_params['V_f']) * np.sqrt(2)
unitsphere_to_dXf = lambda x: problem_params['project_M'](problem_params['x_eq'] + x.T @ unitsphere_to_dXf_linear)
else:
# state space R^n
K_lqr, P_lqr = pontryagin_utils.get_terminal_lqr(problem_params)
# find a matrix mapping from the unit circle to the value level set
# previously done with eigendecomposition -> sqrt of eigenvalues.
# with the cholesky decomp the trajectories look about the same
# qualitatively. it is nicer so we'll keep that.
# cholesky decomposition says: P = L L.T but not L.T L
L_lqr = np.linalg.cholesky(P_lqr)
assert rnd(L_lqr @ L_lqr.T, P_lqr) < 1e-6, 'cholesky decomposition wrong or inaccurate'
# linear map from the hypersphere to the ellipse V_lqr(x) == V_f
unitsphere_to_dXf = lambda x: problem_params['x_eq'] + x.T @ np.linalg.inv(L_lqr) * np.sqrt(problem_params['V_f']) * np.sqrt(2)
# set xfs for initial batch of trajectories, depending on chosen
# method.
if algo_params['initial_shooting'] == 'uniform':
# purely random ass points for initial batch of trajectories.
normal_pts = jax.random.normal(key, shape=(algo_params['initial_batchsize'], problem_params['nx']))
unitsphere_pts = normal_pts / np.linalg.norm(normal_pts, axis=1)[:, None]
xfs = jax.vmap(unitsphere_to_dXf)(unitsphere_pts)
elif algo_params['initial_shooting'] == 'lqr':
def forward_sim_lqr_until_value(x0, P_lqr, v_goal):
# simulate forward using LQR value function.
# stop once we hit the desired value.
def forwardsim_rhs(t, x, args):
lam_x = P_lqr @ (x - problem_params['x_eq']) # <- for lqr instead
u = pontryagin_utils.u_star_general(x, lam_x, problem_params)
# u = -K_lqr @ (x - problem_params['x_eq'])
return problem_params['f'](x, u)
term = diffrax.ODETerm(forwardsim_rhs)
step_ctrl = diffrax.PIDController(
atol=algo_params['pontryagin_solver_atol'],
rtol=algo_params['pontryagin_solver_rtol'],
dtmin=algo_params['dtmin'],
dtmax=algo_params['dtmax'],
)
saveat = diffrax.SaveAt(steps=True, dense=True, t0=True, t1=True)
def event_fn(state, **kwargs):
x_err = state.y - problem_params['x_eq']
lqr_value = 0.5 * x_err @ P_lqr @ x_err
return lqr_value <= v_goal
terminating_event = diffrax.DiscreteTerminatingEvent(event_fn)
if problem_params['m'] is not None and algo_params['project_manifold']:
solver = pontryagin_utils.ProjectionSolver(project=problem_params['project_M'])
else:
solver = diffrax.Tsit5()
forward_sol = diffrax.diffeqsolve(
term, solver, t0=0., t1=10., dt0=0.01, y0=x0,
stepsize_controller=step_ctrl, saveat=saveat,
max_steps = algo_params['pontryagin_solver_maxsteps'],
throw=algo_params['throw'],
discrete_terminating_event=terminating_event,
)
return forward_sol
# sample uniform random points from surface of unit ball ||x|| = 1
key, normalkey = jax.random.split(key)
normal_pts = jax.random.normal(key, shape=(algo_params['initial_batchsize'], problem_params['nx']))
unitball_pts = normal_pts / np.linalg.norm(normal_pts, axis=1)[:, None]
# probably interior is more 'correct' here but surface should work
# too. to make it the interior i think it is multiplication by
# Unif([0, 1]) ** (1/n).
# copied from above but with higher value level
if problem_params['m'] is not None:
unitsphere_to_dV_linear = Proj_tangent.T @ np.linalg.inv(L_lqr_tangent) @ Proj_tangent * np.sqrt(algo_params['v_init']) * np.sqrt(2)
unitsphere_to_dV = lambda x: problem_params['project_M'](problem_params['x_eq'] + x.T @ unitsphere_to_dV_linear)
else:
unitsphere_to_dV = lambda x: problem_params['x_eq'] + x.T @ np.linalg.inv(L_lqr) * np.sqrt(algo_params['v_init']) * np.sqrt(2)
x0s = jax.vmap(unitsphere_to_dV)(unitball_pts)
sols = jax.vmap(forward_sim_lqr_until_value, in_axes=(0, None, None))(x0s, P_lqr, problem_params['V_f'])
# pl.figure('forward solver m(x)')
# pl.plot(jax.vmap(jax.vmap(problem_params['m']))(sols.ys).T, c='black', alpha=.1)
# pl.show()
xfs_unprojected = jax.vmap(lambda sol: sol.ys[sol.stats['num_accepted_steps']])(sols)
xfs = jax.vmap(problem_params['project_M'])(xfs_unprojected)
# mark the ones that stopped due to time or step limit as unusable
# because only the ones stopped due to DiscreteTerminatingEvent reached
# the low value sublevel set where we accept the LQR solution.
stopped_bc_terminatingevent = sols.result == 1
xfs = xfs.at[~stopped_bc_terminatingevent].set(np.nan)
else:
name = algo_params['initial_shooting']
raise ValueError(f'initial shooting method {name} does not exist')
# test if it worked
V_f = lambda x: 0.5 * (x - problem_params['x_eq']).T @ P_lqr @ (x - problem_params['x_eq'])
vfs = jax.vmap(V_f)(xfs)
# }}}
# define lots of boring ass functions {{{
solve_backward, f_extended = pontryagin_utils.define_backward_solver(
problem_params, algo_params
)
def solve_backward_lqr(x_f, algo_params):
# P_lqr = hessian of value fct.
# everything else follows from usual differentiation rules.
# V_f = lambda x: 0.5 * (x - problem_params['x_eq']).T @ P_lqr @ (x - problem_params['x_eq'])
v_f = V_f(x_f)
vx_f = jax.jacobian(V_f)(x_f)
state_f = {
'x': x_f,
't': 0,
'v': v_f,
'vx': vx_f,
}
if algo_params['pontryagin_solver_vxx']:
vxx_f = P_lqr
state_f['vxx'] = vxx_f
return solve_backward(state_f, v_upper=10. * algo_params['v_init'])
sols_orig = jax.vmap(solve_backward_lqr, in_axes=(0, None))(xfs, algo_params)
if problem_params['system_name'] == 'orbits' and algo_params['showfigs']:
thetas = np.linspace(-np.pi, np.pi, 300)
circle = np.array([np.sin(thetas), np.cos(thetas)]).T
pl.plot(circle[:, 0], circle[:, 1], c='black', alpha=.1, linestyle='--')
if algo_params['initial_shooting'] == 'lqr':
pl.plot(*np.split(sols.ys.reshape(-1, 2), [1], axis=1), '.-', label='forward sols', alpha=.1)
pl.plot(*np.split(sols_orig.ys['x'].reshape(-1, 2), [1], axis=1), '.-', label='backward sols', alpha=.1)
pl.ylim([0.9, 1.1]); pl.xlim([-0.3, 0.3])
pl.legend()
pl.show()
def select_train_pts(value_interval, sols):
# this is basically repeated in prune_and_train, should we always just use that one?
# (old) ideas for additional functionality:
# - include not only strictly the value interval, but at least n_min pts from each trajectory.
# so that if no points happen to be within the value band we include a couple (lower) ones
# to still hopefully improve the fit.
# - return only a random subsample of the data (with a specified fraction)
# - throw away points of the same trajectory that are closer than some threshold (in time or state space?)
# this is also a form of subsampling but maybe better than random.
v_lower, v_upper = value_interval
v_finite = np.logical_and(~np.isnan(sols.ys['v']), ~np.isinf(sols.ys['v']))
v_in_interval = np.logical_and(sols.ys['v'] >= v_lower, sols.ys['v'] <= v_upper)
# sols.ys['vxx'].shape == (N_trajectories, N_ts, nx, nx)
# get the frobenius norms of the hessian & throw out large ones.
if 'vxx' in sols.ys:
vxx_norms = np.linalg.norm(sols.ys['vxx'], axis=(2, 3))
vxx_acceptable = vxx_norms < algo_params['vxx_max_norm'] # some random upper bound based on looking at a plot of v vs ||vxx||
bool_train_idx = np.logical_and(v_in_interval, vxx_acceptable)
else:
bool_train_idx = v_in_interval
all_ys = jtm(lambda node: node[bool_train_idx], sols.ys)
perc = 100 * bool_train_idx.sum() / v_finite.sum()
print(f'full (train+test) dataset size: {bool_train_idx.sum()} points (= {perc:.2f}% of valid points)')
n_data = count_floats(all_ys)
print(f'corresponding to {n_data} degrees of freedom')
# check if there are still NaNs left -- should not be the case.
contains_nan = jtm(lambda n: np.isnan(n).any(), all_ys)
contains_nan_any = jax.tree_util.tree_reduce(operator.or_, contains_nan)
if contains_nan_any:
print('There are still NaNs in training data. dropping into debugger. have fun')
ipdb.set_trace()
return all_ys
def v_meanstd(x, vmap_params):
# find (empirical) mean and std. dev of value function.
vs_ensemble = jax.vmap(v_nn, in_axes=(0, None))(vmap_params, x)
v_mean = vs_ensemble.mean()
v_std = vs_ensemble.std()
return v_mean, v_std
def vx_meanstd(x, vmap_params):
# vmap for nn ensemble.
vx_fct = jax.jacobian(v_nn, argnums=1)
ensemble_vxs = jax.vmap(vx_fct, in_axes=(0, None))(vmap_params, x)
# now we have all_vxs.shape == (N_ensemble, nx)
# we want ensemble mean and std across axis 0.
# stds will be individual for each coordinate, sum/mean whatever later if you want.
vx_mean = ensemble_vxs.mean(axis=0)
vx_std = ensemble_vxs.std(axis=0)
return vx_mean, vx_std
v_meanstds = jax.jit(jax.vmap(v_meanstd, in_axes=(0, None)))
vx_meanstds = jax.jit(jax.vmap(vx_meanstd, in_axes=(0, None)))
def plot_decision_boundary(v_nn, vmap_params, problem_params):
# this x0 i got from random idpb experimentation by finding the
# points with (0, -1) angle and lowest value. among those it is the
# one with largest x.
x0 = np.array([ 2.398837 , 0.06769013, 0. , -1. , -1.166603 , 3.3332477 , -5.1929855 ], dtype=float)
x1 = x0 * np.array([-1, 1, -1, 1, -1, 1, -1])
ts = np.linspace(-1, 1, 200)
xs = np.linspace(x0, x1, 200)
mus, sigmas = v_meanstds(xs, vmap_params)
ax = pl.subplot(211)
pl.plot(ts, mus, label='value mean')
pl.fill_between(ts, mus - sigmas, mus + sigmas, color='C0', alpha=.2, label=f'value 1σ confidence')
pl.legend()
vx_mu, vx_sigma = vx_meanstds(xs, vmap_params)
pl.subplot(212, sharex=ax)
pl.plot(ts, vx_mu, label=problem_params['state_names'])
pl.gca().set_prop_cycle(None)
for j in range(7):
pl.fill_between(ts, vx_mu[:, j] - vx_sigma[:, j], vx_mu[:, j] + vx_sigma[:, j], alpha=.2)
pl.legend()
def plot_v_vx_line(xs, vmap_params):
# xs = problem_params['project_M'](np.linspace(x0, x1, N))
mus, sigmas = v_meanstds(xs, vmap_params)
vx_mu, vx_sigma = vx_meanstds(xs, vmap_params)
ax = pl.subplot(211)
pl.plot(thetas, mus, label='value mean')
pl.fill_between(thetas, mus - sigmas, mus + sigmas, color='C0', alpha=.2, label=f'value 1σ confidence')
pl.legend()
pl.subplot(212, sharex=ax)
pl.plot(thetas, vx_mu, label=problem_params['state_names'])
pl.gca().set_prop_cycle(None)
for j in range(7):
pl.fill_between(thetas, vx_mu[:, j] - vx_sigma[:, j], vx_mu[:, j] + vx_sigma[:, j], alpha=.2)
pl.legend()
def forward_sim_nn_until_value(x0, params, v_k, vmap=False):
# also simulates forward, but stops once we are with high probability
# inside the value level set v_k AND we have sufficiently low sigma.
# only vmap=True is tested as of now.
if vmap:
# we have a whole NN ensemble. use the mean here.
# v_nn_unnormalised_single = lambda params, x: normaliser.unnormalise_v(v_nn(params, normaliser.normalise_x(x)))
# mean across only axis resulting in a scalar. differentiate later.
v_fct = lambda x: jax.vmap(v_nn_unnormalised, in_axes=(0, None))(params, x).mean()
else:
v_fct = lambda x: v_nn_unnormalised(params, x)
def forwardsim_rhs(t, x, args):
lam_x = jax.jacobian(v_fct)(x).squeeze()
# lam_x = P_lqr @ x # <- for lqr instead
u = pontryagin_utils.u_star_general(x, lam_x, problem_params)
return problem_params['f'](x, u)
term = diffrax.ODETerm(forwardsim_rhs)
step_ctrl = diffrax.PIDController(
atol=algo_params['pontryagin_solver_atol'],
rtol=algo_params['pontryagin_solver_rtol'],
dtmin=algo_params['dtmin'],
dtmax=algo_params['dtmax'],
)
saveat = diffrax.SaveAt(steps=True, dense=True, t0=True, t1=True)
# additionally, terminating event.
# only works for vmapped NN ensemble!
if not vmap:
raise NotImplementedError('only vmapped (NN ensemble) case implemented here.')
def event_fn(state, **kwargs):
# another stopping condition could be much more simply: v_std < some limit?
# then we continue a bit if it happens to not be that way right at the edge
# of the value level set.
v_mean, v_std = v_meanstd(state.y, params)
# we only quit once we're very sure that we're in the value level set.
# thus we take an upper confidence band = overestimated value function = inner approx of level set
# return (v_mean + 2 * v_std <= v_k).item() # if meanstd returns arrays of shape (), not floats
is_very_likely_in_Vk = v_mean + 2 * v_std <= v_k
sigma_max = algo_params['sigma_max_abs'] + v_mean * algo_params['sigma_max_rel']
has_low_sigma = v_std <= sigma_max
# return is_very_likely_in_Vk
return np.logical_and(is_very_likely_in_Vk, has_low_sigma)
terminating_event = diffrax.DiscreteTerminatingEvent(event_fn)
if problem_params['m'] is not None and algo_params['project_manifold']:
solver = pontryagin_utils.ProjectionSolver(project=problem_params['project_M'])
else:
solver = diffrax.Tsit5()
forward_sol = diffrax.diffeqsolve(
term, solver, t0=0., t1=10., dt0=0.01, y0=x0,
stepsize_controller=step_ctrl, saveat=saveat,
max_steps = algo_params['pontryagin_solver_maxsteps'],
throw=algo_params['throw'],
discrete_terminating_event=terminating_event,
)
return forward_sol
def solve_backward_nn_ens(x_f, vmap_params, v_upper, problem_params, algo_params):
v_fct = lambda x: jax.vmap(v_nn_unnormalised, in_axes=(0, None))(vmap_params, x).mean()
v_f = v_fct(x_f)
vx_f = jax.jacobian(v_fct)(x_f)
v_f_lqr = V_f(x_f)
vx_f_lqr = jax.jacobian(V_f)(x_f)
# if v_f_lqr < problem_params['V_f'], use that information instead.
use_lqr = v_f_lqr < problem_params['V_f']
v_f = jax.lax.select(use_lqr, v_f_lqr, v_f)
vx_f = jax.lax.select(use_lqr, vx_f_lqr, vx_f)
state_f = {
'x': x_f,
't': 0,
'v': v_f,
'vx': vx_f,
}
# if manifold, backproject here.
if problem_params['m'] is not None:
# easy part: project x to the manifold.
state_f['x'] = problem_params['project_M'](x_f)
# now we want to set the costate to 0 in the "irrelevant" normal direction.
# get normal & tangent space projections just like in nn_utils
B = jax.jacobian(problem_params['m'])(x_f)
assert B.shape == (problem_params['nx'],), 'only manifolds of codimension 1 supported rn'
B = B / np.linalg.norm(B)
# orthogonal projection to normal space at current x
P_normal = np.outer(B, B)
# orthogonal projection to tangent space at current x
P_tangent = np.eye(problem_params['nx']) - P_normal
# from this construction we have P_normal + P_tangent = I. can we
# thus just project a costate onto the tangent space? will this
# work out?
# the costate is in T*xM, the cotangent space, whereas the state
# derivative is in TxM. Together they can form the inner product
# <lambda, xdot> as they often do, which equals d/dt V(x(t)).
# We decompose lambda:
# lambda = (P_normal + P_tangent) lambda = lambda_normal + lambda_tangent.
# the inner product becomes <lambda, xdot> =
# = <P_normal lambda, xdot> + <P_tangent lambda, xdot>
# = lambda.T P_normal.T xdot + lambda.T P_tangent.T xdot | writing it out in R^n standard basis
# = <lambda, P_normal.T xdot> + <lambda, P_tangent.T xdot> | changing parentheses without effect & writing as inner product again
# = <lambda, P_normal xdot> + <lambda, P_tangent xdot> | projection matrices symmetric
# = 0 + <lambda, P_tangent xdot> | normal space is orthogonal to tangent space of which xdot is an element
# thus, we see that we can arbitrarily modify the costate in
# normal direction without affecting the relevant inner products.
# this is kind of obvious right? more formally this means
# (something like) the canonical map from T*x R^n to T*x M is a