Thunder Overview¶
This section introduces Thunder’s core concepts and architecture. For more details, see Inside thunder.
Thunder is a deep learning compiler for PyTorch, which means it translates calls to PyTorch modules into a format that is easy to transform and that executors can consume to produce fast executables. This translation must produce a simple representation focusing on tensor operations. The format we’ve chosen, like other deep learning compilers, is a sequence of operations called a program trace.
This translation begins with:
jitted_model = thunder.jit(my_module)
or:
jitted_fn = thunder.jit(my_function)
When given a module, the call to thunder.jit()
returns a Thunder-optimized module that shares parameters with the original module (as demonstrated in the Train a MLP on MNIST example), and when given a function it returns a function that when called will jit compile a path through the original function given information about the inputs.
When the jitted module or function is called:
jitted_model(*args, **kwargs)
or:
jitted_fn(*args, **kwargs)
As suggested above, Thunder begins reviewing the module’s or function’s Python bytecode and the input. It may be surprising that Thunder considers the inputs at all, but since control flow (and therefore the operations captured) may vary depending on the input, this is actually required to produce a trace. These traces are cached, so that if inputs of the same type, shape, etc are used again, the trace can be reused.
Traces are generated by running the bytecode through a custom Python interpreter, which is itself implemented in Python. This interpreter has been extended to perform instructions in a different way compared to what standard CPython does. In particular, it constructs a trace of operations performed on tensors or numbers, and keeps track of the provenance of all objects in the program, whether they originated from inside the interpreter or outside.
Much like other machine learning frameworks, Traces don’t typically deal directly with PyTorch tensors, but with proxies that only have metadata like shape, device, dtype, and whether the tensor requires grad or not. As such, during interpretation for trace generation, the execution of the program doesn’t perform any computation on accelerators. Instead, it records the operators along one path of the traceable function.
If replacing CPython with an interpreter written in Python sounds problematic from a performance perspective, you would be largely correct. We haven’t yet put any time into optimizing it, and we think it consumes roughly 400x as much CPU time as CPython. However, the function only needs to be jitted once per equivalence class of inputs, and CPU is not a bottleneck in most machine learning pipelines. As long as the metadata of the inputs (such as a tensor’s shape) and control flow conditions are not changed, we can rely on smart caching to immediately execute an optimized trace. The end result is a faster total execution time.
Traces can be transformed (like for backward()
) and optimized (like by replacing calls to eager PyTorch operations with calls to faster executors), and the final result of this process is an execution trace. Thunder executes the original call by converting the execution trace into a Python function and calling that function with the actual inputs. For details about this optimization process, see the thunder step by step section.
To recap, the complete translation process is:
For PyTorch modules, a Thunder-optimized module is created from the original module.
For PyTorch functions, compilation produces a compiled function.
When the module or function is called, the trace is generated, swapping some inputs with “proxies”.
The trace is transformed and optimized to produce an execution trace.
The execution trace is converted into a Python function and called.
As mentioned, this translation process is often slow - it takes tens of seconds for nanoGPT’s (https://github.com/karpathy/nanoGPT) largest configuration - so Thunder’s performance model expects relatively few of these translations and then a lot of uses of the result. This corresponds with many training and inference patterns, where the same program is executed many times.