Contributing to Thunder

We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don’t hesitate to use the GitHub Issue tracker.

We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment.

For a simple general overview of Thunder, we recommend reading inside Thunder first.

Adding operators

Adding operators might be one of the easiest and fun ways to get involved in contributing to Thunder. The operator GitHub Issue tracker provides a great starting point in deciding which operation to work on first.

The subsections below are structured as follows

We recommend reading the document sequentially!

Primitives

The lowest level is the primitive operations, defined in thunder/core/prims.py. Primitive operations, as seen in the Representing operations section, describe all the computation performed, and they are intended to be as simple as possible so that executors like nvFuser find it easy to manipulate them. Thunder’s primitives are similar to PyTorch’s primTorch primitives, and are based on JAX’s jax.lax operations.

Primitives have several parts, as defined by the make_prim function. Most importantly they have an id, a name, and a meta function. The meta function performs error checking and maps the metadata (like dtype, shape, device) of inputs to the primitive metadata of outputs. For operations that are part of a class, like the elementwise unary or reduction operations, they often share a common meta function. More unique operations, like slice, define their own meta functions.

The actual execution of primitive operations is handled by executors like nvFuser or PyTorch – more on that in a moment.

Before adding a primitive, check with the team on its design. It might be appropriate to add primitives when necessary to describe the semantics of an operation or to improve the numerical accuracy or speed of operations.

There is a tradeoff with the design of primitive operations one has to keep in mind. On one hand, fewer primitive operations can make program transformation, and execution easier. Fewer primitives means fewer transformation rules – since transformation rules are defined on primitives – and a smaller interface with executors. On the other hand, too few primitive operations may make it hard, or impossible, to express all the operations that users are interested in. Too few primitive operations may also make it difficult to execute programs quickly and numerically accurately.

For example, the expm1 operation can mathematically be defined in terms of the exp and subtraction operations, and so it does not need to be a primitive to enable any functionality. Many libraries, including C++’s standard library, still define an expm1 operation for numerical accuracy, and so expm1 is a primitive in Thunder.

The Core Language

Above the primitives is the core language, or clang. Clang operations are mostly written like any other Python operation. They ultimately call the primitive operations, although they may call other operations before doing so (for example, clang.foo might call clang.bar which calls prims.bar).

Core language operations are intended to be common functionality that’s useful when defining user-facing languages like torch or numpy. Many of these operations are just wrappers around primitive operations. For example, the elementwise binary primitives are as simple as possible, so they don’t perform broadcasting or type promotion. The core language elementwise binary operations, however, do perform broadcasting and type promotion. For example, take a look at the following implementation of add from clang

1602def _elementwise_binary_wrapper(a, b, *, prim, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT):
1603    computation_dtype, result_dtype = utils.elementwise_type_promotion(a, b, type_promotion_kind=type_promotion_kind)
1604
1605    a, b = maybe_broadcast(a, b)
1606    a, b = maybe_convert_to_dtype(a, computation_dtype), maybe_convert_to_dtype(b, computation_dtype)
1607
1608    result = prim(a, b)
1609    result = maybe_convert_to_dtype(result, result_dtype)
1610
1611    return result
1612
1613
1614@clangop(method_name="add")
1615def add(a, b):
1616    return _elementwise_binary_wrapper(a, b, prim=prims.add)

Before adding a core language operation consider if the functionality expressed is universal enough.

As a style note, operations in Thunder should defer as much error checking as possible. For example, if a primitive’s meta function will perform an error check for X, then the core language operation that calls it should generally not also check for X.

The Torch Language

To translate torch operations into something that Thunder understands we define a torch language. Operations in the torch should reflect the behavior of their corresponding torch operations (small deviations are sometimes OK).

When a program is interpreted, torch operations are remapped into these operations, which ultimately call primitive operations.

Language Context

In the core and torch languages functions are decorated to set a language context and – for torch operations – to describe how to map operations like torch.foo into thunder.torch.foo.

The language context determines what properties and methods tensor objects have. For example, when a + b is written and the first argument is an array or tensor (so, TensorProxy.__add__ is invoked), the language context decides what that addition means. Or when a.size is used, the language context determines what that means (and it’s different in PyTorch and NumPy).

Adding operations to the Torch executor

Now that we are familiar with the hierarchy of operations and the underlying language contexts, let’s see some examples of adding operations.

For simplicity, we only cover adding operations to the torch executor. The sections below are meant to be read sequentially.

Adding a primitive

A good example of adding a primitive operation to the torch executor is the PR #136 which adds support for torch.Tensor.unfold.

Let’s outline some of its key parts.

Consider the following update to thunder/core/prims.py

152SLICE = auto()
153SQUEEZE = auto()
154TRANSPOSE = auto()
155UNFOLD = auto()
156VIEW = auto()
157# Memory layout prims (Experimental)
158STRIDE_ORDER = auto()
3082def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
3083    dim = utils.canonicalize_dim(a.ndim, dim)
3084    max_size = 1 if a.ndim == 0 else a.shape[dim]
3085
3086    utils.check(
3087    size <= max_size, lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}"
3088    )
3089    utils.check(size >= 0, lambda: f"Size is {size} but must be >= 0")
3090    utils.check(step > 0, lambda: f"Step is {step} but must be > 0")
3091
3092    shape = list(a.shape)
3093    shape.append(size)
3094    shape[dim] = (shape[dim] - size) // step + 1
3095
3096    return TensorProxy(like=a, shape=shape)
3097
3098
3099 unfold = make_prim(PrimIDs.UNFOLD, "unfold", meta=unfold_meta, tags=(OpTags.SHAPE_OP,))

The above registers a primitive symbol unfold using make_prim with id=PrimIDs.UNFOLD, name=unfold, and meta=unfold_meta. One can see that unfold_meta follows the signature of the underlying torch.Tensor.unfold operation (so that the primitive is directly modeled after the PyTorch operation) with the only exception of expecting a TensorProxy and not the torch.Tensor as its input. The rest of the function checks the inputs and returns a TensorProxy of the appropriate shape. like=a means that the output will inherit the meta-data like device and dtype from a. The primitive is also tagged with tags=(OpTags.SHAPE_OP,), and, therefore, is associated with shape-based operations. We use tags to additionally group operations for group-specific operation optimizations inside Thunder.

Once the symbol is created, we need to tell Thunder how to execute it. Since we are updating the torch executor, the following lines are added to the executors/torchex.py file

465unbind = _register_torch_operation("unbind")
466unfold = _register_torch_operation("unfold", module=torch.Tensor)
467unsqueeze = _register_torch_operation("unsqueeze")
536_register_implementation(prims.transpose, checker=_always_executable, execution_transform=_transpose_prim_transform)
537_register_implementation(prims.unfold, unfold, checker=_always_executable)
538_register_implementation(prims.view, view, checker=_always_executable)

the first one registers a new symbol that is directly tied to the torch.Tensor.unfold, and the second associates this symbol with prims.unfold upon execution unless the checker fails. Having checker=_always_executable always greenlights this association, and, hence, whenever the torch executor tries to execute prims.unfold, it executes torch.Tensor.unfold. Note, however, that although the checker does have access to the symbol’s inputs, it is different from the meta-function. Meta-functions are supposed to only validate inputs and to be executor-agnostic. Checkers, on the other hand, are not meant to check inputs’ validity and they are agnosit to executors. As such, they are useful for checking and enabling symbols for specific versions of executors like PyTorch, for example.

The mapping of the prims.unfold symbol to torch.Tensor.unfold is very simple since the inputs to prims.unfold can directly be passed to torch.Tensor.unfold without any additional pre-preprocessing (association between TensorProxy and torch.Tensor is handled automatically by the torch executor). This is not the case with any operation, however, and sometimes the symbol’s interface has to undergo a transformation to be compatible with the registered implementation provided by the executor. For example, the following lines from executors/torchex.py

234def _full_transform(
235    shape: Sequence[int], fill_value: Number, *, device: None | devices.Device, dtype: None | dtypes.dtype
236) -> TensorProxy:
237    torch_device: None | torch.device = to_torch_device(device)
238    torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
239
240    return full(shape, fill_value, device=torch_device, dtype=torch_dtype)
421_register_implementation(prims.full, checker=_always_executable, execution_transform=_full_transform)

show us how to accomplish that with the execution_transform argument of _register_implementation where the Thunder meta-data like device, dtype is converted to the corresponding PyTorch meta-data.

Testing the Operation

In the previous section we saw an example of adding a primitive operation. However, it is not guaranteed that the operation performs as expected. We need to test it!

Operators are typically tested by adding an OpInfo for them. See here to better understand how OpInfos work. OpInfo contains metadata describing an operator, a sample input generator, a sample generator for erroneous inputs that is used for testing handling exceptions/meta function correctness, and test directives. It’s used to automatically generate a variety of tests, most importantly tests that verify the operator’s behavior is consistent with its reference implementations.

It is important to determine whether you need to add test_directives in order to skip tests or expect failures of tests.

  • Skip (pytest.mark.skip): Skips are needed when something is not implemented by an executor or for a device.

  • Expected Failures (pytest.mark.xfail): Expected failures indicate that an executor has implemented some aspect of an operation but its behavior is incorrect.

An example of OpInfo for prims.unfold from the PR #136 added to thunder/tests/opinfos.py

2997def unfold_sample_generator(op, device, dtype, requires_grad, **kwargs):
2998    make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
2999
3000    cases = (
3001        ((), 0, 1, 3),
3002        ((), -1, 0, 5),
3003        ((0,), 0, 0, 1),
3004        ((8,), 0, 2, 1),
3005        ((6, 2), 0, 2, 2),
3006    )
3007
3008    for shape, dim, size, step in cases:
3009        yield SampleInput(make(shape), dim, size, step)
3010
3011
3012 def unfold_error_generator(op, device, dtype=torch.float32, **kwargs):
3013     make = partial(make_tensor, device=device, dtype=dtype)
3014
3015     cases = (
3016         ((), 0, 2, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 1 but size is 2"),
3017         ((0,), 0, 0, -1, RuntimeError, "Step is -1 but must be > 0"),
3018         ((8,), 1, 2, 1, IndexError, r"Dimension out of range \(expected to be in range of \[-1, 0\], but got 1\)"),
3019         ((8,), 0, -5, 1, RuntimeError, "Size is -5 but must be >= 0"),
3020         ((8,), 0, 10, 1, RuntimeError, "Maximum size for tensor at dimension 0 is 8 but size is 10"),
3021     )
3022
3023     for shape, dim, size, step, err_type, err_msg in cases:
3024         yield SampleInput(make(shape), dim, size, step), err_type, err_msg
3025
3026
3027 unfold_opinfo = OpInfo(
3028     clang.unfold,
3029     sample_input_generator=unfold_sample_generator,
3030     error_input_generator=unfold_error_generator,
3031     torch_reference=torch.Tensor.unfold,
3032 )
3033
3034 shape_ops.append(unfold_opinfo)

Note how comprehensive unfold_sample_generator and unfold_error_generator are. unfold_sample_generator does not shy away from testing scalar inputs (shape=()) and empty inputs (shape=(0,), i.e. shapes containing zeros). And unfold_error_generator tests about every aspect of the underlying meta-function.

To run the tests for a particular operator, use pytest’s -k option. This will run tests for Thunder’s different executors, supported dtypes, and supported device types. For example, to run the tests for unfold the command would be

$ pytest thunder/tests/test_ops.py -k unfold

Another example of an OpInfo with specified test_directives

577acos_opinfo = OpInfo(
578    ltorch.acos,
579    domain=(-1, 1),
580    sample_input_generator=elementwise_unary_generator,
581    torch_reference=_elementwise_unary_torch(torch.acos),
582    test_directives=(
583        # Torch doesn't support CPU float16 or complex32 acos
584        DecorateInfo(
585            pytest.mark.xfail,
586            "test_core_vs_torch_consistency",
587            dtypes=(datatypes.float16, datatypes.complex32),
588            devicetypes=(devices.DeviceType.CPU,),
589        ),
590    ),
591)
592elementwise_unary_ops.append(acos_opinfo)

We strive for Thunder to be of the highest quality possible, so it is always a good idea to be very thorough when it comes to testing.

Adding grad support

Operations are not differentiable by default, unless they are implemented as compositions of differentiable operations (related to updating the torch language. More on that later). When an operation is a composition of other operations, we say that this operation is decomposable. Primitive operations, by definition, are not decomposable, and, as such, require an explicit backward/grad/VJP (for simplicity, we use them interchangeably) rule implemented for them. These rules, or grad transforms, are implemented in thunder/core/transforms.py. Note, however, these rules are not exclusively restricted to primitive operations, see Defining custom forward and backward for existing operators, for example, and can be implemented even for decomposable operations for performance reasons.

For now, for simplicity, let’s assume that a new primitive is being added and we would like to make it differentiable. Consider the PR #118 which adds a backward support for a primitive operation prims.topk (added in the PR #88) that is modeled after torch.topk. The added to thunder/core/transforms.py lines

1111@torchctx
1112def _topk_prim_grad(
1113    a: TensorProxy, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None
1114):
1115    fwd = prims.topk(a, k, dim, largest, sorted, out=out)
1116    val, idx = fwd
1117
1118    val_grad = get_grad(val)
1119
1120    a_grad = ltorch.zeros_like(a)
1121    # TODO: replace with scatter once we have it.
1122    # scatter_add is a prim and it relies on atomic ops.
1123    a_grad = ltorch.scatter_add(a_grad, dim, idx, val_grad)
1124    put_grad(a, a_grad)
1125
1126    return fwd
1127
1128
1129register_grad(pids.TOPK, _topk_prim_grad)

define a grad transform for prims.topk. This operation returns a 2-tuple in forward fwd = (val, idx) with only the first element being differentiable. Note that Thunder interleaves forward and backward computations in grad transforms. Take a look at the lines val_grad = get_grad(val), which extracts the in-flowing backward gradient for val, and put_grad(a, a_grad) which sets the backward gradient for the input a.

Do you see that comment about the missing scatter? You could be the one who implements it! :)

Updating the Torch Language

The Torch Language operations are the “highest”-level operations and, as such, are decomposable. If the missing operation can be decomposed into already existing operations, then thunder/torch/__init__.py is where its implementation is to be placed.

For example, consider the PR #100, that adds support for the Hardswish activation function. The function is implemented in thunder/torch/__init__.py

1211@torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False)
1212def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike:
1213    utils.check(not inplace, lambda: f"hardswish only supports inplace=False", exception_type=NotImplementedError)
1214    utils.check(
1215        dtypes.is_float_dtype(a.dtype),
1216        lambda: f"hardswish only supports floating point dtypes, got {a.dtype}",
1217        exception_type=ValueError,
1218    )
1219    return a * relu6(a + 3) / 6

Note the checks (Thunder does not support in-place operations yet) and that hardswish is a composition of the relu6 operation (defined in the torch language) and the language context-specific binary operations over the objects that TensorProxy represent. All these basic operations are differentiable (for the Torch/NVFuser executors), and so is hardswish implicitly differentiable (for the Torch/NVFuser executors).

Afterword

We hope that you find information provided here useful and we look forward to your contributions!

We also recommend checking out Defining new Thunder operations and Defining custom forward and backward for existing operators that cover very similar topics related to extending Thunder out of the tree.