Defining custom forward and backward for existing operators

We are going to add custom executor for forward and backward of torch.nn.functional.cross_entropy operator.

Here’s SoftmaxCrossEntropyLoss definition from https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py:

import torch

import xentropy_cuda


class SoftmaxCrossEntropyLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
        losses, max_log_sum_exp = xentropy_cuda.forward(
            logits, labels, smoothing, half_to_float)
        losses.masked_fill_(labels==padding_idx, 0)

        ctx.save_for_backward(logits, max_log_sum_exp, labels,
            torch.FloatTensor([smoothing]),
            torch.LongTensor([padding_idx]))

        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors

        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==padding_idx.item(), 0)
        grad_logits = xentropy_cuda.backward(
            grad_loss.contiguous(), logits, max_log_sum_exp,
            labels, smoothing.item())

        return grad_logits, None, None, None, None
[1]:
import sys
sys.path.insert(0, '..')
import thunder
import torch
torch.manual_seed(42)

from thunder.core.proxies import TensorProxy

In Thunder, we define Executors to run given ops. Our executor will handle specific ops (rather than fusion regions), so our first thing is to create our own OperatorExecutorand register it with Thunder

[2]:
from thunder.extend import OperatorExecutor, register_executor
apex_xentropy_ex = OperatorExecutor("apex_xentropy_ex", version="0.1")
register_executor(apex_xentropy_ex)
[2]:
apex_xentropy_ex

To get a feel of what’s going on, let’s have a wrapper that prints function calls and their arguments.

[3]:
import functools

_indentation = 0
def _log(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _log_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _log(msg)
    _indentation = 2 + _indentation

def _log_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 2
    _log(msg)

def log(func):
    """A decorator for functions to log arguments and results."""
    name = func.__name__
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if isinstance(v, tuple):
            return "({})".format(pp_values(v))
        elif isinstance(v, thunder.core.proxies.TensorProxy):
            return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
        elif isinstance(v, torch.Tensor):
            return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
        _log_indent("call {}({})".format(name, pp_values(args)))
        res = func(*args)
        _log_unindent("|<- {} = {}\n".format(name, pp(res)))
        return res

    return func_wrapper

We want to define operators apex_xentropy_forward and apex_xentropy_backward. In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor. So we do this for the forward…

[4]:
@log
def apex_xentropy_forward_meta(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    max_log_sum_exp = TensorProxy(like=target)
    if reduction == "none":
        return TensorProxy(shape=(a.shape[0],), dtype=a.dtype, device=a.device,
                               requires_grad=a.requires_grad), max_log_sum_exp
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

import xentropy_cuda

@log
def apex_xentropy_forward_impl(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, False)

    if reduction == "none":
        losses = losses.to(a.dtype)
    else:
        raise ValueError(f"Invalid reduction: {reduction}")

    return losses, max_log_sum_exp


apex_xentropy_forward = apex_xentropy_ex.register_operator(
    "apex_xentropy_forward", meta=apex_xentropy_forward_meta, fn=apex_xentropy_forward_impl
)


…and the backward…

[5]:
@log
def apex_xentropy_backward_meta(
    grad,
    logits,
    labels,
    max_log_sum_exp,
    smoothing,
):
    return TensorProxy(like=logits)


@log
def apex_xentropy_backward_impl(
    grad,
    logits,
    labels,
    max_log_sum_exp,
    smoothing,
):
    return xentropy_cuda.backward(grad.contiguous(), logits, max_log_sum_exp, labels, smoothing)

apex_xentropy_backward = apex_xentropy_ex.register_operator(
    "apex_xentropy_backward", meta=apex_xentropy_backward_meta, fn=apex_xentropy_backward_impl
)

Because Thunder currently does not allow keyword arguments passed to the operators, we define a convenience wrapper:

[6]:
def apex_xentropy(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    res, _ = apex_xentropy_forward(a, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
    return res

We can now thunder.jit functions using our operator:

[7]:
def loss_fn(logits, labels):
    return apex_xentropy(logits, labels, reduction="none")

jfn = thunder.jit(loss_fn)

logits = torch.randn([2048, 50257], device="cuda")
labels = torch.randint(0, 50257, [2048], device="cuda")

actual_result = jfn(logits, labels)
expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction="none")

print("deviation from pytorch implementation:", (actual_result - expected_result).abs().max().item())
call apex_xentropy_forward_meta(TensorProxy(name=t_0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t_1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))

call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.1940,  2.1614, -0.1721,  ..., -0.4797,  1.4608, -0.5221],
        [ 1.8288,  0.2116,  0.1760,  ..., -0.1599,  0.1195,  0.0073],
        [-2.1704,  1.0396,  2.2924,  ...,  0.6021,  0.6498, -0.6316],
        ...,
        [ 0.4908, -0.3445,  2.6618,  ..., -2.0946, -0.2890,  0.1500],
        [-1.0561, -1.3547, -1.0354,  ...,  0.4304, -0.7882, -0.5496],
        [-0.6883, -1.3283,  0.3513,  ..., -0.6951,  0.2013, -1.0238]],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 9132, 12067,  5347,  ...,  9268, 12534, 33582], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.7236, 11.9374, 11.0063,  ..., 11.7434,  9.5018, 10.8008],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3132, 11.3291, 11.3287,  ..., 11.3279, 11.3251, 11.3301],
       device='cuda:0'))

deviation from pytorch implementation: 9.5367431640625e-07

We can also inspect what program thunder recorded to admire the beauty of our operator being called:

[8]:
thunder.last_traces(jfn)[-1]
[8]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(logits, labels):
  # logits: "cuda:0 f32[2048, 50257]"
  # labels: "cuda:0 i64[2048]"
  (res, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
  del logits, labels
  return res

But it might be more awesome to have Thunder automatically use our new operators if applicable. We can define a transformation to do this for us. This consists of two parts: - a checkerfunction that takes the arguments of the function we want to replace (but with Tensor arguments replaced by TensorProxy ones) and outputs True if we handle this case and False if not. - an execution_transform that is just a function with the same parameters and same return value as the function we want to replace and does the compute (as you would expect by calling our operator).

Note that we attach this implementation to the thunder.torch.cross_entropy Symbol (an operator as appearing in Thunder traces, just like our apex_xentropy_forward is a Symbol).

[9]:
def apex_xentropy_checker(
    a: TensorProxy,
    /,
    target: TensorProxy,
    weight: None | TensorProxy = None,
    size_average = None,
    ignore_index: int = -100,
    reduce = None,
    reduction: str = "mean",
    label_smoothing: float = 0.0,
) -> bool:
    DeviceType = thunder.devices.DeviceType
    if a.device.devicetype != DeviceType.CUDA or target.device.devicetype != DeviceType.CUDA:
        return False

    probability_target: bool = thunder.core.utils.same_shape(a.shape, target.shape)
    if probability_target or label_smoothing > 0.0:
        return False

    torch_dtype: torch.dtype = thunder.torch.to_torch_dtype(a.dtype)
    if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32):
        return False

    if ignore_index >= 0:
        return False

    if weight is not None:
        return False

    # NOTE These parameters are deprecated and not supported
    if size_average is not None or reduce is not None:
        return False

    if reduction not in ["sum", "mean", "none"]:
        return False

    # Checks from
    # https://github.com/NVIDIA/apex/blob/7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6/apex/contrib/csrc/xentropy/xentropy_kernel.cu#L587-L590
    if a.ndim != 2:
        return False

    if target.ndim != 1:
        return False

    if a.shape[0] != target.shape[0]:
        return False

    if a.numel == 0:
        return False

    # Xentropy kernel produces incorrect results if a.shape[1] is less
    # than 30 and not a multiple of 4
    if a.shape[1] < 30 and a.shape[1] % 4 != 0:
        return False

    return True

from thunder.core.transforms import get_grad, put_grads


def cross_entropy_to_apex(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    loss, max_log_sum_exp = apex_xentropy_forward(
        a,
        target,
        weight,
        size_average,
        ignore_index,
        reduce,
        reduction,
        label_smoothing,
    )
    return loss

apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker,
                                execution_transform=cross_entropy_to_apex)

We now can run the “unmodified” PyTorch function with F.cross_entroy and still get our implementation (but don’t forget the executor in the call to the jit!):

[10]:
def loss_fn(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction="none")

jfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])

logits = torch.randn([2048, 50257], device="cuda")
labels = torch.randint(0, 50257, [2048], device="cuda")

actual_result = jfn(logits, labels)
expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction="none")

print("deviation from pytorch implementation:", (actual_result - expected_result).abs().max().item())

print(thunder.last_traces(jfn)[-1])
call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t19, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t18, shape=(2048,), dtype=int64, device=cuda:0))

