Distributed Data Parallel (DDP)

Thunder has its own Distributed Data Parallel (DDP) transform that we recommend using, although compiled modules also work with PyTorch’s DDP transform.

You can wrap a model in Thunder’s ddp like this:

from thunder.distributed import ddp

model = MyModel()
ddp_model = ddp(model)
cmodel = thunder.jit(ddp_model)

Specifying which rank to broadcast from is optional. ddp() will broadcast from the lowest rank in that group if broadcast_from is not specified.

Thunder’s ddp is compatible with PyTorch distributed runners like torchrun (https://pytorch.org/docs/stable/elastic/run.html).

When using PyTorch’s DDP, call DDP on the jitted module:

from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel()
jitted_model = thunder.jit(model)
ddp_model = DDP(jitted_model)

The ability of Thunder to express distributed algorithms like DDP as a simple transform on the trace is one of Thunder’s strengths and is being leveraged to quickly implement more elaborate distributed strategies, like Fully Sharded Data Parallel (FSDP).