Extending Thunder

This notebook shows how to use thunder’s extend submodule to add new operations and custom grad and execution transforms.

import sys
sys.path.insert(0, '..')
from numbers import Number

import thunder
import thunder.torch as ltorch
from thunder.core.devices import DeviceType
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import grad, put_grads, get_grad

import torch
import numpy as np

from thunder.extend import OperatorExecutor, register_executor
# Registers a new operator executor
myex = OperatorExecutor("myex", version="0.1")
# Our operator executor will use the "multimul" function as a new example operator.
#   This function uses NumPy to perform two multiplications of four inputs.
#   This function's contrived, but will be useful to illustrate the extend submodule's capabilities.
def multimul_impl(
        a: Number | torch.Tensor,
        b: Number | torch.Tensor,
        c: Number | torch.Tensor,
        d: Number | torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:
    return np.multiply(a, b), np.multiply(c, d)
# We can verify that multimul is a valid Python function that operates on PyTorch tensors -- at least PyTorch tensors on the CPU.
a = torch.randn((2, 2))
b = torch.randn((2, 2))
multimul_impl(a, b, a, b)
(tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]),
 tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]))
# To let thunder use multimul we need to define how it propagates metadata. This can be done by directly defining a "meta function",
# of by defining a traceable "like" function that describes what multimul does in terms of existing thunder operations.
#   The "like" function can be used for metadata propagation AND transforming the new operator, as we'll see below.
#   In this case, the "like" function just describes the two multiplications that multimul performs.
def multimul_like(
        a: Number | TensorProxy,
        b: Number | TensorProxy,
        c: Number | TensorProxy,
        d: Number | TensorProxy,
    return a * b, c * d
# The "register_operator" method of operator executor's returns a "Symbol" object for multimul that can be called directly
#   from compiled thunder code.
multimul = myex.register_operator('multimul', like=multimul_like, fn=multimul_impl)
# Example of calling the new multimul symbol
def foo(a, b, c, d):
    return multimul(a, b, c, d)

cfoo = thunder.jit(foo, executors=[myex])
cfoo(a, b, a, b)
(tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]),
 tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]))
# The symbol is recorded, like other operations, into thunder's trace
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1, t_2, t_3):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  # t_2: "cpu f32[2, 2]"
  # t_3: "cpu f32[2, 2]"
  (t0, t1) = multimul(t_0, t_1, t_2, t_3)
    # t0 = ltorch.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
    # t1 = ltorch.mul(t_2, t_3)  # t1: "cpu f32[2, 2]"
      # t1 = prims.mul(t_2, t_3)  # t1: "cpu f32[2, 2]"
  del t_0, t_1, t_2, t_3
  return (t0, t1)
# multimul is even differentiable because its "like" function is differentiable

cfoo_grad = grad(cfoo)
cfoo_grad(a, b, a, b)

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1, t_2, t_3):
  # t_1: "cpu f32[2, 2]"
  t8 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t8: "cpu f32[2, 2]"
    # t8 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t8: "cpu f32[2, 2]"
      # t8 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t8: "cpu f32[2, 2]"
  t2 = torch.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
    # t2 = ltorch.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
  del t_1
  # t_0: "cpu f32[2, 2]"
  t3 = torch.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
    # t3 = ltorch.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
  del t_0, t8
  # t_3: "cpu f32[2, 2]"
  t9 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t9: "cpu f32[2, 2]"
    # t9 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t9: "cpu f32[2, 2]"
      # t9 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t9: "cpu f32[2, 2]"
  t6 = torch.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
    # t6 = ltorch.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
      # t6 = prims.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
  del t_3
  # t_2: "cpu f32[2, 2]"
  t7 = torch.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
    # t7 = ltorch.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
      # t7 = prims.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
  del t_2, t9
  return [t2, t3, t6, t7]
tensor([[-1.1229, -0.1863],
        [ 2.2082, -0.6380]])
# We can tell thunder to execute existing operations using multimul by defining a transform
#   from them to multimul, and a "checker" function that returns True when the
#   transform is valid and False otherwise.

# We can translate mul to multimul by ignoring the second multiplication
def mul_to_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
    result, _ = multimul(a, b, 0, 0)
    return result

# The "checker" function verifies that all inputs are CPU tensors or numbers, because NumPy
#   can't handle other inputs
def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:
    def is_cpu(x: Number | TensorProxy) -> bool:
        if isinstance(a, TensorProxy):
            return a.device.devicetype == DeviceType.CPU
        return True

    return all(is_cpu(x) for x in (a, b))
# The "register_implementation" method describes how to translate mul to multimul
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul)
# Verifies the implementation of mul using multimul, and shows the execution transform
def bar(a, b):
    return a * b
cbar = thunder.jit(bar, executors=[myex])
cbar(a, b)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  (t0, _) = multimul(t_0, t_1, 0, 0)
    # t0 = ltorch.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
  del t_0, t_1
  return t0
# Execution transforms happen AFTER semantic transforms like grad, so even when computing the grad
#   of mul (which involves two multiplications to compute the grad) we still see multimul in the
#   execution trace

cbar_grad = grad(cbar)
cbar_grad(a, b)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_1: "cpu f32[2, 2]"
  t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
    # t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
      # t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t4: "cpu f32[2, 2]"
  (t2, _) = multimul(t_1, t4, 0, 0)
    # t2 = ltorch.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
  del t_1
  # t_0: "cpu f32[2, 2]"
  (t3, _) = multimul(t_0, t4, 0, 0)
    # t3 = ltorch.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
  del t_0, t4
  return [t2, t3]
# In the above grad trace there are two multimuls, and both ignore one of their multiplications.
#   It would be more efficient to perform just one multimul, and we can make this happen
#   by defining a new grad transform for mul that calls multimul once.
#   thunder's grad transforms are defined in a novel way that's not the focus of this notebook,
#   but below we define the grad transform to use multimul.
def mymul_grad(a: TensorProxy, b: TensorProxy) -> TensorProxy:
    fwd = a * b

    g = get_grad(fwd)
    a_grad, b_grad = multimul(b, g, a, g)
    put_grads((a, b), (a_grad, b_grad))

    return fwd

# Re-registers the implementation, including the execution transform and now a grad transform
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul, grad_transform=mymul_grad)
# Verifies our new grad transform is used and that a single multimul call is made
cbar_grad = grad(cbar)
cbar_grad(a, b)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
    # t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
      # t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t4: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  (t2, t3) = multimul(t_1, t4, t_0, t4)
    # t2 = ltorch.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
    # t3 = ltorch.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
  del t_1, t4, t_0
  return [t2, t3]
# Some operations may require inputs have particular properties (like be contiguous), or a transform may wish
#   to interleave torch operations with new operations. The transform function supports this. Here
#   we can see an example where the inputs to multimul are made contiguous before it's called
def mul_to_contiguous_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
    a = a.contiguous()
    b = b.contiguous()
    result, _ = multimul(a, b, 0, 0)
    return result

myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_contiguous_multimul)
# Verifies the new "prologue" for multimul works as expected. Note that the contiguous operations are
#   executed by PyTorch, and don't have to be executed by your executor

def caz(a, b):
    return a * b
ccaz = thunder.jit(caz, executors=[myex])
ccaz(a, b)
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  t1 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0)  # t1: "cpu f32[2, 2]"
    # t1 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t1: "cpu f32[2, 2]"
      # t1 = prims.stride_order(t_0, (1, 0))  # t1: "cpu f32[2, 2]"
  del t_0
  t2 = Tensor.contiguous(t_1, memory_format=_torch_memory_format_0)  # t2: "cpu f32[2, 2]"
    # t2 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0)  # t2: "cpu f32[2, 2]"
      # t2 = prims.stride_order(t_1, (1, 0))  # t2: "cpu f32[2, 2]"
  del t_1
  (t0, _) = multimul(t1, t2, 0, 0)
    # t0 = ltorch.mul(t1, t2)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t1, t2)  # t0: "cpu f32[2, 2]"
  del t1, t2
  return t0
# NVIDIA's APEX cross-entropy executor is a good example of a real-world operator executor. It defines
#   fast forward and backward functions for torch.nn.functional.cross_entropy. We can see its custom
#   fwd and bwd operations below
# NOTE This cell and the following cells require the apex executor be installed to run properly
dtype = torch.float32
device = 'cuda'
logits = torch.randn([2048, 50257], device=device, dtype=ltorch.to_torch_dtype(dtype), requires_grad=False)
labels = torch.randint(0, 50257, [2048], device=device)

from thunder.executors.apex_entropyex import apex_ex

def foo(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction="mean", ignore_index=-1)
cfoo = thunder.jit(foo, executors=[apex_ex])
# Shows the forward operation
cfoo(logits, labels)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_0: "cuda:0 f32[2048, 50257]"
  # t_1: "cuda:0 i64[2048]"
  (t18, _) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
  del t_0, t_1
  return t18
# Shows APEX's custom forward and backward operations, plus additional PyTorch operations between the two

cfoo_grad = grad(cfoo)

cfoo_grad(logits, labels)
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

def computation(t_0, t_1):
  # t_0: "cuda:0 f32[2048, 50257]"
  # t_1: "cuda:0 i64[2048]"
  (_, t1) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
  t6 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[2048, 50257]"
    # t6 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[2048, 50257]"
      # t6 = prims.stride_order(t_0, (1, 0))  # t6: "cuda:0 f32[2048, 50257]"
  del t_0
  t8 = torch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[]"
    # t8 = ltorch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[]"
      # t8 = prims.full((), 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[]"
  t12 = torch.unsqueeze(t8, 0)  # t12: "cuda:0 f32[1]"
    # t12 = ltorch.unsqueeze(t8, 0)  # t12: "cuda:0 f32[1]"
      # t12 = prims.broadcast_in_dim(t8, [1], [])  # t12: "cuda:0 f32[1]"
  del t8
  t3 = Tensor.expand(t12, [1])  # t3: "cuda:0 f32[1]"
    # t3 = ltorch.expand(t12, [1])  # t3: "cuda:0 f32[1]"
      # t3 = prims.broadcast_in_dim(t12, (1,), (0,))  # t3: "cuda:0 f32[1]"
  del t12
  t4 = Tensor.expand(t3, (2048,))  # t4: "cuda:0 f32[2048]"
    # t4 = ltorch.expand(t3, (2048,))  # t4: "cuda:0 f32[2048]"
      # t4 = prims.broadcast_in_dim(t3, (2048,), (0,))  # t4: "cuda:0 f32[2048]"
  del t3
  t5 = torch.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
    # t5 = ltorch.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
      # t5 = prims.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
  del t4
  t7 = apex_cross_entropy_backward(t5, t6, target=t_1, max_log_sum_exp=t1, label_smoothing=0.0)  # t7: "cuda:0 f32[2048, 50257]"
  del t5, t6, t1, t_1
  return [t7]