-
Hi All, I'm trying to shard a large array over 2 devices via Is there an easy way to fix this? Or do I need to simply pad the Array with zeros, and subsequently mask these padded entries, to ensure each shard has the same shape?
Using
Using
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Using import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' # Use 8 CPUs as our 'devices'
import jax
from jax import numpy as jnp
jax.config.update('jax_platforms','cpu')
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.debug import visualize_array_sharding
jax.config.update('jax_platforms', 'cpu')
n_devices = jax.device_count()
print(f"There are {n_devices} device(s) available")
# Where you see the # ValueError: One of device_put args was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'model': 2), spec=PartitionSpec('data', None), memory_kind=unpinned_host), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 5 (full shape: (5, 6))
# Here create a mesh with 4 'data' axes and 2 'model' axes
mesh = jax.make_mesh(axis_shapes=(n_devices // 2, 2),
axis_names=('data', 'model'))
values = jnp.arange(5 * 6).reshape(5, 6)
print(f"Check original values shape: {values.shape}")
# Calculate padding needed along with the data axis to make it divisible by 4
pad_size = (-values.shape[0]) % mesh.shape['data'] # pad_size = 3
values_padded = jnp.pad(values, ((0, pad_size), (0, 0)), mode='constant', constant_values=0)
print(f"Check padded values shape: {values_padded.shape}")
# Pspec_model = P(None,'model') # fully replica across 'data', shard across 'model' [WORKS]
# ->> Pspec_data = P('data', None) # shard across 'data', fully replica across 'model' [FAILS]
# Use PartitionSpec to shard along 'data' axis
Pspec_data = P('data', None) # Shard across 'data', fully replicate across 'model'
# Create a sharding object using the mesh and PartitionSpec
my_sharding = NamedSharding(mesh, Pspec_data)
# my_sharding = NamedSharding(mesh, Pspec_model)
# shard_values = jax.device_put(values, device=my_sharding) # shard `values` over the 'model' axis
shard_values = jax.device_put(values_padded, device=my_sharding)
# Visualize the array sharding
visualize_array_sharding(arr=shard_values)
here is more useful source |
Beta Was this translation helpful? Give feedback.
Using
Pspec_data
fails since 5 is not divisible by 4, JAX cannot evenly distribute the rows across the devices, As you said in order to answer your question, you may want to pad the array along the axis.