Train a MLP on MNIST

Here’s a complete program that trains a torchvision MLP on MNIST:

pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu121

Here’s the code:

import os
import torch
import torchvision
import torchvision.transforms as transforms
import thunder

# Creates train and test datasets
device = 'cuda'

device_transform = transforms.Lambda(lambda t: t.to(device))
flatten_transform = transforms.Lambda(lambda t: t.flatten())
my_transform = transforms.Compose((transforms.ToTensor(), device_transform, flatten_transform))

train_dataset = torchvision.datasets.MNIST("/tmp/mnist/train", train=True, download=True, transform=my_transform)
test_dataset = torchvision.datasets.MNIST("/tmp/mnist/test", train=False, download=True, transform=my_transform)

# Creates Samplers
train_sampler = torch.utils.data.RandomSampler(train_dataset)
test_sampler = torch.utils.data.RandomSampler(test_dataset)

# Creates DataLoaders
batch_size = 8
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler=test_sampler)

# Evaluates the model
def eval_model(model, test_loader):
    num_correct = 0
    total_guesses = 0
    for data, targets in iter(test_loader):
        targets = targets.cuda()
        # Acquires the model's best guesses at each class
        results = model(data)
        best_guesses = torch.argmax(results, 1)
        # Updates number of correct and total guesses
        num_correct += torch.eq(targets, best_guesses).sum().item()
        total_guesses += batch_size
    # Prints output
    print("Correctly guessed ", (num_correct/total_guesses) * 100, "% of the dataset")

# Trains the model
def train_model(model, train_loader, *, num_epochs: int = 1):
    loss_fn = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters())
    for epoch in range(num_epochs):
        for data, targets in iter(train_loader):
            targets = targets.cuda()
            # Acquires the model's best guesses at each class
            results = model(data)
            # Computes loss
            loss = loss_fn(results, targets)
            # Updates model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# Constructs the model
model = torchvision.ops.MLP(in_channels=784, hidden_channels=[784, 784, 784, 28, 10], bias=True, dropout=.1).to(device)

# Performs an initial evaluation
model.eval().requires_grad_(False)
jitted_eval_model = thunder.jit(model)
eval_model(jitted_eval_model, test_loader)

# Trains the model
model.train().requires_grad_(True)
jitted_train_model = thunder.jit(model)
train_model(jitted_train_model, train_loader)

model.eval().requires_grad_(False)

# Performs a final evaluation
eval_model(jitted_eval_model, test_loader)

# Evaluates the original, unjitted model
# The unjitted and jitted model share parameters, so it's
# also updated
eval_model(model, test_loader)

# Acquires and prints thunder's "traces", which show what thunder executed
# The training model has both "forward" and "backward" traces, corresponding
# to its forward and backward computations.
# The evaluation model has only one set of traces.
fwd_traces = thunder.last_traces(jitted_train_model)
bwd_traces = thunder.last_backward_traces(jitted_train_model)
eval_traces = thunder.last_traces(jitted_eval_model)

print("This is the trace that thunder executed for training's forward computation:")
print(fwd_traces[-1])

print("This is the trace that thunder executed for training's backward computation:")
print(bwd_traces[-1])

print("This is the trace that thunder executed for eval's computation:")
print(eval_traces[-1])

Let’s look at a few parts of this program more closely.

First, up until the call to thunder.jit() the program is just Python, PyTorch and torchvision. thunder.jit() accepts a PyTorch module (or function) and returns a Thunder-optimized module that has the same signature, parameters and buffers.

After compilation the program is, again, just Python and PyTorch, until the very end. Behind the scenes, when a Thunder module is called it produces a “trace” representing the sequence of tensor operations to perform. This trace is then transformed and optimized, and the sequence of these traces for the last inputs can be acquired by calling thunder.last_traces() on the module (the traced program changes when different input data types, devices, or other properties are used). When the module is used for training, thunder.last_traces() will return both the sequence of “forward” traces and the sequence of “backward” traces, and when it’s just used for evaluation it will just return one sequence of traces. In this case we’re printing the last traces in the sequence, which print as Python programs, and these Python programs are what gets executed by Thunder.

Let’s take a look at the execution trace for the training module’s forward:

@torch.no_grad()
@no_autocast
def augmented_forward_fn(t0, t4, t5, t21, t22, t38, t39, t55, t56, t72, t73):
  # t0
  # t4
  # t5
  # t21
  # t22
  # t38
  # t39
  # t55
  # t56
  # t72
  # t73
  t1 = torch.nn.functional.linear(t0, t4, t5)  # t1
    # t1 = ltorch.linear(t0, t4, t5)  # t1
      # t1 = prims.linear(t0, t4, t5)  # t1
  [t10, t2, t7] = nvFusion0(t1)
    # t2 = prims.gt(t1, 0.0)  # t2
    # t3 = prims.where(t2, t1, 0.0)  # t3
    # t6 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6
    # t7 = prims.lt(t6, 0.9)  # t7
    # t8 = prims.convert_element_type(t7, dtypes.float32)  # t8
    # t9 = prims.mul(t3, t8)  # t9
    # t10 = prims.mul(t9, 1.1111111111111112)  # t10
  del t1
  t11 = torch.nn.functional.linear(t10, t21, t22)  # t11
    # t11 = ltorch.linear(t10, t21, t22)  # t11
      # t11 = prims.linear(t10, t21, t22)  # t11
  [t12, t15, t18] = nvFusion1(t11)
    # t12 = prims.gt(t11, 0.0)  # t12
    # t13 = prims.where(t12, t11, 0.0)  # t13
    # t14 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t14
    # t15 = prims.lt(t14, 0.9)  # t15
    # t16 = prims.convert_element_type(t15, dtypes.float32)  # t16
    # t17 = prims.mul(t13, t16)  # t17
    # t18 = prims.mul(t17, 1.1111111111111112)  # t18
  del t11
  t19 = torch.nn.functional.linear(t18, t38, t39)  # t19
    # t19 = ltorch.linear(t18, t38, t39)  # t19
      # t19 = prims.linear(t18, t38, t39)  # t19
  [t20, t25, t28] = nvFusion2(t19)
    # t20 = prims.gt(t19, 0.0)  # t20
    # t23 = prims.where(t20, t19, 0.0)  # t23
    # t24 = prims.uniform((8, 784), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t24
    # t25 = prims.lt(t24, 0.9)  # t25
    # t26 = prims.convert_element_type(t25, dtypes.float32)  # t26
    # t27 = prims.mul(t23, t26)  # t27
    # t28 = prims.mul(t27, 1.1111111111111112)  # t28
  del t19
  t29 = torch.nn.functional.linear(t28, t55, t56)  # t29
    # t29 = ltorch.linear(t28, t55, t56)  # t29
      # t29 = prims.linear(t28, t55, t56)  # t29
  [t30, t33, t36] = nvFusion3(t29)
    # t30 = prims.gt(t29, 0.0)  # t30
    # t31 = prims.where(t30, t29, 0.0)  # t31
    # t32 = prims.uniform((8, 28), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t32
    # t33 = prims.lt(t32, 0.9)  # t33
    # t34 = prims.convert_element_type(t33, dtypes.float32)  # t34
    # t35 = prims.mul(t31, t34)  # t35
    # t36 = prims.mul(t35, 1.1111111111111112)  # t36
  del t29
  t37 = torch.nn.functional.linear(t36, t72, t73)  # t37
    # t37 = ltorch.linear(t36, t72, t73)  # t37
      # t37 = prims.linear(t36, t72, t73)  # t37
  [t41, t44] = nvFusion4(t37)
    # t40 = prims.uniform((8, 10), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t40
    # t41 = prims.lt(t40, 0.9)  # t41
    # t42 = prims.convert_element_type(t41, dtypes.float32)  # t42
    # t43 = prims.mul(t37, t42)  # t43
    # t44 = prims.mul(t43, 1.1111111111111112)  # t44
  del t37
  return {'output': (t44, ()), 'flat_args': [t0, t4, t5, t21, t22, t38, t39, t55, t56, t72, t73], 'flat_output': (t44,)}, ((t0, t10, t12, t15, t18, t2, t20, t21, t25, t28, t30, t33, t36, t38, t41, t55, t7, t72), (1.1111111111111112, 1.1111111111111112, 1.1111111111111112, 1.1111111111111112, 1.1111111111111112))

There’s a lot going on here, and if you’d like to get into the details then keep reading! But we can see that the trace is a functional Python function, and Thunder has produced several groups of primitives that are sent to nvFuser. Instead of leaving these primitives directly in the module, nvFuser has produced several optimized kernels (fusions) and inserted them into the program (nvFusion0, nvFusion1, …). Under each fusion (in comments) are the “primitive” operations that describe precisely what each group does, although how each fusion is executed is entirely up to nvFuser.