Extending Thunder

This section describes how to add an executor to Thunder for a PyTorch operation.

First, define a Python function with the same signature as the targeted operation, and have it call your implementation. For example, the Apex executor for torch.nn.functional.cross_entropy might define its implementation like:

import torch
import xentropy_cuda

def apex_xentropy(
    a: torch.Tensor,  # a is an actual PyTorch tensor
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, half_to_float)

When this implementation is used it will be called with actual PyTorch tensors, and not with proxies.

Next, define a “checker” function with the same signature as the targeted operation that returns True if your operation can execute the targeted operation and False otherwise. Checkers, unlike the implementations, are called with proxies, and not actual PyTorch tensors, because they’re called at optimization time. The purpose of a checker function is to let executors target only specific inputs to an operation, and defer to another executor on other inputs.

A checker function for the Apex executor might look like:

from thunder.core.proxies import TensorProxy

def apex_xentropy_checker(
    a: TensorProxy,  # a is a proxy
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
  # Apex's xentropy only supports "sum", "mean" or "none" reductions
  if reduction not in ["sum", "mean", "none"]:
    return False

  return True

Create a mapping from the name of the PyTorch operation to your replacement implementation’s name, its checker, and its implementation:

_op_to_xentropy = {
    "torch.nn.functional.cross_entropy": ("apex_xentropy", apex_xentropy_checker, apex_xentropy),
}

Then define a registration function that practitioners can call to access your executor:

def register_apex_xentropyex(*, add_to_default_executors: bool = True) -> None:
    from thunder.executors import add_operator_executor

    return add_operator_executor("apex_xentropy", _op_to_xentropy, add_to_default_executors=add_to_default_executors)

You can test your executor by registering it, compiling a function that calls the targeted operator, and then verifying that your operation is called (by inspecting the execution trace) and producing the correct output. A good example of this is the tests for the Apex executor.