thunder¶
Compiling functions and modules¶
|
Just-in-time compile a callable (function or model). |
|
Just-in-time compile a function. |
Querying information on compiled functions and modules¶
|
Obtains the compilation data from a JITed function. |
|
Obtains the compilation statistics from a JITed function. |
|
Obtains the list of computation traces that have been produced for the last run of the function. |
Obtains the list of backward traces that have been produced for the last run of the function and the selected prologue. |
|
Obtains the list of prologue traces that have been produced for the last run of the function and the selected prologue. |
|
|
Returns the cache options set when JITting the function. |
|
Returns the number of cache hits we found when running the function. |
|
Returns the number of cache misses we found when running the function. |
|
Returns the list of (explicit) transforms applied to the JITed function. |
Returns the list of instructions the interpreter encountered while tracing through the user program (on the last cache miss). |
|
Returns the list of instructions and other information the interpreter encountered while tracing through the user program (on the last cache miss). |
|
|
Prints how compiled options were used (or not) |
JITed Model wrapper¶
- class thunder.ThunderModule(model, compiled_model_call)[source]¶
Bases:
Module
A wrapper nn.Module subclass.
This wrapper is returned by
thunder.jit
, you would typically not instantiate it manually.- get_buffer(name)[source]¶
Return the buffer given by
target
if it exists, otherwise throw an error.See the docstring for
get_submodule
for a more detailed explanation of this method’s functionality as well as how to correctly specifytarget
.- Parameters:
target¶ – The fully-qualified string name of the buffer to look for. (See
get_submodule
for how to specify a fully-qualified string.)- Returns:
The buffer referenced by
target
- Return type:
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not a buffer
- get_parameter(name)[source]¶
Return the parameter given by
target
if it exists, otherwise throw an error.See the docstring for
get_submodule
for a more detailed explanation of this method’s functionality as well as how to correctly specifytarget
.- Parameters:
target¶ – The fully-qualified string name of the Parameter to look for. (See
get_submodule
for how to specify a fully-qualified string.)- Returns:
The Parameter referenced by
target
- Return type:
torch.nn.Parameter
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Parameter
- get_submodule(name)[source]¶
Return the submodule given by
target
if it exists, otherwise throw an error.For example, let’s say you have an
nn.Module
A
that looks like this:A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )
(The diagram shows an
nn.Module
A
.A
has a nested submodulenet_b
, which itself has two submodulesnet_c
andlinear
.net_c
then has a submoduleconv
.)To check whether or not we have the
linear
submodule, we would callget_submodule("net_b.linear")
. To check whether we have theconv
submodule, we would callget_submodule("net_b.net_c.conv")
.The runtime of
get_submodule
is bounded by the degree of module nesting intarget
. A query againstnamed_modules
achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists,get_submodule
should always be used.- Parameters:
target¶ – The fully-qualified string name of the submodule to look for. (See above example for how to specify a fully-qualified string.)
- Returns:
The submodule referenced by
target
- Return type:
- Raises:
AttributeError – If the target string references an invalid path or resolves to something that is not an
nn.Module
- no_sync()[source]¶
Context manager to disable gradient synchronization in data parallel mode.
This context manager is intended to be used in conjunction with
torch.nn.parallel.DistributedDataParallel
to disable gradient synchronization in the backward pass. It will not have any effect when used with other modules.Note
This could lead to different accumulated gradients with
torch.nn.parallel.distributed.DistributedDataParallel.no_sync
. PyTorch’s gradient synchronization is implemented by applying all-reduce to gradient buckets oftorch.nn.Parameter.grad
. Thus theno_sync
context leads to \(\text{AllReduce} \left( \sum_{i = 0}^{\rm{num_grad_accum_steps}} g_i \right)\). In contrast, this synchronizes accumulated gradients when exiting, leading to \(\text{AllReduce} \left( \sum_{i = 0}^{\rm{num_grad_accum_steps - 1}} g_i \right) + \text{AllReduce}(g_{\rm{num_grad_accum_steps}})\).Warning
You must reuse this context manager in each group of gradient accumulation iterations since gradients will get synchronized on context manager exit.
with model.no_sync(): for _ in range(len(gradient_accumulation_iters)): loss(model(x)).backward() # uses no-sync-backward trace loss(model(x)).backward() # uses the regular backward trace optimizer.step()