Replies: 1 comment 2 replies
-
It seems like you're calling jax.jit on every step. I think you should call jit and vmap before you begin running your loop, then call the jitted function in the loop. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am new to JAX. I find JAX interesting is because I feel JAX numpy array might be a potential tool to write if-else workflow in GPU.
With initial trials, it seems at least JAX can reach the same mathematical behavior, however the speed is super slow. I think that might be due to the fact my use case involves a lot of if-else branching? This is unavoidable, but my question is this, in general, when evaluating if-else statement, what JAX is going to do? Is it returning the evaluation from GPU back to CPU like a print statement, if so then I do not think JAX is suitable for my problem.
To make my question more concrete, for example, in the following code, I have a simple step(), it checks if the agent tries to cross the boundary, or if the agent hits the landmark. This is a typical skeleton of a path search code. You can see, it has many (seemingly) unavoidable if-else statements and for loop, I struggled hard to make it @jit compatible, and it can run under VMAP, so it can in parallel bundle multiple runners doing the same thing. But it is super slow (several hundred millisecond for this simple computation) compared with a plain numpy counterpart by a few thousand times.
Since I am new to JAX, I would like to see suggestions on my problem, is JAX suitable and how? If not, any advice to approaching this kind of problem? Thank you.
%timeit -n 5 -r 5 jax_step(actions, states, world_boundary, landmarks)
5 loops, best of 5: 291 ms per loop
Beta Was this translation helpful? Give feedback.
All reactions