-
Hi, i understand if use jax device on gpu, there is the time for transferring the data back and forth between gpus and cpus. if I only use jax on cpu, is there such a problem? import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
A_jnp = jnp.array(A)
A_jnp.device_buffer.device()
% CpuDevice(id=0) So, does the cpu device means |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! There is a long discussion on this at #4486; the short answer is that whenever possible, CPU-backed DeviceArrays will be created from Numpy arrays as views (without copying the data), but XLA has some restrictions on byte alignment that can necessitate copies in some situations. For example: import jax.numpy as jnp
import numpy as np
# Creating a numpy array from a JAX array
x_jax = jnp.arange(10)
x_numpy = np.asarray(x_jax)
print("pointers match:", np.byte_bounds(x_numpy)[0] == x_jax.device_buffer.unsafe_buffer_pointer())
# pointers match: True
# Creating a jax array from a numpy array
x_numpy = np.arange(10)
x_jax = jnp.asarray(x_numpy)
print("pointers match:", np.byte_bounds(x_numpy)[0] == x_jax.device_buffer.unsafe_buffer_pointer())
# pointers match: False
# Creating a jax array from a byte-aligned numpy array
x_numpy = np.asarray(x_jax) # Numpy array that is a view of a buffer that meets XLA's alginment restrictions.
x_jax_2 = jnp.asarray(x_numpy)
print("pointers match:", np.byte_bounds(x_numpy)[0] == x_jax_2.device_buffer.unsafe_buffer_pointer())
# pointers match: True |
Beta Was this translation helpful? Give feedback.
Thanks for the question! There is a long discussion on this at #4486; the short answer is that whenever possible, CPU-backed DeviceArrays will be created from Numpy arrays as views (without copying the data), but XLA has some restrictions on byte alignment that can necessitate copies in some situations.
For example: