thunder

Compiling functions and modules

jit(fn, /, *[, langctx, executors, ...])

Just-in-time compile a callable (function or model).

functional.jit(fn, /, *[, langctx, ...])

Just-in-time compile a function.

Querying information on compiled functions and modules

compile_data(fn)

Obtains the compilation data from a JITed function.

compile_stats(fn)

Obtains the compilation statistics from a JITed function.

last_traces(fn)

Obtains the list of computation traces that have been produced for the last run of the function.

last_backward_traces(fn)

Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue.

last_prologue_traces(fn)

Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue.

cache_option(fn)

Returns the cache options set when JITting the function.

cache_hits(fn)

Returns the number of cache hits we found when running the function.

cache_misses(fn)

Returns the number of cache misses we found when running the function.

list_transforms(fn)

Returns the list of (explicit) transforms applied to the JITed function.

last_interpreted_instructions(fn)

Returns the list of instructions the interpreter encountered while tracing through the user program (on the last cache miss).

last_interpreter_log(fn)

Returns the list of instructions and other information the interpreter encountered while tracing through the user program (on the last cache miss).

last_compile_options(fn, /)

Prints how compiled options were used (or not)

JITed Model wrapper

class thunder.ThunderModule(model, compiled_model_call)[source]

Bases: Module

A wrapper nn.Module subclass.

This wrapper is returned by thunder.jit, you would typically not instantiate it manually.

get_buffer(name)[source]

Return the buffer given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters:

target – The fully-qualified string name of the buffer to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

The buffer referenced by target

Return type:

torch.Tensor

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not a buffer

get_parameter(name)[source]

Return the parameter given by target if it exists, otherwise throw an error.

See the docstring for get_submodule for a more detailed explanation of this method’s functionality as well as how to correctly specify target.

Parameters:

target – The fully-qualified string name of the Parameter to look for. (See get_submodule for how to specify a fully-qualified string.)

Returns:

The Parameter referenced by target

Return type:

torch.nn.Parameter

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Parameter

get_submodule(name)[source]

Return the submodule given by target if it exists, otherwise throw an error.

For example, let’s say you have an nn.Module A that looks like this:

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(The diagram shows an nn.Module A. A has a nested submodule net_b, which itself has two submodules net_c and linear. net_c then has a submodule conv.)

To check whether or not we have the linear submodule, we would call get_submodule("net_b.linear"). To check whether we have the conv submodule, we would call get_submodule("net_b.net_c.conv").

The runtime of get_submodule is bounded by the degree of module nesting in target. A query against named_modules achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, get_submodule should always be used.

Parameters:

target – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)

Returns:

The submodule referenced by target

Return type:

torch.nn.Module

Raises:

AttributeError – If the target string references an invalid path or resolves to something that is not an nn.Module

no_sync()[source]

Context manager to disable gradient synchronization in data parallel mode.

This context manager is intended to be used in conjunction with torch.nn.parallel.DistributedDataParallel to disable gradient synchronization in the backward pass. It will not have any effect when used with other modules.

Note

This could lead to different accumulated gradients with torch.nn.parallel.distributed.DistributedDataParallel.no_sync. PyTorch’s gradient synchronization is implemented by applying all-reduce to gradient buckets of torch.nn.Parameter.grad. Thus the no_sync context leads to \(\text{AllReduce} \left( \sum_{i = 0}^{\rm{num_grad_accum_steps}} g_i \right)\). In contrast, this synchronizes accumulated gradients when exiting, leading to \(\text{AllReduce} \left( \sum_{i = 0}^{\rm{num_grad_accum_steps - 1}} g_i \right) + \text{AllReduce}(g_{\rm{num_grad_accum_steps}})\).

Warning

You must reuse this context manager in each group of gradient accumulation iterations since gradients will get synchronized on context manager exit.

with model.no_sync():
    for _ in range(len(gradient_accumulation_iters)):
        loss(model(x)).backward()  # uses no-sync-backward trace
loss(model(x)).backward()  # uses the regular backward trace
optimizer.step()