diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 44089249..f1e5aaa1 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -364,6 +364,17 @@ def __init__(self): default=1, help="Context parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--experimental.expert_parallel_degree", + type=int, + default=1, + help=""" + Expert parallelism degree. 1 means disabled. + When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree. + When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree, + where k >= 1 and k | data_parallel_shard_degree. + """, + ) self.parser.add_argument( "--experimental.expert_parallel_mode", type=str, diff --git a/torchtitan/parallelisms/expert_parallel.py b/torchtitan/parallelisms/expert_parallel.py index 72f33015..3653093a 100644 --- a/torchtitan/parallelisms/expert_parallel.py +++ b/torchtitan/parallelisms/expert_parallel.py @@ -329,22 +329,37 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: # This class is for dp2ep with TP (without TP we can just use ExpertParallel) class ExpertTensorParallel(ParallelStyle): + def __init__( + self, + *, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to device_mesh, + # as there's an issue from DeviceMesh that + # "Cannot create a submesh from a submesh." + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + @staticmethod - def _prepare_input_fn(mod, inputs, device_mesh): + def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh): input_tensor = inputs[0] # input_tensor of placements Shard(1) on the tp mesh + assert not isinstance(input_tensor, DTensor) # a2a(ep) - input_tensor = DTensor.from_local(input_tensor, device_mesh["dp"], (Shard(1),)) + input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),)) input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local() # ag(tp) - input_tensor = DTensor.from_local(input_tensor, device_mesh["tp"], (Shard(1),)) + input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),)) input_tensor = input_tensor.redistribute(placements=(Replicate(),)) return input_tensor - def _partition_fn(self, name, module, device_mesh): - # NOTE: the following code should work when FSDP is applied on the non-expert modules. + @staticmethod + def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh): + # TODO: FSDP doesn't support sharding a 2D Tensor # module.register_parameter( # "gate_proj", # nn.Parameter( @@ -364,9 +379,8 @@ def _partition_fn(self, name, module, device_mesh): # ), # ) # Column-wise sharding - # NOTE: the following code works when FSDP is not applied. - # TODO: the above 2D sharding (only on experts) causes optimizer foreach to fail - # TODO: apply FSDP on the non-expert params should resolve the issue + # NOTE: instead, for MoE experts, we shard on the EP mesh and then "forget" it + # TODO: this is problematic from the DCP perspective module.register_parameter( "gate_proj", nn.Parameter( @@ -376,7 +390,7 @@ def _partition_fn(self, name, module, device_mesh): module.gate_proj, device_mesh, [Shard(0), Shard(2)] ).to_local() ), - device_mesh["tp"], + tp_mesh, (Shard(2),), ) ), @@ -390,7 +404,7 @@ def _partition_fn(self, name, module, device_mesh): module.down_proj, device_mesh, [Shard(0), Shard(1)] ).to_local() ), - device_mesh["tp"], + tp_mesh, (Shard(1),), ) ), @@ -404,20 +418,20 @@ def _partition_fn(self, name, module, device_mesh): module.up_proj, device_mesh, [Shard(0), Shard(2)] ).to_local() ), - device_mesh["tp"], + tp_mesh, (Shard(2),), ) ), ) # Column-wise sharding @staticmethod - def _prepare_output_fn(mod, outputs, device_mesh): + def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh): # outputs of placements Partial() on the tp mesh # rs(tp) outputs = outputs.redistribute(placements=(Shard(1),)).to_local() # a2a(ep) - outputs = DTensor.from_local(outputs, device_mesh["dp"], (Shard(0),)) + outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),)) outputs = outputs.redistribute(placements=(Shard(1),)).to_local() return outputs @@ -426,7 +440,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, - self._partition_fn, - self._prepare_input_fn, - self._prepare_output_fn, + partial(self._partition_fn, self.tp_mesh, self.ep_mesh), + partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh), + partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh), ) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 2ee9b87b..b38f23af 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -18,6 +18,8 @@ class ParallelDims: cp: int tp: int pp: int + ep: int + ep_mode: str world_size: int enable_loss_parallel: bool @@ -25,14 +27,15 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, cp, tp, pp = ( + dp_replicate, dp_shard, cp, tp, pp, ep = ( self.dp_replicate, self.dp_shard, self.cp, self.tp, self.pp, + self.ep, ) - for d in (dp_replicate, cp, tp, pp): + for d in (dp_replicate, cp, tp, pp, ep): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." @@ -41,17 +44,80 @@ def _validate(self): dp = self.world_size // (cp * tp * pp) self.dp_shard = dp_shard = dp // dp_replicate - assert dp_replicate >= 1 assert dp_shard >= 1 - assert cp >= 1, cp - assert tp >= 1, tp - assert pp >= 1, pp assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" ) + if ep > 1: + assert self.ep_mode in ["tp", "tp2ep", "dp2ep"] + if self.ep_mode == "tp" or self.ep_mode == "tp2ep": + assert ep == tp + elif self.ep_mode == "dp2ep": + # EP would borrow all cp and some dp_shard degree + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + else: + self.ep_mode = "none" + + def build_mesh_with_dp2ep(self, device_type): + # In dp2ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_1 * dp_shard_2 + # ep = dp_shard_2 * cp + dp_shard_1 = self.dp_shard * self.cp // self.ep + dp_shard_2 = self.ep // self.cp + + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"], + ): + # dp_shard_1 is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_1": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading + dp_mesh_dim_names = [] + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_mesh_dim_names.append("dp_shard_1") + if "dp_shard_2" in names: + dp_mesh_dim_names.append("dp_shard_2") + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + + # Mesh for param sharding + dp_shard_cp_mesh_dim_name = [] + dp_shard_cp_mesh_dim_name.append("dp_shard_1") + if "dp_shard_2" in names: + dp_shard_cp_mesh_dim_name.append("dp_shard_2") + if self.cp_enabled: + dp_shard_cp_mesh_dim_name.append("cp") + mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp") + + # Mesh for ep + ep_mesh_dim_names = [] + if "dp_shard_2" in names: + ep_mesh_dim_names.append("dp_shard_2") + if self.cp_enabled: + ep_mesh_dim_names.append("cp") + assert len(ep_mesh_dim_names) > 0 + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + def build_mesh(self, device_type): + if self.ep_mode == "dp2ep": + return self.build_mesh_with_dp2ep(device_type) + dims = [] names = [] for d, name in zip( @@ -60,26 +126,35 @@ def build_mesh(self, device_type): ): if d > 1: dims.append(d) - if (name == "dp_replicate" and self.dp_shard == 1) or ( - name == "dp_shard" and self.dp_replicate == 1 - ): - names.append("dp") - else: - names.append(name) + names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are - # initialized - if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP - mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + # initialized: + # Mesh for data loading + dp_mesh_dim_names = [] + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") - if self.cp > 1: - if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP - mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") - elif self.dp_shard > 1: # FSDP - mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + + # Mesh for param sharding + dp_shard_cp_mesh_dim_name = [] + if self.dp_shard_enabled: + dp_shard_cp_mesh_dim_name.append("dp_shard") + + if self.cp_enabled: + dp_shard_cp_mesh_dim_name.append("cp") + + if dp_shard_cp_mesh_dim_name != []: + mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp") return mesh @@ -107,6 +182,10 @@ def tp_enabled(self): def pp_enabled(self): return self.pp > 1 + @property + def ep_enabled(self): + return self.ep > 1 + @property def loss_parallel_enabled(self): return self.tp > 1 and self.enable_loss_parallel diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index feab87ea..a2e21e27 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -35,7 +35,6 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims -from torchtitan.parallelisms.utils import check_if_feature_in_pytorch def parallelize_llama( @@ -66,16 +65,15 @@ def parallelize_llama( enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) - ep_mode = job_config.experimental.expert_parallel_mode - if ep_mode != "none": + if parallel_dims.ep_mode != "none": apply_ep( model, - ep_mode=ep_mode, - dp_mesh=world_mesh["dp"] if parallel_dims.dp_shard_enabled else None, + ep_mode=parallel_dims.ep_mode, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_mode == "dp2ep" else None, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - dp_tp_mesh=( - world_mesh["dp", "tp"] - if parallel_dims.dp_shard_enabled and parallel_dims.tp_enabled + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.ep_mode == "dp2ep" and parallel_dims.tp_enabled else None ), ) @@ -94,48 +92,39 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + or parallel_dims.ep_mode == "dp2ep" ): # apply FSDP or HSDP, potentially with Context Parallel - try: - dp_mesh_dim_names = ( - ("dp_replicate", "dp_shard") - if parallel_dims.dp_replicate_enabled - else ("dp",) - ) - dp_mesh = ( - world_mesh["dp_cp"] - if parallel_dims.cp_enabled - else world_mesh[(*dp_mesh_dim_names,)] - ) - except IndexError: - # note: this is a workaround of the above logic for old pytorch version - # where https://github.com/pytorch/pytorch/pull/138945 is not included - # throw a warning to encourage users to upgrade to a newer pytorch version - check_if_feature_in_pytorch( - "DeviceMesh flattening over 3D+ meshes", - "https://github.com/pytorch/pytorch/pull/138945", - "2.6.0.dev20241030", - ) - # TODO: remove this workaround once PyTorch 2.6 is released - dp_mesh_dim_names = ( - ("dp_replicate", "dp_shard") - if parallel_dims.dp_replicate_enabled - else ("dp",) - ) - # note that mesh can only be flattened from the finest-grained mesh dimensions - dp_mesh = ( - world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") - if parallel_dims.cp_enabled - else world_mesh[dp_mesh_dim_names] + + if not parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled: + # Composability of DDP + CP is not supported. + raise RuntimeError( + "Composability of DDP + CP or DDP + EP is not supported." ) - # apply_fsdp( - # model, - # dp_mesh, - # param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], - # reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - # pp_enabled=parallel_dims.pp_enabled, - # cpu_offload=job_config.training.enable_cpu_offload, - # ) + # the mesh dim names of which the model params are sharded on + dp_mesh_dim_names = [] + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_mesh_dim_names.append("dp_shard_cp") + + # the mesh dim names of which the MoE params are sharded on + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_mode == "dp2ep": + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_1") + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + ep_enabled=(parallel_dims.ep_mode == "dp2ep"), + dp_mod_ep_mesh=world_mesh[tuple(dp_mod_ep_mesh_dim_names)], + ) if parallel_dims.dp_replicate_enabled: logger.info("Applied HSDP to the model") @@ -254,9 +243,9 @@ def apply_tp( def apply_ep( model: nn.Module, ep_mode: str, - dp_mesh: Optional[DeviceMesh] = None, + ep_mesh: Optional[DeviceMesh] = None, tp_mesh: Optional[DeviceMesh] = None, - dp_tp_mesh: Optional[DeviceMesh] = None, + ep_tp_mesh: Optional[DeviceMesh] = None, ): from torch.distributed.tensor import Partial from torch.distributed.tensor.parallel import PrepareModuleOutput @@ -317,10 +306,10 @@ def apply_ep( elif ep_mode == "dp2ep": if not tp_mesh: - assert dp_mesh is not None + assert ep_mesh is not None parallelize_module( module=transformer_block.moe.experts, - device_mesh=dp_mesh, + device_mesh=ep_mesh, # input / output sharding on the tokens dim parallelize_plan=ExpertParallel( input_layouts=Shard(1), @@ -329,7 +318,7 @@ def apply_ep( ) else: # dp2ep with TP (no Router Parallel) - assert dp_tp_mesh is not None + assert ep_tp_mesh is not None moe_plan = { # input / output sharding on the seqlen dim "moe": PrepareModuleInputOutput( @@ -356,8 +345,10 @@ def apply_ep( parallelize_module( module=transformer_block.moe.experts, - device_mesh=dp_tp_mesh, - parallelize_plan=ExpertTensorParallel(), + device_mesh=ep_tp_mesh, + parallelize_plan=ExpertTensorParallel( + tp_mesh=tp_mesh, ep_mesh=ep_tp_mesh + ), ) logger.info(f"Applied {ep_mode} Expert Parallelism to the model") @@ -466,6 +457,8 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, + ep_enabled: bool = False, + dp_mod_ep_mesh: Optional[DeviceMesh] = None, ): """ Apply data parallelism to the model. FSDP2 is used here. @@ -484,6 +477,16 @@ def apply_fsdp( # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(model.layers) - 1 + + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + if ep_enabled: + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard( transformer_block, **fsdp_config, diff --git a/train.py b/train.py index 6545ab1a..057dcd85 100644 --- a/train.py +++ b/train.py @@ -50,6 +50,8 @@ def main(job_config: JobConfig): cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, + ep=job_config.experimental.expert_parallel_degree, + ep_mode=job_config.experimental.expert_parallel_mode, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, )