Defining new Thunder operators

We are going to add a new operator to Thunder with the corresponding executor. The operator will be called `sincos`` and will compute the sine and cosine of a given input.

Thunder has three sets of core operators: thunder.torch, thunder.clang, and thunder.prims. thunder.prims is a set of operators that are implemented in Python and are used to build the other two sets of operators. A primitive is an operator that is not implemented in terms of other operators.

[1]:
import thunder
import torch

from thunder.core.proxies import TensorProxy
from enum import Enum

Let us define some helper functions (execute the cell below) for printing what’s going on.

[2]:
import functools

_indentation = 0
def _log(msg=None):
    """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _log_indent(msg=None):
    """Print a message and then indent the rest."""
    global _indentation
    _log(msg)
    _indentation = 2 + _indentation

def _log_unindent(msg=None):
    """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 2
    _log(msg)

def log(func):
    """A decorator for functions to log arguments and results."""
    name = func.__name__
    def pp(v):
        """Print certain values more succinctly"""
        vtype = str(type(v))
        if isinstance(v, tuple):
            return "({})".format(pp_values(v))
        elif isinstance(v, thunder.core.proxies.TensorProxy):
            return f"TensorProxy(name={v.name}, shape={v.shape}, dtype={v.dtype}, device={v.device})"
        elif isinstance(v, torch.Tensor):
            return f"Tensor(shape={v.shape}, stride={v.stride()}, dtype={v.dtype}, device={v.device}) with values {v}"
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
        _log_indent("call {}({})".format(name, pp_values(args)))
        res = func(*args)
        _log_unindent("|<- {} = {}\n".format(name, pp(res)))
        return res

    return func_wrapper

Our new operator has the following signature sincos(x: Tensor) -> Tuple[Tensor, Tensor]. It takes a tensor as input and returns a tuple of two tensors. The first tensor is the sine of the input and the second tensor is the cosine of the input.

We call all callables that should be recorded in the trace Symbols. Symbols are the building blocks of the trace. Symbols are either primitives or composite operators. Composite perators are implemented in terms of other operators and primitives. Primitives are operators that are not implemented in terms of other operators or primitives.

The easiest way to register a new operator is through defining a meta - defining how the metadata of the output looks like give the metadata of the inputs and an implementation (dealing with concrete objects like Python Numbers and PyTorch Tensors) and register both of them through an executor. This will automatically create a symbol for us.

So we create an executor:

[3]:
sincos_executor = thunder.extend.OperatorExecutor("sincos_executor", version='0.1')
thunder.add_default_executor(sincos_executor)
[3]:
[sincos_executor, sdpa]

We define meta and implementation:

[4]:
@log
def sincos_meta(inp):
    return (TensorProxy(like=inp), TensorProxy(like=inp))

@log
def sincos_impl(inp):
    return torch.sin(inp), torch.cos(inp)

And register it as sincos:

[5]:
sincos = sincos_executor.register_operator('sincos', meta=sincos_meta, fn=sincos_impl)
sincos
[5]:
[Symbol name=sincos]

That’s it! We have implemented our new primitive. Let’s test it.

[6]:
def fun(a, b):
    sin, cos = sincos(a)
    return sin + cos + b
[7]:
a = torch.randn(1)
b = torch.randn(1)

fun is now a Thunder function, meaning it can only accept Thunder’s TensorProxy as inputs. Let’s test it.

[8]:
try:
    fun(a, b)
except Exception as e:
    print(e)
Attempting to execute outside of a tracing context, which is not supported

In the future we will add support for torch.Tensor and numpy.ndarray inputs for eager mode of Thunder functions. But for now this function is working only in the tracing mode.

[9]:
# Let's see first how this function is represented as a trace
trace = thunder.trace()(fun, a, b)
print(trace)
call sincos_meta(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))

# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def fun(a, b):
  # a: "cpu f32[1]"
  # b: "cpu f32[1]"
  (t0, t1) = sincos(a)
  t2 = ltorch.add(t0, t1, alpha=None)  # t2: "cpu f32[1]"
    # t2 = prims.add(t0, t1)  # t2: "cpu f32[1]"
  t3 = ltorch.add(t2, b, alpha=None)  # t3: "cpu f32[1]"
    # t3 = prims.add(t2, b)  # t3: "cpu f32[1]"
  return t3
[10]:
# We can loop over the recorded operations that we call BoundSymbols
for bound_symbol in trace.bound_symbols:
    print(f"Bound symbol with id={bound_symbol.sym.id} is represented in the trace as |{bound_symbol}|")
    if bound_symbol.subsymbols:
        print("  It has the following subsymbols:")
        for subsymbol in bound_symbol.subsymbols:
            print(f"    id={subsymbol.sym.id}  |{subsymbol}|")
Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# a: "cpu f32[1]" |
Bound symbol with id=PrimIDs.UNPACK_TRIVIAL is represented in the trace as |# b: "cpu f32[1]" |
Bound symbol with id=sincos is represented in the trace as |(t0, t1) = sincos(a)|
Bound symbol with id=torch.add is represented in the trace as |t2 = ltorch.add(t0, t1, alpha=None)  # t2: "cpu f32[1]"
  # t2 = prims.add(t0, t1)  # t2: "cpu f32[1]"|
  It has the following subsymbols:
    id=PrimIDs.ADD  |t2 = prims.add(t0, t1)  # t2: "cpu f32[1]"|
Bound symbol with id=torch.add is represented in the trace as |t3 = ltorch.add(t2, b, alpha=None)  # t3: "cpu f32[1]"
  # t3 = prims.add(t2, b)  # t3: "cpu f32[1]"|
  It has the following subsymbols:
    id=PrimIDs.ADD  |t3 = prims.add(t2, b)  # t3: "cpu f32[1]"|
Bound symbol with id=PrimIDs.RETURN is represented in the trace as |return t3|

Let’s see what happens if we try to compile a function that uses our new primitive and run it.

[11]:
cfun = thunder.jit(fun)
[12]:
cfun(a, b)
call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))

call sincos_impl(Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1413]))
|<- sincos_impl = (Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.1408]), Tensor(shape=torch.Size([1]), stride=(1,), dtype=torch.float32, device=cpu) with values tensor([0.9900]))

/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NoneType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type bool, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type SequenceIter, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of type int, which is not identified as an input. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type NotImplementedType, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
/home/tv/firma/grid/thunder/lightning-thunder/thunder/core/jit_ext.py:478: UserWarning: We are using a (non-const) value of unknown type StopIteration, which may or may not be safe. This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!
  warnings.warn(s)
[12]:
tensor([0.7666])

Let’s check how our function is represented in the execution trace now (change to thunder.last_traces(cfun)[0] to see the trace before transformations)

[13]:
thunder.last_traces(cfun)[-1]
[13]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cpu f32[1]"
  # b: "cpu f32[1]"
  (res, cos) = sincos(a)
  del a
  result = torch.add(res, cos)  # result: "cpu f32[1]"
    # result = ltorch.add(res, cos, alpha=None)  # result: "cpu f32[1]"
      # result = prims.add(res, cos)  # result: "cpu f32[1]"
  del res, cos
  t3 = torch.add(result, b)  # t3: "cpu f32[1]"
    # t3 = ltorch.add(result, b, alpha=None)  # t3: "cpu f32[1]"
      # t3 = prims.add(result, b)  # t3: "cpu f32[1]"
  del result, b
  return t3

For a peek under the hood, we can also first create a new symbol (without reference to an executor) and then register an executor for that.

[14]:
from thunder.core.symbol import Symbol
@log
def sincos_meta(input):
    return (TensorProxy(like=input), TensorProxy(like=input))

# this gives a nice, unique, printable id
class CustomOps(Enum):
    sincos2 = 0

sincos2 = Symbol(
    id=CustomOps.sincos2,
    name="sincos2",
    meta=sincos_meta,
    is_prim=True,
)
[15]:
def fun2(a, b):
    sin, cos = sincos2(a)
    return sin + cos + b

cfun2 = thunder.jit(fun2)
[16]:
try:
    cfun2(a, b)
except RuntimeError as e:
    print(e)
call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))

Failed to find an executor for bound symbol bsym=(res, cos) = __main__.sincos2(a)

There’s no registered executor for sincos so we need to register an executor for our new primitive. Let’s do that.

Check out the “adding-operator-executor.ipynb” notebook to see how to implement an executor for a Symbol.

[17]:
@log
def checker_sincos2(a):
    # We allow the sincos function to be called with any tensor
    return True

@log
def executor_sincos2(a):
    # we need to have something here works with TensorProxies during the transformations,
    # so we need to functions from thunder.torch or thunder.clang or other Symbols
    return thunder.torch.sin(a), thunder.torch.cos(a)

sincos_executor.register_implementation(sincos2, checker=checker_sincos2, execution_transform=executor_sincos2)

[18]:
# Let's try again
cfun2 = thunder.jit(fun2)
cfun2(a, b)
call sincos_meta(TensorProxy(name=t_0, shape=(1,), dtype=float32, device=cpu))
|<- sincos_meta = (TensorProxy(name=t0, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t1, shape=(1,), dtype=float32, device=cpu))

call checker_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))
|<- checker_sincos2 = True

call executor_sincos2(TensorProxy(name=a, shape=(1,), dtype=float32, device=cpu))
|<- executor_sincos2 = (TensorProxy(name=t4, shape=(1,), dtype=float32, device=cpu), TensorProxy(name=t5, shape=(1,), dtype=float32, device=cpu))

[18]:
tensor([0.7666])
[19]:
# Let's check how our function is represented in the execution trace now
thunder.last_traces(cfun2)[-1]
[19]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a: "cpu f32[1]"
  # b: "cpu f32[1]"
  res = torch.sin(a)  # res: "cpu f32[1]"
    # res = ltorch.sin(a)  # res: "cpu f32[1]"
      # res = prims.sin(a)  # res: "cpu f32[1]"
  cos = torch.cos(a)  # cos: "cpu f32[1]"
    # cos = ltorch.cos(a)  # cos: "cpu f32[1]"
      # cos = prims.cos(a)  # cos: "cpu f32[1]"
  del a
  result = torch.add(res, cos)  # result: "cpu f32[1]"
    # result = ltorch.add(res, cos, alpha=None)  # result: "cpu f32[1]"
      # result = prims.add(res, cos)  # result: "cpu f32[1]"
  del res, cos
  t3 = torch.add(result, b)  # t3: "cpu f32[1]"
    # t3 = ltorch.add(result, b, alpha=None)  # t3: "cpu f32[1]"
      # t3 = prims.add(result, b)  # t3: "cpu f32[1]"
  del result, b
  return t3

That’s it! We’ve created our custom operator and registered an executor for it. To recap, we’ve done the following: * Created a new Symbol called sincos that represents the sine and cosine computation (but not the actual computation itself). All we know about it is that it takes a tensor as input and returns a tuple of two tensors. We gave this Symbol a name and id attributes to identify it in the trace and when processing the trace. * Implemented the actual computation by calling PyTorch’s sin and cos functions.