Replies: 3 comments 1 reply
-
I've noticed the same, but for sufficiently large problems the jax BFGS outperforms scipy. Could be "initialization" overhead? |
Beta Was this translation helpful? Give feedback.
-
I thought so too, but even if you call the jax optimization a few times consecutively, the time required for each calculation remains roughly the constant. I wonder if jax for some reason tries to compute the derivative, while scipy does everything numerically. |
Beta Was this translation helpful? Give feedback.
-
For operations on very small arrays, JAX tends to be orders of magnitude slower than numpy, and this is expected because JAX's computation model involves some per-computation overhead that is non-negligible for small operations. Here's a simpler example: In [1]: import numpy as np
In [2]: import jax.numpy as jnp
In [3]: x = np.random.rand(10, 10)
In [4]: %timeit np.dot(x, x)
1.05 µs ± 21.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [5]: y = jnp.asarray(x)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
In [6]: %timeit jnp.dot(y, y).block_until_ready() ``
129 µs ± 5.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) If you do operations on larger arrays, with proportionally more time spent on the actual computation, JAX is often faster. For example: In [7]: x = np.random.rand(2000, 2000)
In [8]: %timeit np.dot(x, x)
96.3 ms ± 6.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [9]: y = jnp.asarray(x)
In [10]: %timeit jnp.dot(y, y).block_until_ready()
45.9 ms ± 602 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) In your case, your operations are on very small arrays, so I wouldn't expect JAX to out-perform numpy/scipy in terms of total computation time. |
Beta Was this translation helpful? Give feedback.
-
I'm comparing the BFGS solver in
scipy
andjax.scipy
and I am seeing a huge slowdown in jax.Is the jax implementation of the algorithm fundamentally different (see also the output of the function and jacobian evaluations
res.nfev, res.njev
below)? Or am I using it inappropriately?Here's a minimal example:
which gives the output:
Beta Was this translation helpful? Give feedback.
All reactions