You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm having trouble finding a fast and scalable way of computing the following in JAX:
Consider a neural network f(x, theta), where x is an input point and theta denotes the set of p network parameters. For simplicity, let's assume the network has a single output dimension. Denoting the Jacobian with respect to the network parameters evaluated at a single input point, x_i, by J(x_i; theta) and noting that J(x_i; theta) is a p-dimensional vector [J(x_i; theta)_1, ..., J(x_i; theta)_p], I would like to compute the gradient with respect to theta of the sum of pointwise-squared Jacobians over for n input points, i.e., I would like to compute thep-dimensional vector
v = grad_theta sum_{i=1}^n \sum_{j=1}^p J(x_i; theta)^2_j
# write grad_theta explicitly:
= d / d theta ( sum_{i=1}^n \sum_{j=1}^p J(x_i; theta)^2_j )
# move d / d theta into sum over x_i:
= sum_{i=1}^n ( d / d theta ( \sum_{j=1}^p J(x_i; theta)^2_j ) )
# move d / d theta into sum over parameter indices and write as p-dimensional vector:
= [
sum_{i=1}^n ( d / d theta_1 ( ( d f(x; theta) / d theta_1 )^2 ) ),
...,
sum_{i=1}^n ( d / d theta_p ( ( d f(x; theta) / d theta_p )^2 ) )
]
However, I've encountered the following two problems:
Even for a relatively small number of network parameters (~1 million), I get OOM errors when computing the gradients for more than ~10 input points (using jitting and a GTX 1080 with ~12GB of memory).
Even when setting the number of input points to <= 10, computing the gradients becomes prohibitively slow for networks with >1 million parameters. I need to compute the p-dimensional gradient vector at each gradient step when training a neural network, so each computation needs to be fast.
Ideally, I would like to be able to compute the gradient over a mini-batch of size 128 for networks as large as ResNet18 (or larger if possible) without getting OOM errors, with each gradient computation taking as little time as possible (<= 0.05s/gradient computation on a single GTX 1080). For reference, my current implementation, which naively computes the sum over the squared Jacobian entries and input points and uses JAX's grad to compute the gradients, takes ~0.05s/gradient step for n=10 on a small network (3-layer CNN) and 0.2s/gradient step for n=10 on a slightly larger network (8-layer CNN, ~500k parameters). See below for sample code.
Is there any way to use JAX's functionality to compute the gradients v efficiently and in a way that scales to large networks and to n of up to ~200? Any pointers or suggestions as to how this could be done or how to improve the code below would be greatly appreciated!
Thank you!
from typing import Callable, Union
import haiku as hk
import jax
import tree
from jax import numpy as jnp
import numpy as np
def sum_jac(
fwd_fn: Callable, params: hk.Params
) -> jnp.ndarray:
"""
@param fwd_fn: a function that only takes in parameters and returns model output of shape (batch_dim, output_dim)
@param params: the model parameters
"""
jacobian = jax.jacobian(fwd_fn)(params)
def _get_diag_mat(jac):
# jac has shape (batch_dim, output_dim, params_dims...)
# jac_2D has shape (batch_dim * output_dim, nb_params)
batch_dim, output_dim = jac.shape[:2]
jac_2D = jnp.reshape(jac, (batch_dim * output_dim, -1))
mat = jnp.einsum("ij,ji->i", jac_2D, jac_2D.T)
mat = jnp.reshape(mat, (batch_dim, output_dim))
# mat has shape (batch_dim, output_dim)
return mat
diag_mat = tree.map_structure(_get_diag_mat, jacobian)
diag_mat_sum_array = jnp.stack(tree.flatten(diag_mat), axis=0).sum(axis=0)
return diag_mat_sum_array.sum()
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
I'm having trouble finding a fast and scalable way of computing the following in JAX:
Consider a neural network
f(x, theta)
, wherex
is an input point and theta denotes the set ofp
network parameters. For simplicity, let's assume the network has a single output dimension. Denoting the Jacobian with respect to the network parameters evaluated at a single input point,x_i
, byJ(x_i; theta)
and noting thatJ(x_i; theta)
is ap
-dimensional vector[J(x_i; theta)_1, ..., J(x_i; theta)_p]
, I would like to compute the gradient with respect to theta of the sum of pointwise-squared Jacobians over forn
input points, i.e., I would like to compute thep
-dimensional vectorHowever, I've encountered the following two problems:
p
-dimensional gradient vector at each gradient step when training a neural network, so each computation needs to be fast.Ideally, I would like to be able to compute the gradient over a mini-batch of size 128 for networks as large as ResNet18 (or larger if possible) without getting OOM errors, with each gradient computation taking as little time as possible (<= 0.05s/gradient computation on a single GTX 1080). For reference, my current implementation, which naively computes the sum over the squared Jacobian entries and input points and uses JAX's
grad
to compute the gradients, takes ~0.05s/gradient step forn=10
on a small network (3-layer CNN) and 0.2s/gradient step forn=10
on a slightly larger network (8-layer CNN, ~500k parameters). See below for sample code.Is there any way to use JAX's functionality to compute the gradients
v
efficiently and in a way that scales to large networks and ton
of up to ~200? Any pointers or suggestions as to how this could be done or how to improve the code below would be greatly appreciated!Thank you!
Beta Was this translation helpful? Give feedback.
All reactions