Skip to content

How to shard a jax.Array object over an axis with odd-valued length. #25236

Answered by bmaxdk
jwtkeeble asked this question in Q&A
Discussion options

You must be logged in to vote

Using Pspec_data fails since 5 is not divisible by 4, JAX cannot evenly distribute the rows across the devices, As you said in order to answer your question, you may want to pad the array along the axis.

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'  # Use 8 CPUs as our 'devices'

import jax
from jax import numpy as jnp

jax.config.update('jax_platforms','cpu')

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.debug import visualize_array_sharding

jax.config.update('jax_platforms', 'cpu')

n_devices = jax.device_count()
print(f"There are {n_devices} device(s) available")

# Where you see the # ValueError: One of device_put args wa…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by jwtkeeble
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants