FSDP Tutorial

In this tutorial, we will walk through the implementation of Fully Sharded Data Parallel (FSDP) with Zero2 sharding strategy in thunder.


In recent times, the LLM models have grown so large that all the model parameters don’t fit on a single GPU. To circumvent this problem, there are various strategies like Tensor Parallel, Pipeline Parallel, Fully Sharded Data Parallel, etc to train these large models. In this tutorial, we discuss and implement Zero2 strategy for Fully Sharded Data Parallel (FSDP).

What is Zero2 strategy for FSDP?

In this strategy, we shard the model parameters across all the availabe GPUs. That is each GPU holds onto only a chunk of the parameter. During the forward pass, all GPUs call all_gather communication primitive to gather the parameters from other GPUs. Unlike Zero3 strategy which frees the parameter after forward pass, we save these unsharded parameters for backward pass. This is to save the overhead of extra communication. In the backward pass, we utilize the saved parameters and compute the gradients. Once the gradients are computed, we use reduce_scatter communication primitive to reduce (average) the gradients across all GPUs and scatter those gradients so that a given GPU holds only a chunk of gradient.

For more information on FSDP, we recommend reading

  1. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel - Link

  2. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models - Link

Example Model

For this example we will have a simple model Linear(Tanh(Linear(x))) which will be sharded over 2 GPUs

NOTE: We are generating the abstract trace so we don’t actually need a system with 2 GPUs for this. It is only required when we execute this trace.

[ ]:
import torch
import torch.distributed
import thunder
import thunder.distributed
from IPython.display import Code
[ ]:
dim = 64
def create_model():
    layers = [torch.nn.Linear(dim, dim, bias=False),
              torch.nn.Linear(dim, dim, bias=False)]
    return torch.nn.Sequential(*layers).to(device)

# Model
model = create_model()
# Input
x = torch.randn(dim, dim, device=device)

# we want to obtain a functional version of our model. The JIT does that internally and we reach into those
# internals here
thunder_model = thunder.jit(model)
cache_rec, i_, _ = thunder.compile_data(thunder_model).get_computation_and_inputs(x)
computation_trace = cache_rec.computation_traces[0]

[ ]:
def wrap_as_highlighted_code(trace):
    return Code(str(trace), language="python")

We can show the functional version:

[ ]:

Step 1 : Configuration

For our implementation of FSDP, we will generate the trace where we are sharding our model over 2 GPU

[ ]:
# FSDP Config
# Usually these values are set in the environment by `torchrun` but for this example
# we will set them ourselves
world_size = 2  # We have two processes.
global_rank = 0  # Current process is the very first process.

Step 2: Function to shard parameters

Next step is to write a function which will actually shard the parameters over 0-dim.

[ ]:
# NOTE: We shard over 0th dimension of the param.
def shard_param(param: torch.Tensor, rank: int, world_size: int, name: str) -> None:
    # We will keep it simple and error if param's 0th dim is not divisible by ``world_size``.
    # Alternative is that we can pad our parameters so that they are divisible by `world_size`.
    assert param.shape[0] % world_size == 0,(
        f"Current sharding requires the first dimension of the parameter {name!r} ({param.shape[0]})"
        f" to be divisible by the world size ({world_size})"
    chunk_size = param.shape[0] // world_size

    # rank helps us determine which chunk of the parameter we will hold.
    shard = param.data.narrow(0, chunk_size * rank, chunk_size).clone()
    param.data = shard

# Shard each parameter of the model
for param_name, param in model.named_parameters():
    shard_param(param, global_rank, world_size, param_name)
    # Mark the param to denote that it is sharded.
    # This is required by the synchronization primitive we will use below.
    param.distparallel_type = thunder.core.proxies.DistParallelType.FULLY_SHARDED
[ ]:
# Verify our model looks as expected
[ ]:
# Let us verify that we have actually sharded the parameters.
# Checking if the weight of 1st Linear layer is sharded over 0th dim.
assert model[0].weight.shape == (dim / world_size, dim)

Step 3: Add an operation to synchronize the parameters before calling the model.forward.

We have to create a process group. This is needed because the synchronization primitive synchronize that we will use to gather and scatter our weights in forward and backward requires a process group.

[ ]:
# Create a process group
options = torch.distributed.distributed_c10d.ProcessGroup.Options(backend="nccl")
process_group = torch.distributed.distributed_c10d.ProcessGroup(torch.distributed.distributed_c10d.Store(),
                                                     global_rank, world_size, options)
torch.distributed.distributed_c10d.GroupMember.WORLD = process_group
[ ]:
# now we have a  functional version of the model which
# takes as inputs the expected arguments and all the parameters.
functional_forward = computation_trace.python_callable()

# This function creates a model with synchronization
# before calling the forward pass.
def model_with_syncs(x, *params):
    # We call `prims.synchronize` on all the parameters.
    # This is essentially calling `all_gather` so that we have the complete
    # parameter before we actually to the forward computation.
    unsharded_params = []
    for param in params:
        unsharded_params.append(thunder.distributed.prims.synchronize(param, process_group))

    return functional_forward(x, *unsharded_params)

Let us now see what the trace of our model looks like with all the synchronization.

Two main observations regarding the below trace 1. We can observe the prims.synchronize that we inserted using model_with_syncs. 2. Output of the prims.synchronize have the shape of unsharded (original) parameter.

With this, we have implemented the FSDP for the forward pass of our model.

[ ]:
trace = thunder.trace()(model_with_syncs, x, *model.parameters())


For backward, we don’t have to do anything because thunder already knows how to compute the backward of prims.synchronize. We can verify that by using the value_and_grad transform to generate the complete forward and backward trace together.

Observations for the trace below: 1. prims.synchronize from previous trace is now decomposed into prims.all_gather and prims.wait. So, we can clearly see that we make a communication call to gather the parameter (which is asynchronous) and wait till we have the complete parameter. 2. At the end of the trace (after the forward and the backward computation), we see calls to prims.reduce_scatter and prims.wait. This takes care of reducing the gradients across all the GPUs and sharding them. One thing to note, for averaging gradients with low dynamic range dtype like float16, if we naively sum the gradients across GPUs before dividing by world_size, it can lead to overflows. So we scale the gradient tensor with world_size, before calling reduce_scatter with sum reduction to effectively average the gradients without overflow.

[ ]:
from thunder.core.transforms import value_and_grad

forward_and_backward_model = value_and_grad(model_with_syncs)

forward_backward_trace = thunder.trace()(forward_and_backward_model, x, *model.parameters())


The above trace, only contains primitive which specifies the semantic of an operation abstractly but doesn’t perform the actual computation.

Now we will generate the execution trace which can actually perform the compute.

In the execution trace generated below, we can see that all the primitives have been replaced with actually PyTorch operations. Also, our synchronization primitives have been replaced with PyTorch implementation provided by thunder i.e. torch_all_gather_prim_impl, torch_reduce_scatter_prim_impl, torch_wait_prim_impl.

[ ]:
optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())

# Grab the final trace
exec_trace = optimized_trace[-1]

Step 4 : Running the actual computation

Running the actual computation will require setting up 2 processes and running our above code in both those processes (which can be tricky with Jupyter Notebook). Instead, we will write a small script and run it with torchrun which takes care of setting up the processes and relevant state.

NOTE: This requires device running this notebook to have at least 2-GPUs

In the example below, we will use thunder.distributed.fsdp which does the same as what we did above (with some extra checks). The code below should look familiar as it is roughly all the above pieces in a single script.

[ ]:
%%writefile thunder_fsdp_simple_example.py

# imports
from thunder.tests.litgpt_model import GPT, Config
import torch
import torch.distributed
import thunder
import thunder.distributed
import os

# # # # # # # #
# Create Model
# # # # # # # #

# NOTE: We create the model on CPU.
dim = 64
def create_model():
    layers = []
    layers.append(torch.nn.Linear(dim, dim))
    layers.append(torch.nn.Linear(dim, dim))
    return torch.nn.Sequential(*layers).to(device)

# Model
model = create_model()
# Input
x = torch.randn(dim, dim, device=device)

# # # # # # # #
# Setup for distributed
# # # # # # # #

rank = int(os.environ["LOCAL_RANK"])

device = f"cuda:{rank}"

# # # # # # # #
# Move inputs to correct device
# # # # # # # #
x = x.to(device)

# # # # # # # #
# Wrap the model in thunder.distributed.fsdp
# # # # # # # #

# thunder.distributed.fsdp takes care of moving the parameter
# shard to the correct GPU for the current process.
cmodel = thunder.jit(thunder.distributed.fsdp(model))

# Run the forward pass.

# # # # # # # #
# Check the traces
# # # # # # # #
fwd_traces = thunder.last_traces(cmodel)
bwd_traces = thunder.last_backward_traces(cmodel)

# # # # # # # #
# Print and check to see if they match ours
# # # # # # # #
if rank == 0:
    print("*******"* 8)

Let us run the above script and check what the trace looks like.

We can observe that forward trace has torch_all_gather_prim_impl to gather the parameter before forward pass and the backward trace has torch_reduce_scatter_prim_impl to reduce and scatter the gradients back to different GPUs. This is similar to our implementation above.

[ ]:
!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py


We have created our implementation of FSDP to shard our model across multiple GPUs. In the process, we also learned that:

  1. thunder provides us with primitives for synchronization across mutiple GPUs.

  2. thunder also takes care of implementing the backward support for the synchronization primitives, so we don’t have to explicitly do anything to get the backward working.

  3. We can just easily apply thunder.distributed.fsdp to our model and it will take care of sharding the parameters and also adding synchronizations to our model. Also, we can easily check the modifications by inspecting the traces.