Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PP hangs when pipeline_parallel_microbatches < pipeline_parallel_degree #775

Open
cassanof opened this issue Jan 6, 2025 · 10 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@cassanof
Copy link

cassanof commented Jan 6, 2025

Pipeline parallelism seem to hang when the number of microbatches is less than the degree.
This issue occurs for both the standard and interleaved 1F1B schedules. Have not tested other schedules.

@tianyu-l tianyu-l added the bug Something isn't working label Jan 6, 2025
@H-Huang
Copy link
Member

H-Huang commented Jan 7, 2025

Thanks for flagging @cassanof. There is a bug for the number of microbatch calculations for the multi-stage schedules. This PR fixes that: #781

But weirdly, I don't get a hang when trying to reproduce your issue for num_microbatches < pipeline_parallel_degree, I get an error like:

File "/home/howardhuang/local/pytorch/torch/distributed/pipelining/schedules.py", line 345, in check_type_and_len
      raise ValueError(
  ValueError: Expecting 4 arg_mbs but got 2

Perhaps you are using an older pytorch version, maybe there was a fix added in the pytorch repo? Anyways, once #781 is landed could you try again and see if that fixes your issues? Thanks!

@cassanof
Copy link
Author

cassanof commented Jan 7, 2025

Thanks a lot for picking this up!

I'm using the following version of pytorch:

2.6.0.dev20241228+cu126

Happy to report if the issue is resolved after your PR!

@H-Huang
Copy link
Member

H-Huang commented Jan 7, 2025

#781 was landed!

@cassanof
Copy link
Author

cassanof commented Jan 8, 2025

It still seems to hanging at the first microbatch. By the way, I should have said that earlier, I am using PP=8 with 8 GPUs, therefore it's just 1D PP.

@H-Huang
Copy link
Member

H-Huang commented Jan 9, 2025

Can you also ensure that batch_size >= num_microbatches? Added a validation for this in this PR (#784).

Pipelining is implemented with P2P ops between different ranks. I'm thinking what is happening in your case is maybe the process crashed for rank 0, and rank 1 is waiting to recv from it, which subsequently causes other ranks to also wait for their recvs. Though you should see an error for rank 0 unless that is getting swallowed somehow.

Could you paste your .toml file so I can try your exact configs as well? Thanks!

@cassanof
Copy link
Author

cassanof commented Jan 11, 2025

Hi!

Yes, batchsize == microbatch size in my setting. I'm logging every rank, it hangs and at some point NCCL times out.

Here is the relevant bit of the config:

[training]
batch_size = 4
seq_len = 1024
...

[experimental]
pipeline_parallel_degree = 8
pipeline_parallel_microbatches = 4
pipeline_parallel_schedule = "1F1B"
...

Ran on 8xh100. Particularly, seems to timeout on the coalesced op (this was ran on 256 h100 instead):

[rank12]:[E111 06:46:36.515025591 ProcessGroupNCCL.cpp:628] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=15, OpType=COALESCED, NumelIn=184467440737
09551615, NumelOut=18446744073709551615, Timeout(ms)=600000) ran for 600053 milliseconds before timing out.

@cassanof
Copy link
Author

I debugged through this with the basic 1F1B schedule. The deadlock happens in the warmup phase of _step_microbatches.
I just put these prints, and used batch_size=1, PP=8, PP_mb=1:

    def _step_microbatches(
        self,
        arg_mbs: Optional[List] = None,
        kwarg_mbs: Optional[List] = None,
        target_mbs: Optional[List] = None,
        losses: Optional[List] = None,
    ):
        """
        Run one iteration of the pipeline schedule with list of microbatches.
        Will go through all the microbatches according to the 1F1B schedule.

        Args:
            microbatches: list of microbatch args.
        """
        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

        if not self._stage_initialized:
            self._initialize_stage(arg_mbs[0], kwarg_mbs[0])

        # Last stage has 1 warmup, second-to-last 2 warmups, ...
        # first stage `num_stages` warmups
        warmup_chunks = min(
            self._n_microbatches,
            self._num_stages - self._stage.stage_index,
        )

        # Chunk counters
        fwd_mb_index = 0
        bwd_mb_index = 0

        # Warmup phase
        send_work = None
        fwd_sends = []
        for _ in range(warmup_chunks):
            # Receive activations
            fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
            print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got fwd_recvs {fwd_recvs}")
            if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
                print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got recv_work {recv_work}")
                recv_work.wait()

            # Compute
            output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])  # type: ignore[index]
            print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got output shape {output.shape}")

            # Clear previous chunk's forward sends (hopefully they have well
            # finished, otherwise, we are heavily communication bound, in which
            # case it doesn't create a lot of benefit to compute next chunk
            # eagerly either)
            if send_work:
                send_work.wait()

            # Send activations
            fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
            print(f"rank={torch.distributed.get_rank()} forward_one_chunk {fwd_mb_index=} got fwd_sends {fwd_sends}")
            if fwd_mb_index != warmup_chunks - 1:
                # Safe to fire
                send_work = _batch_p2p(fwd_sends, desc="fwd_send")
            # otherwise:
            #   The last foward send is left for fuse with first 1B in 1B1F below

            # Compute loss
            self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
            fwd_mb_index += 1
