TypeError: Argument 'cpu:0' of type <class 'jaxlib.xla_extension.CpuDevice'> is not a valid JAX type. #5821
Replies: 12 comments 1 reply
-
When you call desc = jax.device_put(descriptors,jax.devices()[0])
matches_jit = jit(match_faces3)
%timeit matches_jit(desc) Also, as a side note – JAX's runtime model is asynchronous, so if you're timing operations you should use the %timeit matches_jit(desc).block_until_ready() See Asynchronous Dispatch for more information. |
Beta Was this translation helpful? Give feedback.
-
I tried this previously but got this Exception, so tried that way. Exception: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float32[512])>with<DynamicJaxprTrace(level=0/1)>. This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using |
Beta Was this translation helpful? Give feedback.
-
It is hard to say for sure because the example you gave is incomplete, but I suspect |
Beta Was this translation helpful? Give feedback.
-
I changed cos into a jax function but the same error. Actually this is the original code which I have run using numpy, Im using jax to reduce the time taken, distances = np.empty((len(descriptors), len(database)))
f.write("Descriptors")
f.write(str(descriptors))
time1 = time.time()
for i, desc in enumerate(descriptors):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(desc, id_desc))
distances[i][j] = dist[np.argmin(dist)]
time2 = time.time() - time1
print("time2",time2)
f.write("Distances")
f.write(str(distances)) |
Beta Was this translation helpful? Give feedback.
-
Thanks for the further information. If you want further help debugging this, I'd suggest putting together a minimal reproducible example and including it here: that will take the guesswork out of helping find the cause of your error. |
Beta Was this translation helpful? Give feedback.
-
[('8174', descriptors = np.array([[-1.69,4.4,4.27,1.96,2.7,-5.73,-5.41,1.12,2.5,2.09,5.8,-8.7,6.7,4.1,7.4,6.3,9.7,2.4,6.4,3.3]]) This is how the descriptors look like The cosine_func is a normal cos function. |
Beta Was this translation helpful? Give feedback.
-
Thanks - rather than describing your code in words (what does "normal cos function" mean?) it would be more helpful if you could include a complete code snippet, i.e. something that someone could copy and paste into a runtime and see the same results you are seeing. The minimal reproducible example link from my earlier comment is a helpful resource with more details on how to construct such an example. |
Beta Was this translation helpful? Give feedback.
-
import numpy as np
import time,datetime
from scipy.spatial.distance import cosine
import jax
import jax.numpy as jnp
f = open("write_out.txt", 'w+')
def cosine_dist(x, y):
return cosine(x, y) * 0.5
distances = np.empty((len(descriptors), len(database)))
f.write("Descriptors")
f.write(str(descriptors))
time1 = time.time()
for i, desc in enumerate(descriptors):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(desc, id_desc))
distances[i][j] = dist[np.argmin(dist)]
time2 = time.time() - time1
print("time2",time2)
f.write("Distances")
f.write(str(distances)) The jax which I tried to implement is distances = np.empty((len(desc), len(database)))
f.write("Descriptors")
f.write(str(desc))
time1 = time.time()
for i, descr in enumerate(desc):
for j, identity in enumerate(database):
dist = []
for k, id_desc in enumerate(identity[1]):
dist.append(cosine_dist(descr, id_desc))
distances[i][j] = dist[jnp.argmin(jnp.asarray(dist))]
time2 = time.time() - time1
print("time2",time2)
distances=distances.tolist()
f.write(str(distances)) When I just changed this Jax time was more, so I tried putting it in a jit function. |
Beta Was this translation helpful? Give feedback.
-
Thanks! This is much more clear. The issue is that from scipy.spatial.distance import cosine
from jax import jit
import jax.numpy as jnp
jit(cosine)(jnp.arange(4), jnp.arange(4)) Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[4])>with<DynamicJaxprTrace(level=0/1)>.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`. If you want to JIT-compile this computation, you'll have to implement the cosine distance using JAX functions; for example: import numpy as np
import jax.numpy as jnp
from jax import jit
from scipy.spatial.distance import cosine
@jit
def jax_cosine(u, v):
return 1 - jnp.dot(u, v) / (jnp.linalg.norm(u) * jnp.linalg.norm(v))
u = np.random.rand(100)
v = np.random.rand(100)
print(np.allclose(cosine(u, v), jax_cosine(u, v)))
# True Additionally, if you want your implementation to have good performance, you should avoid iteration over array axes in favor of vectorized operations via broadcasting or through tools like vmap. |
Beta Was this translation helpful? Give feedback.
-
Thank you in match_faces3(desc) TypeError: list indices must be integers or slices, not DynamicJaxprTracer @jakevdp |
Beta Was this translation helpful? Give feedback.
-
That makes sense: you're trying to iterate over a traced array. As I mentioned, you should avoid explicit iteration in favor of either vectorized operations or vmap. |
Beta Was this translation helpful? Give feedback.
-
Actually, looking more closely the issue is that |
Beta Was this translation helpful? Give feedback.
-
TypeError: Argument 'cpu:0' of type <class 'jaxlib.xla_extension.CpuDevice'> is not a valid JAX type.
I encountered this error while trying to run the below code.
Desc is a multidimensional array
I tried everything but couldn't figure out the problem. Please help @mattjj
Originally posted by @Joy-Preetha in #4416 (comment)
Beta Was this translation helpful? Give feedback.
All reactions