Thunder functional jit

This notebook shows how to use thunder’s “functional jit” entrypoint, thunder.functional.jit. This function takes a “functional” Python function and returns another “jitted” Python function with the same signature. When the jitted function is called, thunder executes the its understanding of the program. If there are no “sharp edges” (more on that below), then the jitted function will compute the same result as the original function.

Before getting into the details, let’s see a simple example

[2]:
import torch

import thunder
from thunder.functional import jit
[3]:
def foo(a, b):
    return a + b

jfoo = jit(foo)

a = torch.randn((2, 2))
b = torch.randn((2, 2))
[4]:
jfoo(a, b)
[4]:
tensor([[ 1.5103, -0.0213],
        [ 1.1842,  0.7658]])
[5]:
foo(a, b)
[5]:
tensor([[ 1.5103, -0.0213],
        [ 1.1842,  0.7658]])

Here a function foo that just adds its inputs together is jitted, and we can verify that the result of the jitted function is the same as the result of the original function. We can also inspect what jfoo actually ran by using last_traces.

[6]:
traces = thunder.last_traces(jfoo)
traces[-1]
[6]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b):
  # a
  # b
  t0 = torch.add(a, b)  # t0
    # t0 = ltorch.add(a, b, alpha=None)  # t0
      # t0 = prims.add(a, b)  # t0
  del a, b
  return t0

Here we see the computation that jfoo performed, which adds two tensors together using PyTorch.

The functional jit can execute “functional” Python functions with input values that are PyTorch tensors, numbers, strings, PyTorch dtypes, PyTorch devices, Nones, slice objects, ellipses objects, PyTorch size objects and tuples, lists, and dicts of those values. It cannot accept other types as inputs.

[8]:
# Simple class that holds a pair of tensors as "a" and "b"
class TensorPair:
    def __init__(self, a, b):
        self.a = a
        self.b = b

tp = TensorPair(a, b)

def bar(tp):
    return tp.a + tp.b

jbar = jit(bar)

# Attempting to pass a TensorPair object to the jitted function results
#   in a ValueError
try:
    jbar(tp)
except ValueError as ve:
    print(ve)
Cannot unpack object of type <class '__main__.TensorPair'>. Please file an issue requesting support.
[9]:
# A workaround for custom inputs is to translate them to accepted values and collections

def tensorpair_wrapper(tp):
    return jfoo(tp.a, tp.b)

tensorpair_wrapper(tp)
[9]:
tensor([[ 1.5103, -0.0213],
        [ 1.1842,  0.7658]])

The functional jit will translate PyTorch functions to thunder operations by default.

[10]:
def foo_torch(a, b):
    return torch.add(a, b)

jfoo_torch = jit(foo_torch)
jfoo_torch(a, b)
[10]:
tensor([[ 1.5103, -0.0213],
        [ 1.1842,  0.7658]])

As mentioned above, the functional jit is intended to jit “functional” Python functions without “sharp edges.” A “sharp edge” is any behavior in the original Python function that will not be translated to the jitted function. Sharp edges are:

  • Inputs that aren’t from the function’s signature

  • Attempts to modify inputs

  • Calling non-functional operations and/or operations with side effects

The following cells provide examples of sharp edges.

[15]:
# Inputs that aren't from the function's signature

from thunder.core.interpreter import InterpreterError

# partial_add loads the value b, which is not a signature input
def partial_add(a):
    return a + b

jpartial_add = jit(partial_add)

# The value b will cause an error, as it has not been "proxied" by the functional jit.
#   Behind the scenes, the functional jit replaces its inputs with "proxies" to observe
#   how they're used in the program.
try:
    jpartial_add(a)
except InterpreterError as je:
    print(je)
Encountered exception ValueError: tensor([[ 0.1333, -1.0425],
        [ 0.1407, -0.5683]]) had an unexpected type <class 'torch.Tensor'>. Supported types are (<class 'thunder.core.proxies.TensorProxy'>, <class 'numbers.Number'>) while tracing <function partial_add at 0x169b8fbe0>:

[17]:
# Attempts to modify inputs

def list_sum(lst):
    accum = lst[0]

    for x in lst[1:]:
        accum = accum + x

    lst.append(accum)

jlist_sum = jit(list_sum)

try:
    jlist_sum([a, b])
except NotImplementedError as nie:
    print(nie)
Appending to an input list is not yet supported
[21]:
# Calling non-functional operations and/or operations with side effects
import random

def add_random(a):
    return a + random.random()

jadd_random = jit(add_random)

jadd_random(a)
[21]:
tensor([[1.6113, 1.2556],
        [1.2779, 1.5685]])

Note that the above example will not throw an error, even though the jitted function does not properly emulate the original function. This can be seen by looking at its last computation function.

[19]:
thunder.last_traces(jadd_random)[-1]
[19]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a
  t0 = torch.add(a, 0.4846238608207385)  # t0
    # t0 = ltorch.add(a, 0.4846238608207385, alpha=None)  # t0
      # t0 = prims.add(a, 0.4846238608207385)  # t0
  del a
  return t0

In the last trace, we see that the value returned from random.random() is treated as a compile-time constant, even though it’s generated at runtime. This means that the jitted function will use the same value on every call, and not generate a new value using random.random(). random.random() is a non-functional operation that accepts an implicit random state input, and it also has a side effect of mutating Python’s random state.

In Python, function calls are actually loads of global variables, and these are technically

[ ]: