Skip to content

jax device array on cpu #6055

Answered by jakevdp
JiahaoYao asked this question in Q&A
Mar 13, 2021 · 1 comments · 1 reply
Discussion options

You must be logged in to vote

Thanks for the question! There is a long discussion on this at #4486; the short answer is that whenever possible, CPU-backed DeviceArrays will be created from Numpy arrays as views (without copying the data), but XLA has some restrictions on byte alignment that can necessitate copies in some situations.

For example:

import jax.numpy as jnp
import numpy as np

# Creating a numpy array from a JAX array
x_jax = jnp.arange(10)
x_numpy = np.asarray(x_jax)
print("pointers match:", np.byte_bounds(x_numpy)[0] == x_jax.device_buffer.unsafe_buffer_pointer())
# pointers match: True

# Creating a jax array from a numpy array
x_numpy = np.arange(10)
x_jax = jnp.asarray(x_numpy)
print("pointers match:", 

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@JiahaoYao
Comment options

Answer selected by JiahaoYao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants