Defining custom forward and backward for existing operators¶
We are going to add custom executor for forward and backward of torch.nn.functional.cross_entropy
operator.
Here’s SoftmaxCrossEntropyLoss
definition from https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py:
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
[1]:
import sys
sys.path.insert(0, '..')
import thunder
import torch
torch.manual_seed(42)
from thunder.core.proxies import TensorProxy
In Thunder, we define Executors to run given ops. Our executor will handle specific ops (rather than fusion regions), so our first thing is to create our own OperatorExecutor
and register it with Thunder
[2]:
from thunder.extend import OperatorExecutor, register_executor
apex_xentropy_ex = OperatorExecutor("apex_xentropy_ex", version="0.1")
register_executor(apex_xentropy_ex)
[2]:
apex_xentropy_ex
To get a feel of what’s going on, let’s have a wrapper that prints function calls and their arguments.
[3]:
import functools
_indentation = 0
def _log(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _log_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_log(msg)
_indentation = 2 + _indentation
def _log_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 2
_log(msg)
def log(func):
"""A decorator for functions to log arguments and results."""
name = func.__name__
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if isinstance(v, tuple):
return "({})".format(pp_values(v))
elif isinstance(v, thunder.core.proxies.TensorProxy):
return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
elif isinstance(v, torch.Tensor):
return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_log_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_log_unindent("|<- {} = {}\n".format(name, pp(res)))
return res
return func_wrapper
We want to define operators apex_xentropy_forward
and apex_xentropy_backward
. In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor. So we do this for the forward…
[4]:
@log
def apex_xentropy_forward_meta(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
max_log_sum_exp = TensorProxy(like=target)
if reduction == "none":
return TensorProxy(shape=(a.shape[0],), dtype=a.dtype, device=a.device,
requires_grad=a.requires_grad), max_log_sum_exp
else:
raise ValueError(f"Invalid reduction: {reduction}")
import xentropy_cuda
@log
def apex_xentropy_forward_impl(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
losses, max_log_sum_exp = xentropy_cuda.forward(a, target, label_smoothing, False)
if reduction == "none":
losses = losses.to(a.dtype)
else:
raise ValueError(f"Invalid reduction: {reduction}")
return losses, max_log_sum_exp
apex_xentropy_forward = apex_xentropy_ex.register_operator(
"apex_xentropy_forward", meta=apex_xentropy_forward_meta, fn=apex_xentropy_forward_impl
)
…and the backward…
[5]:
@log
def apex_xentropy_backward_meta(
grad,
logits,
labels,
max_log_sum_exp,
smoothing,
):
return TensorProxy(like=logits)
@log
def apex_xentropy_backward_impl(
grad,
logits,
labels,
max_log_sum_exp,
smoothing,
):
return xentropy_cuda.backward(grad.contiguous(), logits, max_log_sum_exp, labels, smoothing)
apex_xentropy_backward = apex_xentropy_ex.register_operator(
"apex_xentropy_backward", meta=apex_xentropy_backward_meta, fn=apex_xentropy_backward_impl
)
Because Thunder currently does not allow keyword arguments passed to the operators, we define a convenience wrapper:
[6]:
def apex_xentropy(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
res, _ = apex_xentropy_forward(a, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
return res
We can now thunder.jit
functions using our operator:
[7]:
def loss_fn(logits, labels):
return apex_xentropy(logits, labels, reduction="none")
jfn = thunder.jit(loss_fn)
logits = torch.randn([2048, 50257], device="cuda")
labels = torch.randint(0, 50257, [2048], device="cuda")
actual_result = jfn(logits, labels)
expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
print("deviation from pytorch implementation:", (actual_result - expected_result).abs().max().item())
call apex_xentropy_forward_meta(TensorProxy(name=t_0, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=t_1, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))
call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.1940, 2.1614, -0.1721, ..., -0.4797, 1.4608, -0.5221],
[ 1.8288, 0.2116, 0.1760, ..., -0.1599, 0.1195, 0.0073],
[-2.1704, 1.0396, 2.2924, ..., 0.6021, 0.6498, -0.6316],
...,
[ 0.4908, -0.3445, 2.6618, ..., -2.0946, -0.2890, 0.1500],
[-1.0561, -1.3547, -1.0354, ..., 0.4304, -0.7882, -0.5496],
[-0.6883, -1.3283, 0.3513, ..., -0.6951, 0.2013, -1.0238]],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 9132, 12067, 5347, ..., 9268, 12534, 33582], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.7236, 11.9374, 11.0063, ..., 11.7434, 9.5018, 10.8008],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3132, 11.3291, 11.3287, ..., 11.3279, 11.3251, 11.3301],
device='cuda:0'))
deviation from pytorch implementation: 9.5367431640625e-07
We can also inspect what program thunder recorded to admire the beauty of our operator being called:
[8]:
thunder.last_traces(jfn)[-1]
[8]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(logits, labels):
# logits: "cuda:0 f32[2048, 50257]"
# labels: "cuda:0 i64[2048]"
(res, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
del logits, labels
return res
But it might be more awesome to have Thunder automatically use our new operators if applicable. We can define a transformation to do this for us. This consists of two parts:
a
checker
function that takes the arguments of the function we want to replace (but withTensor
arguments replaced byTensorProxy
ones) and outputsTrue
if we handle this case andFalse
if not.an
execution_transform
that is just a function with the same parameters and same return value as the function we want to replace and does the compute (as you would expect by calling our operator).
Note that we attach this implementation to the thunder.torch.cross_entropy
Symbol (an operator as appearing in Thunder traces, just like our apex_xentropy_forward
is a Symbol).
[9]:
def apex_xentropy_checker(
a: TensorProxy,
/,
target: TensorProxy,
weight: None | TensorProxy = None,
size_average = None,
ignore_index: int = -100,
reduce = None,
reduction: str = "mean",
label_smoothing: float = 0.0,
) -> bool:
DeviceType = thunder.devices.DeviceType
if a.device.devicetype != DeviceType.CUDA or target.device.devicetype != DeviceType.CUDA:
return False
probability_target: bool = thunder.core.utils.same_shape(a.shape, target.shape)
if probability_target or label_smoothing > 0.0:
return False
torch_dtype: torch.dtype = thunder.torch.to_torch_dtype(a.dtype)
if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32):
return False
if ignore_index >= 0:
return False
if weight is not None:
return False
# NOTE These parameters are deprecated and not supported
if size_average is not None or reduce is not None:
return False
if reduction not in ["sum", "mean", "none"]:
return False
# Checks from
# https://github.com/NVIDIA/apex/blob/7b2e71b0d4013f8e2f9f1c8dd21980ff1d76f1b6/apex/contrib/csrc/xentropy/xentropy_kernel.cu#L587-L590
if a.ndim != 2:
return False
if target.ndim != 1:
return False
if a.shape[0] != target.shape[0]:
return False
if a.numel == 0:
return False
# Xentropy kernel produces incorrect results if a.shape[1] is less
# than 30 and not a multiple of 4
if a.shape[1] < 30 and a.shape[1] % 4 != 0:
return False
return True
from thunder.core.transforms import get_grad, put_grads
def cross_entropy_to_apex(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
loss, max_log_sum_exp = apex_xentropy_forward(
a,
target,
weight,
size_average,
ignore_index,
reduce,
reduction,
label_smoothing,
)
return loss
apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker,
execution_transform=cross_entropy_to_apex)
We now can run the “unmodified” PyTorch function with F.cross_entroy
and still get our implementation (but don’t forget the executor in the call to the jit!):
[10]:
def loss_fn(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels, reduction="none")
jfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])
logits = torch.randn([2048, 50257], device="cuda")
labels = torch.randint(0, 50257, [2048], device="cuda")
actual_result = jfn(logits, labels)
expected_result = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
print("deviation from pytorch implementation:", (actual_result - expected_result).abs().max().item())
print(thunder.last_traces(jfn)[-1])
call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t19, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t18, shape=(2048,), dtype=int64, device=cuda:0))
call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 1.2891, -0.2912, 0.6866, ..., -1.5067, 1.3132, -0.7352],
[-1.9077, -0.8366, -0.0747, ..., 1.6109, -0.7460, 0.7346],
[-1.0830, -0.2586, 0.0402, ..., -0.2030, -1.0907, -1.7308],
...,
[ 0.5805, -0.0830, -0.4658, ..., -0.1023, -1.3720, 0.1850],
[-0.8181, 1.3273, 0.8034, ..., 1.2658, -1.4824, 0.0482],
[ 0.9964, -1.8733, 0.3547, ..., 0.0190, -0.3228, 0.4827]],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 8137, 23633, 42622, ..., 39128, 39817, 18664], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.9479, 11.7810, 9.1981, ..., 10.1080, 10.4095, 10.5884],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3268, 11.3203, 11.3294, ..., 11.3251, 11.3333, 11.3225],
device='cuda:0'))
deviation from pytorch implementation: 9.5367431640625e-07
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(logits, labels):
# logits: "cuda:0 f32[2048, 50257]"
# labels: "cuda:0 i64[2048]"
(t17, _) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
del logits, labels
return t17
So what is with the backward?¶
Well, we can define a gradient function and register it along with our implementation.
We thought a lot about how our extension point for gradients looked like - PyTorch’s autograd.Functions
is probably the most well-known way - and we felt that it would be nice to make the connection between tensors in the computation and their gradients explicit.
So the grad transform we implement below is a function that does the following:
it takes the same arguments as the forward,
it computes the forward from its arguments,
it then uses
get_grad
to obtain the required gradients for the forward outputs,computes the gradients for the inputs (this is the backward),
finally attaches the computed gradients to the respective tensors with
put_grad
We supply the grad function as an additional argument of register_implementation
.
[11]:
@log
def apex_cross_entropy_grad(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
loss, max_log_sum_exp = apex_xentropy_forward(
a,
target,
weight,
size_average,
ignore_index,
reduce,
reduction,
label_smoothing,
)
grad = get_grad(loss)
grad_logits = apex_xentropy_backward(
grad,
a,
target,
max_log_sum_exp,
label_smoothing,
)
put_grads((a,), (grad_logits,))
return loss
apex_xentropy_ex.register_implementation(thunder.torch.cross_entropy, checker=apex_xentropy_checker,
execution_transform=cross_entropy_to_apex, grad_transform=apex_cross_entropy_grad)
With these registrations, we can compile a function and it will be automatically transformed into forward and backward and wrapped in a PyTorch autograd.Function calling the backward trace computed by Thunder.
[12]:
from thunder import torch as ltorch
torch.manual_seed(0)
logits = torch.randn([2048, 50257], device="cuda", requires_grad=True)
labels = torch.randint(0, 50257, [2048], device="cuda")
def loss_fn(logits, labels):
return torch.nn.functional.cross_entropy(logits, labels, reduction="none", ignore_index=-1)
cfn = thunder.jit(loss_fn, executors=[apex_xentropy_ex])
actual_loss = cfn(logits, labels)
go = torch.randn_like(actual_loss)
actual_grad, = torch.autograd.grad(actual_loss, logits, go)
expected_loss = loss_fn(logits, labels)
expected_grad, = torch.autograd.grad(expected_loss, logits, go)
print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())
thunder.last_traces(cfn)[-1]
call apex_cross_entropy_grad(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])
call apex_xentropy_forward_meta(TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), None, None, [IntegerProxy name=ignore_index, value=-1], None, none, [FloatProxy name=label_smoothing, value=0.0])
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))
call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=a, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=target, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=label_smoothing, value=0.0])
|<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)
|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)
call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))
call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), [FloatProxy name=f0, value=0.0])
|<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)
call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00, ..., 4.5115e-01,
2.4087e-01, 1.9543e+00],
[ 7.5610e-03, -4.9079e-01, 3.6572e-01, ..., 2.5072e+00,
9.0470e-01, -1.4305e+00],
[-4.4104e-01, -7.6137e-01, -1.1172e+00, ..., 5.9006e-02,
-1.0212e+00, 3.0210e-02],
...,
[-4.2869e+00, 1.4900e+00, -9.1910e-01, ..., 3.6535e-03,
-6.8372e-01, 7.1824e-01],
[-4.2704e-02, 1.3505e+00, 2.1361e+00, ..., -1.1139e+00,
6.1626e-01, 4.8158e-01],
[-7.3334e-01, 2.0820e+00, 3.7722e-02, ..., -7.2141e-01,
4.6871e-01, 7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902, ..., 45225, 32145, 12167], device='cuda:0'), None, None, -1, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([12.4000, 10.9672, 12.6648, ..., 11.7144, 11.8293, 10.9396],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300, ..., 11.3257, 11.3189, 11.3202],
device='cuda:0'))
call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([ 0.8882, -0.0650, -1.2035, ..., -0.4344, -0.0588, -2.5740],
device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[-9.2466e-01, -4.2534e-01, -2.6438e+00, ..., 4.5115e-01,
2.4087e-01, 1.9543e+00],
[ 7.5610e-03, -4.9079e-01, 3.6572e-01, ..., 2.5072e+00,
9.0470e-01, -1.4305e+00],
[-4.4104e-01, -7.6137e-01, -1.1172e+00, ..., 5.9006e-02,
-1.0212e+00, 3.0210e-02],
...,
[-4.2869e+00, 1.4900e+00, -9.1910e-01, ..., 3.6535e-03,
-6.8372e-01, 7.1824e-01],
[-4.2704e-02, 1.3505e+00, 2.1361e+00, ..., -1.1139e+00,
6.1626e-01, 4.8158e-01],
[-7.3334e-01, 2.0820e+00, 3.7722e-02, ..., -7.2141e-01,
4.6871e-01, 7.0758e-01]], device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([ 3957, 45831, 13902, ..., 45225, 32145, 12167], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3186, 11.3176, 11.3300, ..., 11.3257, 11.3189, 11.3202],
device='cuda:0'), 0.0)
|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 4.2787e-06, 7.0495e-06, 7.6679e-07, ..., 1.6936e-05,
1.3724e-05, 7.6143e-05],
[-7.9652e-07, -4.8391e-07, -1.1396e-06, ..., -9.7005e-06,
-1.9535e-06, -1.8908e-07],
[-9.2972e-06, -6.7489e-06, -4.7280e-06, ..., -1.5329e-05,
-5.2049e-06, -1.4894e-05],
...,
[-7.2011e-08, -2.3243e-05, -2.0894e-06, ..., -5.2573e-06,
-2.6439e-06, -1.0743e-05],
[-6.8437e-07, -2.7565e-06, -6.0470e-06, ..., -2.3447e-07,
-1.3227e-06, -1.1561e-06],
[-1.4990e-05, -2.5033e-04, -3.2410e-05, ..., -1.5170e-05,
-4.9872e-05, -6.3328e-05]], device='cuda:0')
Max error in loss: 9.5367431640625e-07
Max error in logits grad: 2.384185791015625e-07
[12]:
[# Constructed by Backward pass
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
t2, \
= cotangents
logits, \
labels, \
t0, \
= C0
f0, \
= C1
t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: "cuda:0 f32[2048, 50257]"
return (t3, None),
# Constructed by Transform for execution (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
t2, \
= cotangents
logits, \
labels, \
t0, \
= C0
f0, \
= C1
t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: "cuda:0 f32[2048, 50257]"
return (t3, None),
# Constructed by Update Call Context (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
t2, \
= cotangents
labels, \
logits, \
t0, \
= C0
f0, \
= C1
t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: "cuda:0 f32[2048, 50257]"
return (t3, None),
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, \
C1, \
= saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t2, \
= cotangents
clear_collection(cotangents)
del cotangents
labels, \
logits, \
t0, \
= C0
clear_collection(C0)
del C0
f0, \
= C1
clear_collection(C1)
del C1
t3 = apex_xentropy_backward(t2, logits, labels, t0, f0) # t3: "cuda:0 f32[2048, 50257]"
del t2, logits, labels, t0, f0
return (t3, None)]
Alternatively, we can also use the grad
transform to get the gradient:
[13]:
logits = torch.randn([2048, 50257], device="cuda", requires_grad=True)
labels = torch.randint(0, 50257, [2048], device="cuda")
grad_jfn = thunder.core.transforms.grad(jfn)
actual_grad, = grad_jfn(logits, labels)
expected_grad, = torch.autograd.grad(loss_fn(logits, labels).sum(), logits)
print("Difference:", (actual_grad - expected_grad).abs().max().item())
print(thunder.last_traces(grad_jfn)[-1])
call apex_cross_entropy_grad(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
call apex_xentropy_forward_meta(TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_meta = (TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0))
call apex_xentropy_backward_meta(TensorProxy(name=t2, shape=(2048,), dtype=float32, device=cuda:0), TensorProxy(name=logits, shape=(2048, 50257), dtype=float32, device=cuda:0), TensorProxy(name=labels, shape=(2048,), dtype=int64, device=cuda:0), TensorProxy(name=t0, shape=(2048,), dtype=int64, device=cuda:0), 0.0)
|<- apex_xentropy_backward_meta = TensorProxy(name=t3, shape=(2048, 50257), dtype=float32, device=cuda:0)
|<- apex_cross_entropy_grad = TensorProxy(name=t1, shape=(2048,), dtype=float32, device=cuda:0)
call apex_xentropy_forward_impl(Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],
[ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],
[ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],
...,
[-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],
[-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],
[ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],
device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), None, None, -100, None, none, 0.0)
|<- apex_xentropy_forward_impl = (Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([10.0233, 11.9095, 11.2898, ..., 10.9289, 10.7487, 10.7455],
device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],
device='cuda:0'))
call apex_xentropy_backward_impl(Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0'), Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[ 0.5390, 0.1760, -1.0790, ..., 0.1695, -0.8082, -0.6984],
[ 2.1555, 1.3938, 0.3928, ..., 0.8937, -0.4949, 1.1610],
[ 0.6784, 1.1188, 0.7508, ..., -0.0941, 0.8380, 0.1878],
...,
[-1.5834, -0.1573, -1.3511, ..., 0.6167, -0.1083, 0.4116],
[-0.5476, 0.5831, 0.0791, ..., -0.4986, -0.5270, 0.0954],
[ 0.2825, -1.0378, -0.5506, ..., 0.0149, 1.3521, -1.0823]],
device='cuda:0', requires_grad=True), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.int64, device=cuda:0) with values tensor([44917, 35770, 41569, ..., 9798, 33992, 36123], device='cuda:0'), Tensor(shape=torch.Size([2048]), stride=(1,), dtype=torch.float32, device=cuda:0) with values tensor([11.3241, 11.3207, 11.3283, ..., 11.3224, 11.3186, 11.3205],
device='cuda:0'), 0.0)
|<- apex_xentropy_backward_impl = Tensor(shape=torch.Size([2048, 50257]), stride=(50257, 1), dtype=torch.float32, device=cuda:0) with values tensor([[2.0706e-05, 1.4403e-05, 4.1058e-06, ..., 1.4309e-05, 5.3827e-06,
6.0079e-06],
[1.0461e-04, 4.8840e-05, 1.7949e-05, ..., 2.9621e-05, 7.3879e-06,
3.8697e-05],
[2.3705e-05, 3.6822e-05, 2.5485e-05, ..., 1.0948e-05, 2.7806e-05,
1.4513e-05],
...,
[2.4836e-06, 1.0338e-05, 3.1331e-06, ..., 2.2417e-05, 1.0857e-05,
1.8259e-05],
[7.0235e-06, 2.1758e-05, 1.3145e-05, ..., 7.3762e-06, 7.1699e-06,
1.3360e-05],
[1.6078e-05, 4.2941e-06, 6.9897e-06, ..., 1.2304e-05, 4.6857e-05,
4.1070e-06]], device='cuda:0')
Difference: 1.3969838619232178e-09
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(logits, labels):
# logits: "cuda:0 f32[2048, 50257]"
# labels: "cuda:0 i64[2048]"
(_, t0) = apex_xentropy_forward(logits, labels, None, None, -100, None, 'none', 0.0)
t4 = torch.full((2048,), 1.0, device=torch.device("cuda:0"), dtype=torch.float32) # t4: "cuda:0 f32[2048]"
# t4 = ltorch.full((2048,), 1.0, device=torch.device("cuda:0"), dtype=torch.float32) # t4: "cuda:0 f32[2048]"
# t4 = prims.full((2048,), 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t4: "cuda:0 f32[2048]"
t3 = apex_xentropy_backward(t4, logits, labels, t0, 0.0) # t3: "cuda:0 f32[2048, 50257]"
del t4, logits, labels, t0
return [t3]
So let’s wrap up what we did here:
We defined a custom executor with custom operations (Symbols in Thunder language), each with a Meta- (data propagation) function and an implementation.
We defined and registered rules to map existing operations to our new operations. This allows us to use optimizations on our model without changing the model’s code!
We defined a gradient rule and saw how our automatic PyTorch Autograd integration or the explicit
grad
transform uses it.
Now go and implement your favourite optimized operators. We would love to hear about your use-cases!
[ ]: