-
Quick discussion on best JIT usage. if we have a loop and inside the loop there is a function that can't be JIT was is the best way to JIT the whole thing:
I was thinking we might be able to JIT the whole thing and ignore cannot_be_jit() like
But a) not sure how to do this and b) not sure what that gets us in efficiency over just jitting can_be_jit() Any insight would be greatly appreciated! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
When you JIT-compile a function, everything within the function is JIT-compiled, and there is no way to call non-JIT compiled code from within JIT-compiled code (the reason for this is essentially that non-JIT code is executed in Python, while JIT compiled code is executed in XLA, and there's no easy way for XLA to call back to Python) If you want to mix JIT and non-JIT code, you can do that by only JIT-compiling the pieces you want to be JIT-compiled. In your example, it might be something like this: jitted_can_be_jit = jit(can_be_jit)
def loop_function(x):
for _ in range(10):
x = cannot_be_jit(x)
jitted_can_be_jit(x) |
Beta Was this translation helpful? Give feedback.
When you JIT-compile a function, everything within the function is JIT-compiled, and there is no way to call non-JIT compiled code from within JIT-compiled code (the reason for this is essentially that non-JIT code is executed in Python, while JIT compiled code is executed in XLA, and there's no easy way for XLA to call back to Python)
If you want to mix JIT and non-JIT code, you can do that by only JIT-compiling the pieces you want to be JIT-compiled. In your example, it might be something like this: