Extending Thunder

This notebook shows how to use thunder’s extend submodule to add new operations and custom grad and execution transforms.

[1]:
import sys
sys.path.insert(0, '..')
from numbers import Number

import thunder
import thunder.torch as ltorch
from thunder.core.devices import DeviceType
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import grad, put_grads, get_grad

import torch
import numpy as np

torch.manual_seed(42);
[2]:
from thunder.extend import OperatorExecutor, register_executor
[3]:
# Registers a new operator executor
myex = OperatorExecutor("myex", version="0.1")
register_executor(myex)
[3]:
myex
[4]:
# Our operator executor will use the "multimul" function as a new example operator.
#   This function uses NumPy to perform two multiplications of four inputs.
#   This function's contrived, but will be useful to illustrate the extend submodule's capabilities.
def multimul_impl(
        a: Number | torch.Tensor,
        b: Number | torch.Tensor,
        c: Number | torch.Tensor,
        d: Number | torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:
    return np.multiply(a, b), np.multiply(c, d)
[5]:
# We can verify that multimul is a valid Python function that operates on PyTorch tensors -- at least PyTorch tensors on the CPU.
a = torch.randn((2, 2))
b = torch.randn((2, 2))
multimul_impl(a, b, a, b)
[5]:
(tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]),
 tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]))
[6]:
# To let thunder use multimul we need to define how it propagates metadata. This can be done by directly defining a "meta function",
# of by defining a traceable "like" function that describes what multimul does in terms of existing thunder operations.
#   The "like" function can be used for metadata propagation AND transforming the new operator, as we'll see below.
#   In this case, the "like" function just describes the two multiplications that multimul performs.
def multimul_like(
        a: Number | TensorProxy,
        b: Number | TensorProxy,
        c: Number | TensorProxy,
        d: Number | TensorProxy,
):
    return a * b, c * d
[7]:
# The "register_operator" method of operator executor's returns a "Symbol" object for multimul that can be called directly
#   from compiled thunder code.
multimul = myex.register_operator('multimul', like=multimul_like, fn=multimul_impl)
[8]:
# Example of calling the new multimul symbol
def foo(a, b, c, d):
    return multimul(a, b, c, d)

cfoo = thunder.jit(foo, executors=[myex])
cfoo(a, b, a, b)
[8]:
(tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]),
 tensor([[-0.3781, -0.0240],
         [ 0.5177, -0.1470]]))
[9]:
# The symbol is recorded, like other operations, into thunder's trace
thunder.last_traces(cfoo)[-1]
[9]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2, t_3):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  # t_2: "cpu f32[2, 2]"
  # t_3: "cpu f32[2, 2]"
  (t0, t1) = multimul(t_0, t_1, t_2, t_3)
    # t0 = ltorch.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
    # t1 = ltorch.mul(t_2, t_3)  # t1: "cpu f32[2, 2]"
      # t1 = prims.mul(t_2, t_3)  # t1: "cpu f32[2, 2]"
  del t_0, t_1, t_2, t_3
  return (t0, t1)
[10]:
# multimul is even differentiable because its "like" function is differentiable
a.requires_grad_(True)
b.requires_grad_(True)

cfoo_grad = grad(cfoo)
cfoo_grad(a, b, a, b)
print(thunder.last_traces(cfoo_grad)[-1])

a.requires_grad_(False)
b.requires_grad_(False)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1, t_2, t_3):
  # t_1: "cpu f32[2, 2]"
  t8 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t8: "cpu f32[2, 2]"
    # t8 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t8: "cpu f32[2, 2]"
      # t8 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t8: "cpu f32[2, 2]"
  t2 = torch.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
    # t2 = ltorch.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t8)  # t2: "cpu f32[2, 2]"
  del t_1
  # t_0: "cpu f32[2, 2]"
  t3 = torch.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
    # t3 = ltorch.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t8)  # t3: "cpu f32[2, 2]"
  del t_0, t8
  # t_3: "cpu f32[2, 2]"
  t9 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t9: "cpu f32[2, 2]"
    # t9 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t9: "cpu f32[2, 2]"
      # t9 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t9: "cpu f32[2, 2]"
  t6 = torch.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
    # t6 = ltorch.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
      # t6 = prims.mul(t_3, t9)  # t6: "cpu f32[2, 2]"
  del t_3
  # t_2: "cpu f32[2, 2]"
  t7 = torch.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
    # t7 = ltorch.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
      # t7 = prims.mul(t_2, t9)  # t7: "cpu f32[2, 2]"
  del t_2, t9
  return [t2, t3, t6, t7]
[10]:
tensor([[-1.1229, -0.1863],
        [ 2.2082, -0.6380]])
[11]:
# We can tell thunder to execute existing operations using multimul by defining a transform
#   from them to multimul, and a "checker" function that returns True when the
#   transform is valid and False otherwise.

# We can translate mul to multimul by ignoring the second multiplication
def mul_to_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
    result, _ = multimul(a, b, 0, 0)
    return result

# The "checker" function verifies that all inputs are CPU tensors or numbers, because NumPy
#   can't handle other inputs
def mul_to_multimul_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:
    def is_cpu(x: Number | TensorProxy) -> bool:
        if isinstance(a, TensorProxy):
            return a.device.devicetype == DeviceType.CPU
        return True

    return all(is_cpu(x) for x in (a, b))
[12]:
# The "register_implementation" method describes how to translate mul to multimul
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul)
[13]:
# Verifies the implementation of mul using multimul, and shows the execution transform
def bar(a, b):
    return a * b
cbar = thunder.jit(bar, executors=[myex])
cbar(a, b)
thunder.last_traces(cbar)[-1]
[13]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  (t0, _) = multimul(t_0, t_1, 0, 0)
    # t0 = ltorch.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t_0, t_1)  # t0: "cpu f32[2, 2]"
  del t_0, t_1
  return t0
[14]:
# Execution transforms happen AFTER semantic transforms like grad, so even when computing the grad
#   of mul (which involves two multiplications to compute the grad) we still see multimul in the
#   execution trace
a.requires_grad_(True)
b.requires_grad_(True)

cbar_grad = grad(cbar)
cbar_grad(a, b)
thunder.last_traces(cbar_grad)[-1]
[14]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_1: "cpu f32[2, 2]"
  t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
    # t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
      # t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t4: "cpu f32[2, 2]"
  (t2, _) = multimul(t_1, t4, 0, 0)
    # t2 = ltorch.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
  del t_1
  # t_0: "cpu f32[2, 2]"
  (t3, _) = multimul(t_0, t4, 0, 0)
    # t3 = ltorch.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
  del t_0, t4
  return [t2, t3]
[15]:
# In the above grad trace there are two multimuls, and both ignore one of their multiplications.
#   It would be more efficient to perform just one multimul, and we can make this happen
#   by defining a new grad transform for mul that calls multimul once.
#   thunder's grad transforms are defined in a novel way that's not the focus of this notebook,
#   but below we define the grad transform to use multimul.
def mymul_grad(a: TensorProxy, b: TensorProxy) -> TensorProxy:
    fwd = a * b

    g = get_grad(fwd)
    a_grad, b_grad = multimul(b, g, a, g)
    put_grads((a, b), (a_grad, b_grad))

    return fwd

# Re-registers the implementation, including the execution transform and now a grad transform
myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_multimul, grad_transform=mymul_grad)
[16]:
# Verifies our new grad transform is used and that a single multimul call is made
cbar_grad = grad(cbar)
cbar_grad(a, b)
thunder.last_traces(cbar_grad)[-1]
[16]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  t4 = torch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
    # t4 = ltorch.full((2, 2), 1.0, device=torch.device("cpu"), dtype=torch.float32)  # t4: "cpu f32[2, 2]"
      # t4 = prims.full((2, 2), 1.0, device=devices.Device("cpu"), dtype=dtypes.float32)  # t4: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  (t2, t3) = multimul(t_1, t4, t_0, t4)
    # t2 = ltorch.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
      # t2 = prims.mul(t_1, t4)  # t2: "cpu f32[2, 2]"
    # t3 = ltorch.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
      # t3 = prims.mul(t_0, t4)  # t3: "cpu f32[2, 2]"
  del t_1, t4, t_0
  return [t2, t3]
[17]:
# Some operations may require inputs have particular properties (like be contiguous), or a transform may wish
#   to interleave torch operations with new operations. The transform function supports this. Here
#   we can see an example where the inputs to multimul are made contiguous before it's called
def mul_to_contiguous_multimul(a: Number | TensorProxy, b: Number | TensorProxy) -> TensorProxy:
    a = a.contiguous()
    b = b.contiguous()
    result, _ = multimul(a, b, 0, 0)
    return result

myex.register_implementation(ltorch.mul, checker=mul_to_multimul_checker, execution_transform=mul_to_contiguous_multimul)
[18]:
# Verifies the new "prologue" for multimul works as expected. Note that the contiguous operations are
#   executed by PyTorch, and don't have to be executed by your executor
a.requires_grad_(False)
b.requires_grad_(False)

def caz(a, b):
    return a * b
ccaz = thunder.jit(caz, executors=[myex])
ccaz(a, b)
thunder.last_traces(ccaz)[-1]
[18]:
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_0: "cpu f32[2, 2]"
  # t_1: "cpu f32[2, 2]"
  t1 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0)  # t1: "cpu f32[2, 2]"
    # t1 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t1: "cpu f32[2, 2]"
      # t1 = prims.stride_order(t_0, (1, 0))  # t1: "cpu f32[2, 2]"
  del t_0
  t2 = Tensor.contiguous(t_1, memory_format=_torch_memory_format_0)  # t2: "cpu f32[2, 2]"
    # t2 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0)  # t2: "cpu f32[2, 2]"
      # t2 = prims.stride_order(t_1, (1, 0))  # t2: "cpu f32[2, 2]"
  del t_1
  (t0, _) = multimul(t1, t2, 0, 0)
    # t0 = ltorch.mul(t1, t2)  # t0: "cpu f32[2, 2]"
      # t0 = prims.mul(t1, t2)  # t0: "cpu f32[2, 2]"
  del t1, t2
  return t0
[19]:
# NVIDIA's APEX cross-entropy executor is a good example of a real-world operator executor. It defines
#   fast forward and backward functions for torch.nn.functional.cross_entropy. We can see its custom
#   fwd and bwd operations below
# NOTE This cell and the following cells require the apex executor be installed to run properly
dtype = torch.float32
device = 'cuda'
logits = torch.randn([2048, 50257], device=device, dtype=ltorch.to_torch_dtype(dtype), requires_grad=False)
labels = torch.randint(0, 50257, [2048], device=device)

from thunder.executors.apex_entropyex import apex_ex

def foo(logits, labels):
    return torch.nn.functional.cross_entropy(logits, labels, reduction="mean", ignore_index=-1)
cfoo = thunder.jit(foo, executors=[apex_ex])
[20]:
# Shows the forward operation
cfoo(logits, labels)
thunder.last_traces(cfoo)[-1]
[20]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_0: "cuda:0 f32[2048, 50257]"
  # t_1: "cuda:0 i64[2048]"
  (t18, _) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
  del t_0, t_1
  return t18
[21]:
# Shows APEX's custom forward and backward operations, plus additional PyTorch operations between the two
logits.requires_grad_(True)

cfoo_grad = grad(cfoo)

cfoo_grad(logits, labels)
thunder.last_traces(cfoo_grad)[-1]
[21]:
# Constructed by Delete Last Used (took 0 milliseconds)
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(t_0, t_1):
  # t_0: "cuda:0 f32[2048, 50257]"
  # t_1: "cuda:0 i64[2048]"
  (_, t1) = apex_cross_entropy(t_0, t_1, 'mean', 0.0)
  t6 = Tensor.contiguous(t_0, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[2048, 50257]"
    # t6 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0)  # t6: "cuda:0 f32[2048, 50257]"
      # t6 = prims.stride_order(t_0, (1, 0))  # t6: "cuda:0 f32[2048, 50257]"
  del t_0
  t8 = torch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[]"
    # t8 = ltorch.full((), 1.0, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[]"
      # t8 = prims.full((), 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[]"
  t12 = torch.unsqueeze(t8, 0)  # t12: "cuda:0 f32[1]"
    # t12 = ltorch.unsqueeze(t8, 0)  # t12: "cuda:0 f32[1]"
      # t12 = prims.broadcast_in_dim(t8, [1], [])  # t12: "cuda:0 f32[1]"
  del t8
  t3 = Tensor.expand(t12, [1])  # t3: "cuda:0 f32[1]"
    # t3 = ltorch.expand(t12, [1])  # t3: "cuda:0 f32[1]"
      # t3 = prims.broadcast_in_dim(t12, (1,), (0,))  # t3: "cuda:0 f32[1]"
  del t12
  t4 = Tensor.expand(t3, (2048,))  # t4: "cuda:0 f32[2048]"
    # t4 = ltorch.expand(t3, (2048,))  # t4: "cuda:0 f32[2048]"
      # t4 = prims.broadcast_in_dim(t3, (2048,), (0,))  # t4: "cuda:0 f32[2048]"
  del t3
  t5 = torch.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
    # t5 = ltorch.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
      # t5 = prims.mul(t4, 0.00048828125)  # t5: "cuda:0 f32[2048]"
  del t4
  t7 = apex_cross_entropy_backward(t5, t6, target=t_1, max_log_sum_exp=t1, label_smoothing=0.0)  # t7: "cuda:0 f32[2048, 50257]"
  del t5, t6, t1, t_1
  return [t7]