...

this is what I got:

rank=1 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=0, group_dst=1,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=7 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=6, group_dst=7,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=3 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=2, group_dst=3,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=2 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=1, group_dst=2,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=0 forward_one_chunk fwd_mb_index=0 got fwd_recvs []
rank=6 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=5, group_dst=6,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=5 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=4, group_dst=5,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=4 forward_one_chunk fwd_mb_index=0 got fwd_recvs [P2POp(irecv pg=0, group_src=3, group_dst=4,  torch.Size([1, 1024, 2048]), torch.bfloat16)]
rank=0 forward_one_chunk fwd_mb_index=0 got output shape torch.Size([1, 1024, 2048])
rank=0 forward_one_chunk fwd_mb_index=0 got fwd_sends [P2POp(isend pg=0, group_src=0, group_dst=1,  torch.Size([1, 1024, 2048]), torch.bfloat16)]

and then hangs. It seems like rank 1 is not getting the output of rank 0.

@cassanof
Copy link
Author

I may have spotted the bug. Due to self._n_microbatches=1, we have warmup_chunks=1, therefore it will skip the send_work comm and go into the fused bw+fw comm in the 1B1F phase, but rank 1 is waiting on the output of the fw from rank 0, which will never arrive because rank 0 is sending the fused one.
I'm not entirely positive on how PyTorch internals work, so I might be completely wrong!

@cassanof
Copy link
Author

cassanof commented Jan 11, 2025

yep. i manually patched the code and now seems to be working correctly. i'll send a patched and cleaned up Schedule1F1B class here tomorrow.

@cassanof
Copy link
Author

This is a demo schedule, that can run with mb=1 and larger degrees:

class ScheduleMicroOne1F1B(PipelineScheduleSingle):
    """
    The 1F1B schedule, modified to support running mb=1, with larger  pipeline parallel degrees.
    Will perform one forward and one backward on the microbatches in steady state.
    """

    def _step_microbatches(
        self,
        arg_mbs: Optional[List] = None,
        kwarg_mbs: Optional[List] = None,
        target_mbs: Optional[List] = None,
        losses: Optional[List] = None,
    ):
        """
        Run one iteration of the pipeline schedule with list of microbatches.
        Will go through all the microbatches according to the 1F1B schedule.

        Args:
            microbatches: list of microbatch args.
        """
        arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)

        if not self._stage_initialized:
            self._initialize_stage(arg_mbs[0], kwarg_mbs[0])

        # Last stage has 1 warmup, second-to-last 2 warmups, ...
        # first stage `num_stages` warmups
        warmup_chunks = min(
            self._n_microbatches,
            self._num_stages - self._stage.stage_index,
        )
        print(f"rank={torch.distributed.get_rank()} warmup_chunks: {warmup_chunks} - args: {arg_mbs} - kwargs: {kwarg_mbs}")
        single_warmup = warmup_chunks == 1

        # Chunk counters
        fwd_mb_index = 0
        bwd_mb_index = 0

        # Warmup phase
        send_work = None
        fwd_sends = []
        for _ in range(warmup_chunks):
            # Receive activations
            fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
            if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
                recv_work.wait()

            # Compute
            output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])  # type: ignore[index]

            # Clear previous chunk's forward sends (hopefully they have well
            # finished, otherwise, we are heavily communication bound, in which
            # case it doesn't create a lot of benefit to compute next chunk
            # eagerly either)
            if send_work:
                send_work.wait()

            # Send activations
            fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
            if fwd_mb_index != warmup_chunks - 1 or single_warmup:
                # Safe to fire
                send_work = _batch_p2p(fwd_sends, desc="fwd_send")
            # otherwise:
            #   The last foward send is left for fuse with first 1B in 1B1F below

            # Compute loss
            self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
            fwd_mb_index += 1

        # If only had a single warmup, we need to wait for the send to finish
        if single_warmup and send_work:
            send_work.wait()
        
        # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.

        # 1B1F phase
        while True:  # Don't worry, we have a break inside
            # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
            bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)

            # Now, we need to fire the fwd_sends and bwd_recvs together
            if single_warmup:
                # We only send the bwd_recvs, as we already sent the fwd_sends in the warmup phase
                if bwd_recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
                    bwd_recv_work.wait()
            else:
                if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
                    fuse_work.wait()

            # Backward one chunk
            loss = self._maybe_get_loss(self._stage, bwd_mb_index)
            self._stage.backward_one_chunk(
                bwd_mb_index,
                loss=loss,
                last_backward=bwd_mb_index == self._n_microbatches - 1,
            )

            # Get the bwd send ops, but don't fire, to be fused with the 1F below
            bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
            bwd_mb_index += 1

            if fwd_mb_index == self._n_microbatches:
                # We are done with 1B1F, so break with some left-over bwd_sends
                break

            # We prepare 1F of the `1B1F`
            fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)

            # Fuse it with bwd_sends above
            if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
                fuse_work.wait()

            # Now do the fwd
            output = self._stage.forward_one_chunk(fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index])  # type: ignore[index]

            # Compute loss
            self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)

            # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
            fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
            fwd_mb_index += 1

        # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
        send_work = _batch_p2p(bwd_sends, desc="bwd_send")

        # Cooldown
        while bwd_mb_index < self._n_microbatches:
            # prepare bwd recv ops
            bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
            if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
                recv_work.wait()

            # Backward one chunk
            loss = self._maybe_get_loss(self._stage, bwd_mb_index)
            self._stage.backward_one_chunk(
                bwd_mb_index,
                loss=loss,
                last_backward=bwd_mb_index == self._n_microbatches - 1,
            )

            # Clear previous chunk's backward sends (hopefully they have well finished)
            if send_work:
                send_work.wait()

            # Get the bwd send ops, fire it
            bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
            send_work = _batch_p2p(bwd_sends, desc="bwd_send")
            bwd_mb_index += 1

        # Wait for the last backward send to finish
        if send_work:
            send_work.wait()

        # Return losses if there is a container passed in
        self._update_losses(self._stage, losses)

I also experimented with GPipe, and found that it doesn't suffer from this issue, so I will be using that for now.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Jan 15, 2025
…chedules (#144702)

There is an edge case where `Schedule1F1B` will hang when num_microbatches=1 (pytorch/torchtitan#775). For validation it makes sense to check that the number of stages should be >= number of microbatches otherwise there will be an even larger bubble.

This can be removed when we have the single stage schedules to use an IR and updated to run with schedule runtime (issue tracker #144701)

Pull Request resolved: #144702
Approved by: https://github.com/kwen2501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants