-
Notifications
You must be signed in to change notification settings - Fork 244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PP hangs when pipeline_parallel_microbatches < pipeline_parallel_degree #775
Comments
Thanks for flagging @cassanof. There is a bug for the number of microbatch calculations for the multi-stage schedules. This PR fixes that: #781 But weirdly, I don't get a hang when trying to reproduce your issue for
Perhaps you are using an older pytorch version, maybe there was a fix added in the pytorch repo? Anyways, once #781 is landed could you try again and see if that fixes your issues? Thanks! |
Thanks a lot for picking this up! I'm using the following version of pytorch:
Happy to report if the issue is resolved after your PR! |
#781 was landed! |
It still seems to hanging at the first microbatch. By the way, I should have said that earlier, I am using PP=8 with 8 GPUs, therefore it's just 1D PP. |
Can you also ensure that Pipelining is implemented with P2P ops between different ranks. I'm thinking what is happening in your case is maybe the process crashed for rank 0, and rank 1 is waiting to recv from it, which subsequently causes other ranks to also wait for their recvs. Though you should see an error for rank 0 unless that is getting swallowed somehow. Could you paste your .toml file so I can try your exact configs as well? Thanks! |
Hi! Yes, Here is the relevant bit of the config:
Ran on 8xh100. Particularly, seems to timeout on the coalesced op (this was ran on 256 h100 instead):
|
I debugged through this with the basic 1F1B schedule. The deadlock happens in the warmup phase of def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the 1F1B schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Last stage has 1 warmup, second-to-last 2 warmups, ...
# first stage `num_stages` warmups
warmup_chunks = min(
self._n_microbatches,
self._num_stages - self._stage.stage_index,
)
# Chunk counters
fwd_mb_index = 0
bwd_mb_index = 0
# Warmup phase
send_work = None
fwd_sends = []
for _ in range(warmup_chunks):
# Receive activations
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got fwd_recvs {fwd_recvs}")
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got recv_work {recv_work}")
recv_work.wait()
# Compute
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got output shape {output.shape}")
# Clear previous chunk's forward sends (hopefully they have well
# finished, otherwise, we are heavily communication bound, in which
# case it doesn't create a lot of benefit to compute next chunk
# eagerly either)
if send_work:
send_work.wait()
# Send activations
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got fwd_sends {fwd_sends}")
if fwd_mb_index != warmup_chunks - 1:
# Safe to fire
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
# otherwise:
# The last foward send is left for fuse with first 1B in 1B1F below
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
fwd_mb_index += 1
... this is what I got:
and then hangs. It seems like rank 1 is not getting the output of rank 0. |
I may have spotted the bug. Due to |
yep. i manually patched the code and now seems to be working correctly. i'll send a patched and cleaned up |
This is a demo schedule, that can run with mb=1 and larger degrees: class ScheduleMicroOne1F1B(PipelineScheduleSingle):
"""
The 1F1B schedule, modified to support running mb=1, with larger pipeline parallel degrees.
Will perform one forward and one backward on the microbatches in steady state.
"""
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the 1F1B schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Last stage has 1 warmup, second-to-last 2 warmups, ...
# first stage `num_stages` warmups
warmup_chunks = min(
self._n_microbatches,
self._num_stages - self._stage.stage_index,
)
print(f"rank={torch.distributed.get_rank()} warmup_chunks: {warmup_chunks} - args: {arg_mbs} - kwargs: {kwarg_mbs}")
single_warmup = warmup_chunks == 1
# Chunk counters
fwd_mb_index = 0
bwd_mb_index = 0
# Warmup phase
send_work = None
fwd_sends = []
for _ in range(warmup_chunks):
# Receive activations
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
recv_work.wait()
# Compute
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
# Clear previous chunk's forward sends (hopefully they have well
# finished, otherwise, we are heavily communication bound, in which
# case it doesn't create a lot of benefit to compute next chunk
# eagerly either)
if send_work:
send_work.wait()
# Send activations
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
if fwd_mb_index != warmup_chunks - 1 or single_warmup:
# Safe to fire
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
# otherwise:
# The last foward send is left for fuse with first 1B in 1B1F below
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
fwd_mb_index += 1
# If only had a single warmup, we need to wait for the send to finish
if single_warmup and send_work:
send_work.wait()
# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
# 1B1F phase
while True: # Don't worry, we have a break inside
# We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
# Now, we need to fire the fwd_sends and bwd_recvs together
if single_warmup:
# We only send the bwd_recvs, as we already sent the fwd_sends in the warmup phase
if bwd_recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
bwd_recv_work.wait()
else:
if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
fuse_work.wait()
# Backward one chunk
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
# Get the bwd send ops, but don't fire, to be fused with the 1F below
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
bwd_mb_index += 1
if fwd_mb_index == self._n_microbatches:
# We are done with 1B1F, so break with some left-over bwd_sends
break
# We prepare 1F of the `1B1F`
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
# Fuse it with bwd_sends above
if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
fuse_work.wait()
# Now do the fwd
output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
fwd_mb_index += 1
# Remember we still have some bwd_sends left over after the break? Now it is time to fire it
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
# Cooldown
while bwd_mb_index < self._n_microbatches:
# prepare bwd recv ops
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
recv_work.wait()
# Backward one chunk
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
# Clear previous chunk's backward sends (hopefully they have well finished)
if send_work:
send_work.wait()
# Get the bwd send ops, fire it
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
bwd_mb_index += 1
# Wait for the last backward send to finish
if send_work:
send_work.wait()
# Return losses if there is a container passed in
self._update_losses(self._stage, losses) I also experimented with GPipe, and found that it doesn't suffer from this issue, so I will be using that for now. |
…chedules (#144702) There is an edge case where `Schedule1F1B` will hang when num_microbatches=1 (pytorch/torchtitan#775). For validation it makes sense to check that the number of stages should be >= number of microbatches otherwise there will be an even larger bubble. This can be removed when we have the single stage schedules to use an IR and updated to run with schedule runtime (issue tracker #144701) Pull Request resolved: #144702 Approved by: https://github.com/kwen2501
Pipeline parallelism seem to hang when the number of microbatches is less than the degree.
This issue occurs for both the standard and interleaved 1F1B schedules. Have not tested other schedules.
The text was updated successfully, but these errors were encountered: