One of the main goals for torchtitan was to provide a version of distributed LLM that was not only high performance, but utilized native PyTorch techniques and readable code. The challenge is how to compose together so many individual library components (FSDP, TP, PP, Float8, Compile, DCP, ..., just to name a few), and avoid having to make too many changes to the model guts in the process. A lot of the work is behind the scenes, designing individual components to make fewer assumptions, use common abstractions (e.g. DTensor) and generally "get along". But we found a few tweaks to the model code invaluable as well, and wanted to share those changes and the rationale for them.
When applying Pipeline Parallelism, you will have to construct nn.Module objects representing the portion of the model that runs on a given pipeline stage. Whether you plan to manually edit your model code, or use techniques like tracing to extract model chunks, a few changes to the original model code can go a long way to making this process easier.
Most likely, you can write your model in such a way that the top-level nn.Module owns a sequence of child modules that it calls during forward, delegating most of the complexity to the child module forwards. If you can reduce your top level forward to mostly a for-loop over child module calls, then you'll simplify the pipeline-partitioning task to choosing the set of submodules to keep per stage. If you have non-trivial logic in the top-level forward, you'll have to find a way to patch that logic back onto the resulting pipeline stage model, which can be annoying.
Example (PR #321):
We used to slice the freqs_cis
buffer by seq_len
in the top level forward, pass that into child modules, and expect that inside the child modules the seq_len
would match up with the size of other local tensors. But we don't know about whether TP was applied or not when we consider PP splitting and could create a mismatch. Its just as easy to perform the freqs_cis
slicing inside the child submodule, using the runtime-accurate local seq_len
, and this sidesteps the issue at PP slicing time.
Example (PR #322): We decided to actually reuse the top-level model object on every PP stage, just delete the layers we don't want, and make sure that the top-level forward would do the right thing. This means we don't have to make a separate runtime pp_forward that glues together child modules per stage. The first change was using a moduledict instead of modulelist to store layers. This preserves layer Fully Qualified Names (FQNs) even when deleting some layers - e.g. layers.1 stays layers.1 even if you remove layers.0, which isn't true for a list- this matters for checkpoint save/load. Preserving FQNs is a requirement for using Distributed Checkpointing (DCP) since it uses FQNs as globally unique IDs for sharding metadata. The second change was making the input and output layers optional- if the layer exists, we run it, otherwise we feed the input through to bypass it. With these two changes, we can just (meta)-initialize the whole model, delete the unused parts per stage, then materialize the remaining part on GPU before loading a checkpoint.
Initializing the pipeline-parallel model is challenging becuase we assume the model could be so large as to not fit on local GPU (or possibly, even on CPU), and we also want to use the (bitwise) same initialization as we use for 1D or 2D parallel models, to ease debugging or comparisons between runs. It's not that easy to rewrite the original model's init_weights
function to be tolerant of initializing only some layers, and also serializing initialization operations globally for consistent RNG order.
For now, we sidestep all these problems with a simple but brutal solution: Initialize the whole model on some CPU instance, save a checkpoint file, and then lean on Distributed Checkpointing's "load" functionality to initialize the FQNs that are present on a given PP stage after stage creation. For future work, we consider adding a more elaborate initialization scheme to torch.pipelining
.
One issue with seed checkpoints is that we rely on initializing every model state from the checkpoint, which means the model can't have any non-persistent buffers, or else we have to specially initialize those in train.py after pipeline splitting. freqs_cis
was originally a non-persistent buffer, and we changed this to persistent in order to load it from the seed checkpoint.
We intentionally upcast the final output tensor to fp32 inside the loss function rather in the Transformer.forward()
so that forward and backward casts can be fused with the loss forward and backward respectively when we torch.compile()
the loss function. This can improve both throughput and memory usage.