call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 1.2891, -0.2912,  0.6866,  ..., -1.5067,  1.3132, -0.7352],
        [-1.9077, -0.8366, -0.0747,  ...,  1.6109, -0.7460,  0.7346],
        [-1.0830, -0.2586,  0.0402,  ..., -0.2030, -1.0907, -1.7308],
        ...,
        [ 0.5805, -0.0830, -0.4658,  ..., -0.1023, -1.3720,  0.1850],
        [-0.8181,  1.3273,  0.8034,  ...,  1.2658, -1.4824,  0.0482],
        [ 0.9964, -1.8733,  0.3547,  ...,  0.0190, -0.3228,  0.4827]],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 8137, 23633, 42622,  ..., 39128, 39817, 18664], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.9479, 11.7810,  9.1981,  ..., 10.1080, 10.4095, 10.5884],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3268, 11.3203, 11.3294,  ..., 11.3251, 11.3333, 11.3225],
       device='cuda:0'))

deviation from pytorch implementation: 9.5367431640625e-07
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(logits, labels):
  # logits: "cuda:0 f32[2048, 50257]"
  # labels: "cuda:0 i64[2048]"
  (t17, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
  del logits, labels
  return t17

So what is with the backward?

Well, we can define a gradient function and register it along with our implementation.

We thought a lot about how our extension point for gradients looked like - PyTorch’s autograd.Functions is probably the most well-known way - and we felt that it would be nice to make the connection between tensors in the computation and their gradients explicit.

So the grad transform we implement below is a function that does the following:

  • it takes the same arguments as the forward,

  • it computes the forward from its arguments,

  • it then uses get_grad to obtain the required gradients for the forward outputs,

  • computes the gradients for the inputs (this is the backward),

  • finally attaches the computed gradients to the respective tensors with put_grad

We supply the grad function as an additional argument of register_implementation.

[11]:
@log
def apex_cross_entropy_grad(
    a,
    target,
    weight=None,
    size_average=None,
    ignore_index=-100,
    reduce=None,
    reduction="mean",
    label_smoothing=0.0,
):
    loss, max_log_sum_exp = apex_xentropy_forward(
        a,
        target,
        weight,
        size_average,
        ignore_index,
        reduce,
        reduction,
        label_smoothing,
    )
    grad = get_grad(loss)
    grad_logits = apex_xentropy_backward(
        grad,
        a,
        target,
        max_log_sum_exp,
        label_smoothing,
    )
    put_grads((a,), (grad_logits,))
    return loss

apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker,
                                execution_transform=cross_entropy_to_apex, grad_transform=apex_cross_entropy_grad)

With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.

[12]:
from thunder import torch as ltorch

torch.manual_seed(0)

logits = torch.randn([2048, 50257], device="cuda", requires_grad=True)
labels = torch.randint(0, 50257, [2048], device="cuda")

def loss_fn(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction="none", ignore_index=-1)

cfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])

actual_loss = cfn(logits, labels)
go = torch.randn_like(actual_loss)

actual_grad, = torch.autograd.grad(actual_loss, logits, go)

expected_loss = loss_fn(logits, labels)
expected_grad, = torch.autograd.grad(expected_loss, logits, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())

thunder.last_traces(cfn)[-1]
call apex_cross_entropy_grad(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])
    call apex_xentropy_forward_meta(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])
    |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))

    call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=label_smoothing, value=0.0])
    |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)

|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)

call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))

call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=f0, value=0.0])
|<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)

call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00,  ...,  4.5115e-01,
          2.4087e-01,  1.9543e+00],
        [ 7.5610e-03, -4.9079e-01,  3.6572e-01,  ...,  2.5072e+00,
          9.0470e-01, -1.4305e+00],
        [-4.4104e-01, -7.6137e-01, -1.1172e+00,  ...,  5.9006e-02,
         -1.0212e+00,  3.0210e-02],
        ...,
        [-4.2869e+00,  1.4900e+00, -9.1910e-01,  ...,  3.6535e-03,
         -6.8372e-01,  7.1824e-01],
        [-4.2704e-02,  1.3505e+00,  2.1361e+00,  ..., -1.1139e+00,
          6.1626e-01,  4.8158e-01],
        [-7.3334e-01,  2.0820e+00,  3.7722e-02,  ..., -7.2141e-01,
          4.6871e-01,  7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902,  ..., 45225, 32145, 12167], device='cuda:0'), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.4000, 10.9672, 12.6648,  ..., 11.7144, 11.8293, 10.9396],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300,  ..., 11.3257, 11.3189, 11.3202],
       device='cuda:0'))

call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([ 0.8882, -0.0650, -1.2035,  ..., -0.4344, -0.0588, -2.5740],
       device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00,  ...,  4.5115e-01,
          2.4087e-01,  1.9543e+00],
        [ 7.5610e-03, -4.9079e-01,  3.6572e-01,  ...,  2.5072e+00,
          9.0470e-01, -1.4305e+00],
        [-4.4104e-01, -7.6137e-01, -1.1172e+00,  ...,  5.9006e-02,
         -1.0212e+00,  3.0210e-02],
        ...,
        [-4.2869e+00,  1.4900e+00, -9.1910e-01,  ...,  3.6535e-03,
         -6.8372e-01,  7.1824e-01],
        [-4.2704e-02,  1.3505e+00,  2.1361e+00,  ..., -1.1139e+00,
          6.1626e-01,  4.8158e-01],
        [-7.3334e-01,  2.0820e+00,  3.7722e-02,  ..., -7.2141e-01,
          4.6871e-01,  7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902,  ..., 45225, 32145, 12167], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300,  ..., 11.3257, 11.3189, 11.3202],
       device='cuda:0'), 0.0)
|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 4.2787e-06,  7.0495e-06,  7.6679e-07,  ...,  1.6936e-05,
          1.3724e-05,  7.6143e-05],
        [-7.9652e-07, -4.8391e-07, -1.1396e-06,  ..., -9.7005e-06,
         -1.9535e-06, -1.8908e-07],
        [-9.2972e-06, -6.7489e-06, -4.7280e-06,  ..., -1.5329e-05,
         -5.2049e-06, -1.4894e-05],
        ...,
        [-7.2011e-08, -2.3243e-05, -2.0894e-06,  ..., -5.2573e-06,
         -2.6439e-06, -1.0743e-05],
        [-6.8437e-07, -2.7565e-06, -6.0470e-06,  ..., -2.3447e-07,
         -1.3227e-06, -1.1561e-06],
        [-1.4990e-05, -2.5033e-04, -3.2410e-05,  ..., -1.5170e-05,
         -4.9872e-05, -6.3328e-05]], device='cuda:0')

Max error in loss: 9.5367431640625e-07
Max error in logits grad: 2.384185791015625e-07
[12]:
[# Constructed by Backward pass
 import torch
 from thunder.executors.torchex import no_autocast

 @torch.no_grad()
 @no_autocast
 def backward_fn(saved_for_backward, cotangents):
   # saved_for_backward: "Collection"
   # cotangents: "Collection"
   C0, \
   C1, \
   = saved_for_backward
   t2, \
   = cotangents
   logits, \
   labels, \
   t0, \
   = C0
   f0, \
   = C1
   t3 = apex_xentropy_backward(t2, logits, labels, t0, f0)  # t3: "cuda:0 f32[2048, 50257]"
   return (t3, None),
 # Constructed by Transform for execution (took 0 milliseconds)
 import torch
 from thunder.executors.torchex import no_autocast

 @torch.no_grad()
 @no_autocast
 def backward_fn(saved_for_backward, cotangents):
   # saved_for_backward: "Collection"
   # cotangents: "Collection"
   C0, \
   C1, \
   = saved_for_backward
   t2, \
   = cotangents
   logits, \
   labels, \
   t0, \
   = C0
   f0, \
   = C1
   t3 = apex_xentropy_backward(t2, logits, labels, t0, f0)  # t3: "cuda:0 f32[2048, 50257]"
   return (t3, None),
 # Constructed by Update Call Context (took 0 milliseconds)
 import torch
 from thunder.executors.torchex import no_autocast

 @torch.no_grad()
 @no_autocast
 def backward_fn(saved_for_backward, cotangents):
   # saved_for_backward: "Collection"
   # cotangents: "Collection"
   C0, \
   C1, \
   = saved_for_backward
   t2, \
   = cotangents
   labels, \
   logits, \
   t0, \
   = C0
   f0, \
   = C1
   t3 = apex_xentropy_backward(t2, logits, labels, t0, f0)  # t3: "cuda:0 f32[2048, 50257]"
   return (t3, None),
 # Constructed by Delete Last Used (took 0 milliseconds)
 import torch
 from thunder.executors.torchex import no_autocast

 @torch.no_grad()
 @no_autocast
 def backward_fn(saved_for_backward, cotangents):
   # saved_for_backward: "Collection"
   # cotangents: "Collection"
   C0, \
   C1, \
   = saved_for_backward
   clear_collection(saved_for_backward)
   del saved_for_backward
   t2, \
   = cotangents
   clear_collection(cotangents)
   del cotangents
   labels, \
   logits, \
   t0, \
   = C0
   clear_collection(C0)
   del C0
   f0, \
   = C1
   clear_collection(C1)
   del C1
   t3 = apex_xentropy_backward(t2, logits, labels, t0, f0)  # t3: "cuda:0 f32[2048, 50257]"
   del t2, logits, labels, t0, f0
   return (t3, None)]

Alternatively, we can also use the grad transform to get the gradient:

[13]:
logits = torch.randn([2048, 50257], device="cuda", requires_grad=True)
labels = torch.randint(0, 50257, [2048], device="cuda")

grad_jfn = thunder.core.transforms.grad(jfn)
actual_grad, = grad_jfn(logits, labels)

expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)


print("Difference:", (actual_grad - expected_grad).abs().max().item())
print(thunder.last_traces(grad_jfn)[-1])

call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
    call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
    |<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))

    call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)
    |<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)

|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)

call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390,  0.1760, -1.0790,  ...,  0.1695, -0.8082, -0.6984],
        [ 2.1555,  1.3938,  0.3928,  ...,  0.8937, -0.4949,  1.1610],
        [ 0.6784,  1.1188,  0.7508,  ..., -0.0941,  0.8380,  0.1878],
        ...,
        [-1.5834, -0.1573, -1.3511,  ...,  0.6167, -0.1083,  0.4116],
        [-0.5476,  0.5831,  0.0791,  ..., -0.4986, -0.5270,  0.0954],
        [ 0.2825, -1.0378, -0.5506,  ...,  0.0149,  1.3521, -1.0823]],
       device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569,  ...,  9798, 33992, 36123], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.0233, 11.9095, 11.2898,  ..., 10.9289, 10.7487, 10.7455],
       device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283,  ..., 11.3224, 11.3186, 11.3205],
       device='cuda:0'))

call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390,  0.1760, -1.0790,  ...,  0.1695, -0.8082, -0.6984],
        [ 2.1555,  1.3938,  0.3928,  ...,  0.8937, -0.4949,  1.1610],
        [ 0.6784,  1.1188,  0.7508,  ..., -0.0941,  0.8380,  0.1878],
        ...,
        [-1.5834, -0.1573, -1.3511,  ...,  0.6167, -0.1083,  0.4116],
        [-0.5476,  0.5831,  0.0791,  ..., -0.4986, -0.5270,  0.0954],
        [ 0.2825, -1.0378, -0.5506,  ...,  0.0149,  1.3521, -1.0823]],
       device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569,  ...,  9798, 33992, 36123], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283,  ..., 11.3224, 11.3186, 11.3205],
       device='cuda:0'), 0.0)
|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.0706e-05, 1.4403e-05, 4.1058e-06,  ..., 1.4309e-05, 5.3827e-06,
         6.0079e-06],
        [1.0461e-04, 4.8840e-05, 1.7949e-05,  ..., 2.9621e-05, 7.3879e-06,
         3.8697e-05],
        [2.3705e-05, 3.6822e-05, 2.5485e-05,  ..., 1.0948e-05, 2.7806e-05,
         1.4513e-05],
        ...,
        [2.4836e-06, 1.0338e-05, 3.1331e-06,  ..., 2.2417e-05, 1.0857e-05,
         1.8259e-05],
        [7.0235e-06, 2.1758e-05, 1.3145e-05,  ..., 7.3762e-06, 7.1699e-06,
         1.3360e-05],
        [1.6078e-05, 4.2941e-06, 6.9897e-06,  ..., 1.2304e-05, 4.6857e-05,
         4.1070e-06]], device='cuda:0')

Difference: 1.3969838619232178e-09
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(logits, labels):
  # logits: "cuda:0 f32[2048, 50257]"
  # labels: "cuda:0 i64[2048]"
  (_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
  t4 = torch.full((2048,), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[2048]"
    # t4 = ltorch.full((2048,), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[2048]"
      # t4 = prims.full((2048,), 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[2048]"
  t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0)  # t3: "cuda:0 f32[2048, 50257]"
  del t4, logits, labels, t0
  return [t3]

So let’s wrap up what we did here:

  • We defined a custom executor with custom operations (Symbols in Thunder language), each with a Meta- (data propagation) function and an implementation.

  • We defined and registered rules to map existing operations to our new operations. This allows us to use optimizations on our model without changing the model’s code!

  • We defined a gradient rule and saw how our automatic PyTorch Autograd integration or the explicit grad transform uses it.

Now go and implement your favourite optimized operators. We would love to hear about your use-cases!

[ ]: