Replies: 1 comment 1 reply
-
For this basic scenario, you could do something based on import jax.numpy as jnp
import jax.scipy as jsp
m = jnp.array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
vals = jnp.array([2, 6, 12])
masked = jnp.where(jnp.isin(m, vals), m, 0)
convolved = jsp.signal.convolve(masked, jnp.ones((1, 3)), mode='same').astype(m.dtype)
result = jnp.where(convolved != 0, convolved, m)
print(result)
This would have to be modified if |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone,
As the title suggests, I want to assign different values to elements at different positions in an array based on multiple conditions simultaneously. Here’s a simplified example: given an array
m = [ [0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14] ]
I want to modify the elements adjacent to the positions where
m == 2, 6, 12
by setting them to 2, 6, and 12, respectively. The result should be:m = [ [0, 2, 2, 2, 4], [6, 6, 6, 8, 9], [10, 12, 12, 12, 14] ]
I used vmap and where functions to achieve this, with code as follows:
My general approach is to use vmap to locate all positions, modify them individually, and assign a value of 0 to positions that haven’t been modified. Then, I sum everything. Finally, I use the
where
function to keep original values at unmodified positions, while for modified positions, I assign the accumulated values from the previous step.However, I feel this method is overly complicated, and there might be a more efficient solution. Does anyone have any suggestions?
Beta Was this translation helpful? Give feedback.
All reactions