forked from wiseodd/last_layer_laplace
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhessian.py
108 lines (100 loc) · 3.3 KB
/
hessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
##########################################################################
#
# Courtesy of Felix Dangel: https://github.com/f-dangel/backpack
#
##########################################################################
"""Exact computation of full Hessian using autodiff."""
from torch import cat, zeros, stack
from torch.autograd import grad
from tqdm import tqdm, trange
def exact_hessian(f, parameters, show_progress=False):
r"""Compute all second derivatives of a scalar w.r.t. `parameters`.
The order of parameters corresponds to a one-dimensional
vectorization followed by a concatenation of all tensors in
`parameters`.
Parameters
----------
f : scalar torch.Tensor
Scalar PyTorch function/tensor.
parameters : list or tuple or iterator of torch.Tensor
Iterable object containing all tensors acting as variables of `f`.
show_progress : bool
Show a progressbar while performing the computation.
Returns
-------
torch.Tensor
Hessian of `f` with respect to the concatenated version
of all flattened quantities in `parameters`
Note
----
The parameters in the list are all flattened and concatenated
into one large vector `theta`. Return the matrix :math:`d^2 E /
d \theta^2` with
.. math::
(d^2E / d \theta^2)[i, j] = (d^2E / d \theta[i] d \theta[j]).
The code is a modified version of
https://discuss.pytorch.org/t/compute-the-hessian-matrix-of-a-
network/15270/3
"""
params = list(parameters)
if not all(p.requires_grad for p in params):
raise ValueError("All parameters have to require_grad")
df = grad(f, params, create_graph=True)
# flatten all parameter gradients and concatenate into a vector
dtheta = None
for grad_f in df:
dtheta = (
grad_f.contiguous().view(-1)
if dtheta is None
else cat([dtheta, grad_f.contiguous().view(-1)])
)
# compute second derivatives
hessian_dim = dtheta.size(0)
hessian = zeros(hessian_dim, hessian_dim)
progressbar = tqdm(
iterable=range(hessian_dim),
total=hessian_dim,
desc="[exact] Full Hessian",
disable=(not show_progress),
)
for idx in progressbar:
df2 = grad(dtheta[idx], params, create_graph=True)
d2theta = None
for d2 in df2:
d2theta = (
d2.contiguous().view(-1)
if d2theta is None
else cat([d2theta, d2.contiguous().view(-1)])
)
hessian[idx] = d2theta
return hessian
def exact_hessian_diagonal_blocks(f, parameters, show_progress=True):
"""Compute diagonal blocks of a scalar function's Hessian.
Parameters
----------
f : scalar of torch.Tensor
Scalar PyTorch function
parameters : list or tuple or iterator of torch.Tensor
List of parameters whose second derivatives are to be computed
in a blockwise manner
show_progress : bool, optional
Show a progressbar while performing the computation.
Returns
-------
list of torch.Tensor
Hessian blocks. The order is identical to the order specified
by `parameters`
Note
----
For each parameter, `exact_hessian` is called.
"""
return [exact_hessian(f, [p], show_progress=show_progress)
for p in parameters]