Skip to content

Commit

Permalink
[PoC][MoE & EP] integrate with FSDP & CP
Browse files Browse the repository at this point in the history
ghstack-source-id: aecf4d7daa0885a2184d1780ab8c29ec61b310af
Pull Request resolved: #726
  • Loading branch information
tianyu-l committed Dec 10, 2024
1 parent 8dc9d33 commit f7fb732
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 90 deletions.
11 changes: 11 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ def __init__(self):
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--experimental.expert_parallel_degree",
type=int,
default=1,
help="""
Expert parallelism degree. 1 means disabled.
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree.
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree,
where k >= 1 and k | data_parallel_shard_degree.
""",
)
self.parser.add_argument(
"--experimental.expert_parallel_mode",
type=str,
Expand Down
46 changes: 30 additions & 16 deletions torchtitan/parallelisms/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,22 +329,37 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:

# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ParallelStyle):
def __init__(
self,
*,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to device_mesh,
# as there's an issue from DeviceMesh that
# "Cannot create a submesh from a submesh."
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

@staticmethod
def _prepare_input_fn(mod, inputs, device_mesh):
def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh):
input_tensor = inputs[0]
# input_tensor of placements Shard(1) on the tp mesh
assert not isinstance(input_tensor, DTensor)

# a2a(ep)
input_tensor = DTensor.from_local(input_tensor, device_mesh["dp"], (Shard(1),))
input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local()
# ag(tp)
input_tensor = DTensor.from_local(input_tensor, device_mesh["tp"], (Shard(1),))
input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Replicate(),))

return input_tensor

def _partition_fn(self, name, module, device_mesh):
# NOTE: the following code should work when FSDP is applied on the non-expert modules.
@staticmethod
def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh):
# TODO: FSDP doesn't support sharding a 2D Tensor
# module.register_parameter(
# "gate_proj",
# nn.Parameter(
Expand All @@ -364,9 +379,8 @@ def _partition_fn(self, name, module, device_mesh):
# ),
# ) # Column-wise sharding

# NOTE: the following code works when FSDP is not applied.
# TODO: the above 2D sharding (only on experts) causes optimizer foreach to fail
# TODO: apply FSDP on the non-expert params should resolve the issue
# NOTE: instead, for MoE experts, we shard on the EP mesh and then "forget" it
# TODO: this is problematic from the DCP perspective
module.register_parameter(
"gate_proj",
nn.Parameter(
Expand All @@ -376,7 +390,7 @@ def _partition_fn(self, name, module, device_mesh):
module.gate_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
device_mesh["tp"],
tp_mesh,
(Shard(2),),
)
),
Expand All @@ -390,7 +404,7 @@ def _partition_fn(self, name, module, device_mesh):
module.down_proj, device_mesh, [Shard(0), Shard(1)]
).to_local()
),
device_mesh["tp"],
tp_mesh,
(Shard(1),),
)
),
Expand All @@ -404,20 +418,20 @@ def _partition_fn(self, name, module, device_mesh):
module.up_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
device_mesh["tp"],
tp_mesh,
(Shard(2),),
)
),
) # Column-wise sharding

@staticmethod
def _prepare_output_fn(mod, outputs, device_mesh):
def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh):
# outputs of placements Partial() on the tp mesh

# rs(tp)
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
# a2a(ep)
outputs = DTensor.from_local(outputs, device_mesh["dp"], (Shard(0),))
outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),))
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()

return outputs
Expand All @@ -426,7 +440,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
self._partition_fn,
self._prepare_input_fn,
self._prepare_output_fn,
partial(self._partition_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh),
)
119 changes: 99 additions & 20 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@ class ParallelDims:
cp: int
tp: int
pp: int
ep: int
ep_mode: str
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp, ep = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
self.ep,
)
for d in (dp_replicate, cp, tp, pp):
for d in (dp_replicate, cp, tp, pp, ep):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

Expand All @@ -41,17 +44,80 @@ def _validate(self):
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

if ep > 1:
assert self.ep_mode in ["tp", "tp2ep", "dp2ep"]
if self.ep_mode == "tp" or self.ep_mode == "tp2ep":
assert ep == tp
elif self.ep_mode == "dp2ep":
# EP would borrow all cp and some dp_shard degree
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
else:
self.ep_mode = "none"

def build_mesh_with_dp2ep(self, device_type):
# In dp2ep, dp_shard and ep are derived submeshes:
# dp_shard = dp_shard_1 * dp_shard_2
# ep = dp_shard_2 * cp
dp_shard_1 = self.dp_shard * self.cp // self.ep
dp_shard_2 = self.ep // self.cp

dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"],
):
# dp_shard_1 is needed even if it's 1, whose FSDP wrapping
# helps the MoE layers do mixed precision training
if d > 1 or name == "dp_shard_1":
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading
dp_mesh_dim_names = []
if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_mesh_dim_names.append("dp_shard_1")
if "dp_shard_2" in names:
dp_mesh_dim_names.append("dp_shard_2")
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
dp_shard_cp_mesh_dim_name.append("dp_shard_1")
if "dp_shard_2" in names:
dp_shard_cp_mesh_dim_name.append("dp_shard_2")
if self.cp_enabled:
dp_shard_cp_mesh_dim_name.append("cp")
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")

# Mesh for ep
ep_mesh_dim_names = []
if "dp_shard_2" in names:
ep_mesh_dim_names.append("dp_shard_2")
if self.cp_enabled:
ep_mesh_dim_names.append("cp")
assert len(ep_mesh_dim_names) > 0
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")

return mesh

def build_mesh(self, device_type):
if self.ep_mode == "dp2ep":
return self.build_mesh_with_dp2ep(device_type)

dims = []
names = []
for d, name in zip(
Expand All @@ -60,26 +126,35 @@ def build_mesh(self, device_type):
):
if d > 1:
dims.append(d)
if (name == "dp_replicate" and self.dp_shard == 1) or (
name == "dp_shard" and self.dp_replicate == 1
):
names.append("dp")
else:
names.append(name)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized
if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP
mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp")
# initialized:
# Mesh for data loading
dp_mesh_dim_names = []
if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")

if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")

if self.cp > 1:
if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
elif self.dp_shard > 1: # FSDP
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
if self.dp_shard_enabled:
dp_shard_cp_mesh_dim_name.append("dp_shard")

if self.cp_enabled:
dp_shard_cp_mesh_dim_name.append("cp")

if dp_shard_cp_mesh_dim_name != []:
mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp")

return mesh

Expand Down Expand Up @@ -107,6 +182,10 @@ def tp_enabled(self):
def pp_enabled(self):
return self.pp > 1

@property
def ep_enabled(self):
return self.ep > 1

@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel
Expand Down
Loading

0 comments on commit f7fb732

Please sign in to comment.