Zero to Thunder

Here we take a very short tour of what is possible with Thunder.

To get started we import it (and a bunch of things for this notebook).

[1]:
import sys
sys.path.insert(0, '..')

import torch, thunder

Compiling a first module with Thunder

So let’s get started! As a “Hello World”, let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT.

[2]:
class LLaMAMLP(torch.nn.Module):
    def __init__(self, n_embd, intermediate_size) -> None:
        super().__init__()
        self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_fc_1 = self.fc_1(x)
        x_fc_2 = self.fc_2(x)
        x = torch.nn.functional.silu(x_fc_1) * x_fc_2
        return self.proj(x)
with torch.device("cuda"):
    m = LLaMAMLP(4096, 11008)
for p in m.parameters():
    p.requires_grad_(False)
print(m)

LLaMAMLP(
  (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
  (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
  (proj): Linear(in_features=11008, out_features=4096, bias=False)
)

Now we can apply Thunder. This uses the most important function of Thunder, thunder.jit, which can be used to compile a torch.nn.Module or a function. It will wrap our MLP in a ThunderModule

[3]:
thunder_model = thunder.jit(m)
[4]:
thunder_model
[4]:
ThunderModule(
  (_model): LLaMAMLP(
    (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
    (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
    (proj): Linear(in_features=11008, out_features=4096, bias=False)
  )
)

Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance.

[5]:
x = torch.randn(2, 2048, 4096, device="cuda")
print('deviation:', (thunder_model(x) - m(x)).abs().max().item())

%timeit thunder_model(x); torch.cuda.synchronize()
%timeit m(x); torch.cuda.synchronize()
deviation: 1.4901161193847656e-07
61.3 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
62.1 ms ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So what has changed? Quite a bit!

When we call the Thunder module, it do the computation in a single function without control flow. And what’s more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:

[6]:
thunder.last_traces(thunder_model)[-1]
[6]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):
  # x: "cuda:0 f32[2, 2048, 4096]"
  # t_fc_1_weight: "cuda:0 f32[11008, 4096]"
  # t_fc_2_weight: "cuda:0 f32[11008, 4096]"
  # t_proj_weight: "cuda:0 f32[4096, 11008]"
  x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_1 = prims.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
  del t_fc_1_weight
  x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_2 = prims.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
  del x, t_fc_2_weight
  [result] = nvFusion0(x_fc_1, x_fc_2)
    # t9 = prims.neg(x_fc_1)  # t9: "cuda:0 f32[2, 2048, 11008]"
    # t10 = prims.exp(t9)  # t10: "cuda:0 f32[2, 2048, 11008]"
    # t11 = prims.add(1.0, t10)  # t11: "cuda:0 f32[2, 2048, 11008]"
    # t12 = prims.reciprocal(t11)  # t12: "cuda:0 f32[2, 2048, 11008]"
    # a = prims.mul(x_fc_1, t12)  # a: "cuda:0 f32[2, 2048, 11008]"
    # result = prims.mul(a, x_fc_2)  # result: "cuda:0 f32[2, 2048, 11008]"
  del x_fc_1, x_fc_2
  t18 = torch.nn.functional.linear(result, t_proj_weight, None)  # t18: "cuda:0 f32[2, 2048, 4096]"
    # t18 = ltorch.linear(result, t_proj_weight, None)  # t18: "cuda:0 f32[2, 2048, 4096]"
      # t18 = prims.linear(result, t_proj_weight, None)  # t18: "cuda:0 f32[2, 2048, 4096]"
  del result, t_proj_weight
  return t18

For more detail of what is going on in this trace: - Thunder has transformed the computation (more precisely, m.__call__) into a single function which has all the MLP parameters as arguments. - It has recorded the tensor metadata. - Operations have been mapped from the PyTorch functions to thunder.torch(aka ltorch) equivalents and decomposed into primitive operations. - The multiplication and activation (x = torch.nn.functional.silu(x_fc_1) * x_fc_2have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) - You can see how the parameters are obtained and the metadata is checked in the prologue - get it through thunder.last_prologue_traces(thunder_model)[-1].

You can actually see the series of traces, last_traces gives you a list of transformed traces in chronological order - for example the initial trace thunder.last_traces(thunder_model)[0] does not have the fusion yet.

Compiling a more complex model

Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):

NOTE: For running the cells below, we require litgpt which can be installed with pip install 'litgpt[all] @ git+https://github.com/Lightning-AI/litgpt'. See here to learn more about litgpt.

[7]:
from litgpt import GPT
from thunder.tests.litgpt_model import Config
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 16 # fewer layers
torch.set_default_dtype(torch.bfloat16)
with torch.device('cuda'):
    m = GPT(cfg)
m

[7]:
GPT(
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 4096)
    (h): ModuleList(
      (0-15): 16 x Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=12288, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
          (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
          (proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

Again we jit our model and compare the output…

[8]:
thunder_model = thunder.jit(m)

inp = torch.randint(1, m.config.vocab_size, (1, 512), device="cuda")

actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())

deviation: 0.03125

One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.

Just like before, we can see the program it ran, it is a lot longer, though.

[9]:
print(actual.grad_fn)
thunder.last_traces(thunder_model)[-1]
<torch.autograd.function.ThunderFunctionBackward object at 0x7f923f792ac0>
[9]:
# Constructed by Delete Last Used (took 10 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(*args):
  # args: "Collection"
  t0, \
  t1, \
  t2, \
  t3, \
  t4, \
  t5, \
  t6, \
  t7, \
  t8, \
  t9, \
  t10, \
  t11, \
  t12, \
  t13, \
  t14, \
  t15, \
  t16, \
  t17, \
  t18, \
  t19, \
  t20, \
  t21, \
  t22, \
  t23, \
  t24, \
  t25, \
  t26, \
  t27, \
  t28, \
  t29, \
  t30, \
  t31, \
  t32, \
  t33, \
  t34, \
  t35, \
  t36, \
  t37, \
  t38, \
  t39, \
  t40, \
  t41, \
  t42, \
  t43, \
  t44, \
  t45, \
  t46, \
  t47, \
  t48, \
  t49, \
  t50, \
  t51, \
  t52, \
  t53, \
  t54, \
  t55, \
  t56, \
  t57, \
  t58, \
  t59, \
  t60, \
  t61, \
  t62, \
  t63, \
  t64, \
  t65, \
  t66, \
  t67, \
  t68, \
  t69, \
  t70, \
  t71, \
  t72, \
  t73, \
  t74, \
  t75, \
  t76, \
  t77, \
  t78, \
  t79, \
  t80, \
  t81, \
  t82, \
  t83, \
  t84, \
  t85, \
  t86, \
  t87, \
  t88, \
  t89, \
  t90, \
  t91, \
  t92, \
  t93, \
  t94, \
  t95, \
  t96, \
  t97, \
  t98, \
  t99, \
  t100, \
  t101, \
  t102, \
  t103, \
  t104, \
  t105, \
  t106, \
  t107, \
  t108, \
  t109, \
  t110, \
  t111, \
  t112, \
  t113, \
  t114, \
  t115, \
  t116, \
  t117, \
  = args
  del args
  t122 = torch.nn.functional.embedding(t0, t117, None, None, 2.0, False, False)  # t122: "cuda:0 bf16[1, 512, 4096]"
    # t122 = ltorch.embedding(t0, t117, None, None, 2.0, False, False)  # t122: "cuda:0 bf16[1, 512, 4096]"
      # t1867 = ltorch.reshape(t0, [512])  # t1867: "cuda:0 i64[512]"
        # t1867 = prims.reshape(t0, (512,))  # t1867: "cuda:0 i64[512]"
      # t1868 = prims.take(t117, t1867, 0)  # t1868: "cuda:0 bf16[512, 4096]"
      # t122 = ltorch.reshape(t1868, [1, 512, 4096])  # t122: "cuda:0 bf16[1, 512, 4096]"
        # t122 = prims.reshape(t1868, (1, 512, 4096))  # t122: "cuda:0 bf16[1, 512, 4096]"
  t118 = torch_slice_prim_impl(t1, [0, 0], [512, 128], [1, 1])  # t118: "cuda:0 f32[512, 128]"
  t119 = torch_slice_prim_impl(t2, [0, 0], [512, 128], [1, 1])  # t119: "cuda:0 f32[512, 128]"
  t2015 = torch.unsqueeze(t53, 0)  # t2015: "cuda:0 bf16[1, 4096]"
    # t2015 = ltorch.unsqueeze(t53, 0)  # t2015: "cuda:0 bf16[1, 4096]"
      # t2015 = prims.broadcast_in_dim(t53, [1, 4096], [1])  # t2015: "cuda:0 bf16[1, 4096]"
  t2016 = torch.unsqueeze(t2015, 1)  # t2016: "cuda:0 bf16[1, 1, 4096]"
    # t2016 = ltorch.unsqueeze(t2015, 1)  # t2016: "cuda:0 bf16[1, 1, 4096]"
      # t2016 = prims.broadcast_in_dim(t2015, [1, 1, 4096], [0, 2])  # t2016: "cuda:0 bf16[1, 1, 4096]"
  del t2015
  t133 = Tensor.expand(t2016, (1, 512, 4096))  # t133: "cuda:0 bf16[1, 512, 4096]"
    # t133 = ltorch.expand(t2016, (1, 512, 4096))  # t133: "cuda:0 bf16[1, 512, 4096]"
      # t133 = prims.broadcast_in_dim(t2016, (1, 512, 4096), (0, 1, 2))  # t133: "cuda:0 bf16[1, 512, 4096]"
  del t2016
  t2356 = torch.unsqueeze(t82, 0)  # t2356: "cuda:0 bf16[1, 4096]"
    # t2356 = ltorch.unsqueeze(t82, 0)  # t2356: "cuda:0 bf16[1, 4096]"
      # t2356 = prims.broadcast_in_dim(t82, [1, 4096], [1])  # t2356: "cuda:0 bf16[1, 4096]"
  t2357 = torch.unsqueeze(t2356, 1)  # t2357: "cuda:0 bf16[1, 1, 4096]"
    # t2357 = ltorch.unsqueeze(t2356, 1)  # t2357: "cuda:0 bf16[1, 1, 4096]"
      # t2357 = prims.broadcast_in_dim(t2356, [1, 1, 4096], [0, 2])  # t2357: "cuda:0 bf16[1, 1, 4096]"
  del t2356
  t1609 = Tensor.expand(t2357, (1, 512, 4096))  # t1609: "cuda:0 bf16[1, 512, 4096]"
    # t1609 = ltorch.expand(t2357, (1, 512, 4096))  # t1609: "cuda:0 bf16[1, 512, 4096]"
      # t1609 = prims.broadcast_in_dim(t2357, (1, 512, 4096), (0, 1, 2))  # t1609: "cuda:0 bf16[1, 512, 4096]"
  del t2357
  t2359 = torch.unsqueeze(t58, 0)  # t2359: "cuda:0 bf16[1, 4096]"
    # t2359 = ltorch.unsqueeze(t58, 0)  # t2359: "cuda:0 bf16[1, 4096]"
      # t2359 = prims.broadcast_in_dim(t58, [1, 4096], [1])  # t2359: "cuda:0 bf16[1, 4096]"
  t2360 = torch.unsqueeze(t2359, 1)  # t2360: "cuda:0 bf16[1, 1, 4096]"
    # t2360 = ltorch.unsqueeze(t2359, 1)  # t2360: "cuda:0 bf16[1, 1, 4096]"
      # t2360 = prims.broadcast_in_dim(t2359, [1, 1, 4096], [0, 2])  # t2360: "cuda:0 bf16[1, 1, 4096]"
  del t2359
  t1645 = Tensor.expand(t2360, (1, 512, 4096))  # t1645: "cuda:0 bf16[1, 512, 4096]"
    # t1645 = ltorch.expand(t2360, (1, 512, 4096))  # t1645: "cuda:0 bf16[1, 512, 4096]"
      # t1645 = prims.broadcast_in_dim(t2360, (1, 512, 4096), (0, 1, 2))  # t1645: "cuda:0 bf16[1, 512, 4096]"
  del t2360
  t2044 = torch.unsqueeze(t69, 0)  # t2044: "cuda:0 bf16[1, 4096]"
    # t2044 = ltorch.unsqueeze(t69, 0)  # t2044: "cuda:0 bf16[1, 4096]"
      # t2044 = prims.broadcast_in_dim(t69, [1, 4096], [1])  # t2044: "cuda:0 bf16[1, 4096]"
  t2045 = torch.unsqueeze(t2044, 1)  # t2045: "cuda:0 bf16[1, 1, 4096]"
    # t2045 = ltorch.unsqueeze(t2044, 1)  # t2045: "cuda:0 bf16[1, 1, 4096]"
      # t2045 = prims.broadcast_in_dim(t2044, [1, 1, 4096], [0, 2])  # t2045: "cuda:0 bf16[1, 1, 4096]"
  del t2044
  t205 = Tensor.expand(t2045, (1, 512, 4096))  # t205: "cuda:0 bf16[1, 512, 4096]"
    # t205 = ltorch.expand(t2045, (1, 512, 4096))  # t205: "cuda:0 bf16[1, 512, 4096]"
      # t205 = prims.broadcast_in_dim(t2045, (1, 512, 4096), (0, 1, 2))  # t205: "cuda:0 bf16[1, 512, 4096]"
  del t2045
  t2380 = torch.unsqueeze(t83, 0)  # t2380: "cuda:0 bf16[1, 4096]"
    # t2380 = ltorch.unsqueeze(t83, 0)  # t2380: "cuda:0 bf16[1, 4096]"
      # t2380 = prims.broadcast_in_dim(t83, [1, 4096], [1])  # t2380: "cuda:0 bf16[1, 4096]"
  t2381 = torch.unsqueeze(t2380, 1)  # t2381: "cuda:0 bf16[1, 1, 4096]"
    # t2381 = ltorch.unsqueeze(t2380, 1)  # t2381: "cuda:0 bf16[1, 1, 4096]"
      # t2381 = prims.broadcast_in_dim(t2380, [1, 1, 4096], [0, 2])  # t2381: "cuda:0 bf16[1, 1, 4096]"
  del t2380
  t1717 = Tensor.expand(t2381, (1, 512, 4096))  # t1717: "cuda:0 bf16[1, 512, 4096]"
    # t1717 = ltorch.expand(t2381, (1, 512, 4096))  # t1717: "cuda:0 bf16[1, 512, 4096]"
      # t1717 = prims.broadcast_in_dim(t2381, (1, 512, 4096), (0, 1, 2))  # t1717: "cuda:0 bf16[1, 512, 4096]"
  del t2381
  t2047 = torch.unsqueeze(t60, 0)  # t2047: "cuda:0 bf16[1, 4096]"
    # t2047 = ltorch.unsqueeze(t60, 0)  # t2047: "cuda:0 bf16[1, 4096]"
      # t2047 = prims.broadcast_in_dim(t60, [1, 4096], [1])  # t2047: "cuda:0 bf16[1, 4096]"
  t2048 = torch.unsqueeze(t2047, 1)  # t2048: "cuda:0 bf16[1, 1, 4096]"
    # t2048 = ltorch.unsqueeze(t2047, 1)  # t2048: "cuda:0 bf16[1, 1, 4096]"
      # t2048 = prims.broadcast_in_dim(t2047, [1, 1, 4096], [0, 2])  # t2048: "cuda:0 bf16[1, 1, 4096]"
  del t2047
  t241 = Tensor.expand(t2048, (1, 512, 4096))  # t241: "cuda:0 bf16[1, 512, 4096]"
    # t241 = ltorch.expand(t2048, (1, 512, 4096))  # t241: "cuda:0 bf16[1, 512, 4096]"
      # t241 = prims.broadcast_in_dim(t2048, (1, 512, 4096), (0, 1, 2))  # t241: "cuda:0 bf16[1, 512, 4096]"
  del t2048
  t2383 = torch.unsqueeze(t59, 0)  # t2383: "cuda:0 bf16[1, 4096]"
    # t2383 = ltorch.unsqueeze(t59, 0)  # t2383: "cuda:0 bf16[1, 4096]"
      # t2383 = prims.broadcast_in_dim(t59, [1, 4096], [1])  # t2383: "cuda:0 bf16[1, 4096]"
  t2384 = torch.unsqueeze(t2383, 1)  # t2384: "cuda:0 bf16[1, 1, 4096]"
    # t2384 = ltorch.unsqueeze(t2383, 1)  # t2384: "cuda:0 bf16[1, 1, 4096]"
      # t2384 = prims.broadcast_in_dim(t2383, [1, 1, 4096], [0, 2])  # t2384: "cuda:0 bf16[1, 1, 4096]"
  del t2383
  t1753 = Tensor.expand(t2384, (1, 512, 4096))  # t1753: "cuda:0 bf16[1, 512, 4096]"
    # t1753 = ltorch.expand(t2384, (1, 512, 4096))  # t1753: "cuda:0 bf16[1, 512, 4096]"
      # t1753 = prims.broadcast_in_dim(t2384, (1, 512, 4096), (0, 1, 2))  # t1753: "cuda:0 bf16[1, 512, 4096]"
  del t2384
  t2068 = torch.unsqueeze(t70, 0)  # t2068: "cuda:0 bf16[1, 4096]"
    # t2068 = ltorch.unsqueeze(t70, 0)  # t2068: "cuda:0 bf16[1, 4096]"
      # t2068 = prims.broadcast_in_dim(t70, [1, 4096], [1])  # t2068: "cuda:0 bf16[1, 4096]"
  t2069 = torch.unsqueeze(t2068, 1)  # t2069: "cuda:0 bf16[1, 1, 4096]"
    # t2069 = ltorch.unsqueeze(t2068, 1)  # t2069: "cuda:0 bf16[1, 1, 4096]"
      # t2069 = prims.broadcast_in_dim(t2068, [1, 1, 4096], [0, 2])  # t2069: "cuda:0 bf16[1, 1, 4096]"
  del t2068
  t313 = Tensor.expand(t2069, (1, 512, 4096))  # t313: "cuda:0 bf16[1, 512, 4096]"
    # t313 = ltorch.expand(t2069, (1, 512, 4096))  # t313: "cuda:0 bf16[1, 512, 4096]"
      # t313 = prims.broadcast_in_dim(t2069, (1, 512, 4096), (0, 1, 2))  # t313: "cuda:0 bf16[1, 512, 4096]"
  del t2069
  t2404 = torch.unsqueeze(t84, 0)  # t2404: "cuda:0 bf16[1, 4096]"
    # t2404 = ltorch.unsqueeze(t84, 0)  # t2404: "cuda:0 bf16[1, 4096]"
      # t2404 = prims.broadcast_in_dim(t84, [1, 4096], [1])  # t2404: "cuda:0 bf16[1, 4096]"
  t2405 = torch.unsqueeze(t2404, 1)  # t2405: "cuda:0 bf16[1, 1, 4096]"
    # t2405 = ltorch.unsqueeze(t2404, 1)  # t2405: "cuda:0 bf16[1, 1, 4096]"
      # t2405 = prims.broadcast_in_dim(t2404, [1, 1, 4096], [0, 2])  # t2405: "cuda:0 bf16[1, 1, 4096]"
  del t2404
  t1825 = Tensor.expand(t2405, (1, 512, 4096))  # t1825: "cuda:0 bf16[1, 512, 4096]"
    # t1825 = ltorch.expand(t2405, (1, 512, 4096))  # t1825: "cuda:0 bf16[1, 512, 4096]"
      # t1825 = prims.broadcast_in_dim(t2405, (1, 512, 4096), (0, 1, 2))  # t1825: "cuda:0 bf16[1, 512, 4096]"
  del t2405
  t2071 = torch.unsqueeze(t61, 0)  # t2071: "cuda:0 bf16[1, 4096]"
    # t2071 = ltorch.unsqueeze(t61, 0)  # t2071: "cuda:0 bf16[1, 4096]"
      # t2071 = prims.broadcast_in_dim(t61, [1, 4096], [1])  # t2071: "cuda:0 bf16[1, 4096]"
  t2072 = torch.unsqueeze(t2071, 1)  # t2072: "cuda:0 bf16[1, 1, 4096]"
    # t2072 = ltorch.unsqueeze(t2071, 1)  # t2072: "cuda:0 bf16[1, 1, 4096]"
      # t2072 = prims.broadcast_in_dim(t2071, [1, 1, 4096], [0, 2])  # t2072: "cuda:0 bf16[1, 1, 4096]"
  del t2071
  t349 = Tensor.expand(t2072, (1, 512, 4096))  # t349: "cuda:0 bf16[1, 512, 4096]"
    # t349 = ltorch.expand(t2072, (1, 512, 4096))  # t349: "cuda:0 bf16[1, 512, 4096]"
      # t349 = prims.broadcast_in_dim(t2072, (1, 512, 4096), (0, 1, 2))  # t349: "cuda:0 bf16[1, 512, 4096]"
  del t2072
  t2407 = torch.unsqueeze(t52, 0)  # t2407: "cuda:0 bf16[1, 4096]"
    # t2407 = ltorch.unsqueeze(t52, 0)  # t2407: "cuda:0 bf16[1, 4096]"
      # t2407 = prims.broadcast_in_dim(t52, [1, 4096], [1])  # t2407: "cuda:0 bf16[1, 4096]"
  t2408 = torch.unsqueeze(t2407, 1)  # t2408: "cuda:0 bf16[1, 1, 4096]"
    # t2408 = ltorch.unsqueeze(t2407, 1)  # t2408: "cuda:0 bf16[1, 1, 4096]"
      # t2408 = prims.broadcast_in_dim(t2407, [1, 1, 4096], [0, 2])  # t2408: "cuda:0 bf16[1, 1, 4096]"
  del t2407
  t1861 = Tensor.expand(t2408, (1, 512, 4096))  # t1861: "cuda:0 bf16[1, 512, 4096]"
    # t1861 = ltorch.expand(t2408, (1, 512, 4096))  # t1861: "cuda:0 bf16[1, 512, 4096]"
      # t1861 = prims.broadcast_in_dim(t2408, (1, 512, 4096), (0, 1, 2))  # t1861: "cuda:0 bf16[1, 512, 4096]"
  del t2408
  t2095 = torch.unsqueeze(t62, 0)  # t2095: "cuda:0 bf16[1, 4096]"
    # t2095 = ltorch.unsqueeze(t62, 0)  # t2095: "cuda:0 bf16[1, 4096]"
      # t2095 = prims.broadcast_in_dim(t62, [1, 4096], [1])  # t2095: "cuda:0 bf16[1, 4096]"
  t2096 = torch.unsqueeze(t2095, 1)  # t2096: "cuda:0 bf16[1, 1, 4096]"
    # t2096 = ltorch.unsqueeze(t2095, 1)  # t2096: "cuda:0 bf16[1, 1, 4096]"
      # t2096 = prims.broadcast_in_dim(t2095, [1, 1, 4096], [0, 2])  # t2096: "cuda:0 bf16[1, 1, 4096]"
  del t2095
  t457 = Tensor.expand(t2096, (1, 512, 4096))  # t457: "cuda:0 bf16[1, 512, 4096]"
    # t457 = ltorch.expand(t2096, (1, 512, 4096))  # t457: "cuda:0 bf16[1, 512, 4096]"
      # t457 = prims.broadcast_in_dim(t2096, (1, 512, 4096), (0, 1, 2))  # t457: "cuda:0 bf16[1, 512, 4096]"
  del t2096
  t2092 = torch.unsqueeze(t71, 0)  # t2092: "cuda:0 bf16[1, 4096]"
    # t2092 = ltorch.unsqueeze(t71, 0)  # t2092: "cuda:0 bf16[1, 4096]"
      # t2092 = prims.broadcast_in_dim(t71, [1, 4096], [1])  # t2092: "cuda:0 bf16[1, 4096]"
  t2093 = torch.unsqueeze(t2092, 1)  # t2093: "cuda:0 bf16[1, 1, 4096]"
    # t2093 = ltorch.unsqueeze(t2092, 1)  # t2093: "cuda:0 bf16[1, 1, 4096]"
      # t2093 = prims.broadcast_in_dim(t2092, [1, 1, 4096], [0, 2])  # t2093: "cuda:0 bf16[1, 1, 4096]"
  del t2092
  t421 = Tensor.expand(t2093, (1, 512, 4096))  # t421: "cuda:0 bf16[1, 512, 4096]"
    # t421 = ltorch.expand(t2093, (1, 512, 4096))  # t421: "cuda:0 bf16[1, 512, 4096]"
      # t421 = prims.broadcast_in_dim(t2093, (1, 512, 4096), (0, 1, 2))  # t421: "cuda:0 bf16[1, 512, 4096]"
  del t2093
  t2116 = torch.unsqueeze(t72, 0)  # t2116: "cuda:0 bf16[1, 4096]"
    # t2116 = ltorch.unsqueeze(t72, 0)  # t2116: "cuda:0 bf16[1, 4096]"
      # t2116 = prims.broadcast_in_dim(t72, [1, 4096], [1])  # t2116: "cuda:0 bf16[1, 4096]"
  t2117 = torch.unsqueeze(t2116, 1)  # t2117: "cuda:0 bf16[1, 1, 4096]"
    # t2117 = ltorch.unsqueeze(t2116, 1)  # t2117: "cuda:0 bf16[1, 1, 4096]"
      # t2117 = prims.broadcast_in_dim(t2116, [1, 1, 4096], [0, 2])  # t2117: "cuda:0 bf16[1, 1, 4096]"
  del t2116
  t529 = Tensor.expand(t2117, (1, 512, 4096))  # t529: "cuda:0 bf16[1, 512, 4096]"
    # t529 = ltorch.expand(t2117, (1, 512, 4096))  # t529: "cuda:0 bf16[1, 512, 4096]"
      # t529 = prims.broadcast_in_dim(t2117, (1, 512, 4096), (0, 1, 2))  # t529: "cuda:0 bf16[1, 512, 4096]"
  del t2117
  t2119 = torch.unsqueeze(t63, 0)  # t2119: "cuda:0 bf16[1, 4096]"
    # t2119 = ltorch.unsqueeze(t63, 0)  # t2119: "cuda:0 bf16[1, 4096]"
      # t2119 = prims.broadcast_in_dim(t63, [1, 4096], [1])  # t2119: "cuda:0 bf16[1, 4096]"
  t2120 = torch.unsqueeze(t2119, 1)  # t2120: "cuda:0 bf16[1, 1, 4096]"
    # t2120 = ltorch.unsqueeze(t2119, 1)  # t2120: "cuda:0 bf16[1, 1, 4096]"
      # t2120 = prims.broadcast_in_dim(t2119, [1, 1, 4096], [0, 2])  # t2120: "cuda:0 bf16[1, 1, 4096]"
  del t2119
  t565 = Tensor.expand(t2120, (1, 512, 4096))  # t565: "cuda:0 bf16[1, 512, 4096]"
    # t565 = ltorch.expand(t2120, (1, 512, 4096))  # t565: "cuda:0 bf16[1, 512, 4096]"
      # t565 = prims.broadcast_in_dim(t2120, (1, 512, 4096), (0, 1, 2))  # t565: "cuda:0 bf16[1, 512, 4096]"
  del t2120
  t2140 = torch.unsqueeze(t73, 0)  # t2140: "cuda:0 bf16[1, 4096]"
    # t2140 = ltorch.unsqueeze(t73, 0)  # t2140: "cuda:0 bf16[1, 4096]"
      # t2140 = prims.broadcast_in_dim(t73, [1, 4096], [1])  # t2140: "cuda:0 bf16[1, 4096]"
  t2141 = torch.unsqueeze(t2140, 1)  # t2141: "cuda:0 bf16[1, 1, 4096]"
    # t2141 = ltorch.unsqueeze(t2140, 1)  # t2141: "cuda:0 bf16[1, 1, 4096]"
      # t2141 = prims.broadcast_in_dim(t2140, [1, 1, 4096], [0, 2])  # t2141: "cuda:0 bf16[1, 1, 4096]"
  del t2140
  t637 = Tensor.expand(t2141, (1, 512, 4096))  # t637: "cuda:0 bf16[1, 512, 4096]"
    # t637 = ltorch.expand(t2141, (1, 512, 4096))  # t637: "cuda:0 bf16[1, 512, 4096]"
      # t637 = prims.broadcast_in_dim(t2141, (1, 512, 4096), (0, 1, 2))  # t637: "cuda:0 bf16[1, 512, 4096]"
  del t2141
  t2143 = torch.unsqueeze(t64, 0)  # t2143: "cuda:0 bf16[1, 4096]"
    # t2143 = ltorch.unsqueeze(t64, 0)  # t2143: "cuda:0 bf16[1, 4096]"
      # t2143 = prims.broadcast_in_dim(t64, [1, 4096], [1])  # t2143: "cuda:0 bf16[1, 4096]"
  t2144 = torch.unsqueeze(t2143, 1)  # t2144: "cuda:0 bf16[1, 1, 4096]"
    # t2144 = ltorch.unsqueeze(t2143, 1)  # t2144: "cuda:0 bf16[1, 1, 4096]"
      # t2144 = prims.broadcast_in_dim(t2143, [1, 1, 4096], [0, 2])  # t2144: "cuda:0 bf16[1, 1, 4096]"
  del t2143
  t673 = Tensor.expand(t2144, (1, 512, 4096))  # t673: "cuda:0 bf16[1, 512, 4096]"
    # t673 = ltorch.expand(t2144, (1, 512, 4096))  # t673: "cuda:0 bf16[1, 512, 4096]"
      # t673 = prims.broadcast_in_dim(t2144, (1, 512, 4096), (0, 1, 2))  # t673: "cuda:0 bf16[1, 512, 4096]"
  del t2144
  t2164 = torch.unsqueeze(t74, 0)  # t2164: "cuda:0 bf16[1, 4096]"
    # t2164 = ltorch.unsqueeze(t74, 0)  # t2164: "cuda:0 bf16[1, 4096]"
      # t2164 = prims.broadcast_in_dim(t74, [1, 4096], [1])  # t2164: "cuda:0 bf16[1, 4096]"
  t2165 = torch.unsqueeze(t2164, 1)  # t2165: "cuda:0 bf16[1, 1, 4096]"
    # t2165 = ltorch.unsqueeze(t2164, 1)  # t2165: "cuda:0 bf16[1, 1, 4096]"
      # t2165 = prims.broadcast_in_dim(t2164, [1, 1, 4096], [0, 2])  # t2165: "cuda:0 bf16[1, 1, 4096]"
  del t2164
  t745 = Tensor.expand(t2165, (1, 512, 4096))  # t745: "cuda:0 bf16[1, 512, 4096]"
    # t745 = ltorch.expand(t2165, (1, 512, 4096))  # t745: "cuda:0 bf16[1, 512, 4096]"
      # t745 = prims.broadcast_in_dim(t2165, (1, 512, 4096), (0, 1, 2))  # t745: "cuda:0 bf16[1, 512, 4096]"
  del t2165
  t2167 = torch.unsqueeze(t65, 0)  # t2167: "cuda:0 bf16[1, 4096]"
    # t2167 = ltorch.unsqueeze(t65, 0)  # t2167: "cuda:0 bf16[1, 4096]"
      # t2167 = prims.broadcast_in_dim(t65, [1, 4096], [1])  # t2167: "cuda:0 bf16[1, 4096]"
  t2168 = torch.unsqueeze(t2167, 1)  # t2168: "cuda:0 bf16[1, 1, 4096]"
    # t2168 = ltorch.unsqueeze(t2167, 1)  # t2168: "cuda:0 bf16[1, 1, 4096]"
      # t2168 = prims.broadcast_in_dim(t2167, [1, 1, 4096], [0, 2])  # t2168: "cuda:0 bf16[1, 1, 4096]"
  del t2167
  t781 = Tensor.expand(t2168, (1, 512, 4096))  # t781: "cuda:0 bf16[1, 512, 4096]"
    # t781 = ltorch.expand(t2168, (1, 512, 4096))  # t781: "cuda:0 bf16[1, 512, 4096]"
      # t781 = prims.broadcast_in_dim(t2168, (1, 512, 4096), (0, 1, 2))  # t781: "cuda:0 bf16[1, 512, 4096]"
  del t2168
  t2188 = torch.unsqueeze(t75, 0)  # t2188: "cuda:0 bf16[1, 4096]"
    # t2188 = ltorch.unsqueeze(t75, 0)  # t2188: "cuda:0 bf16[1, 4096]"
      # t2188 = prims.broadcast_in_dim(t75, [1, 4096], [1])  # t2188: "cuda:0 bf16[1, 4096]"
  t2189 = torch.unsqueeze(t2188, 1)  # t2189: "cuda:0 bf16[1, 1, 4096]"
    # t2189 = ltorch.unsqueeze(t2188, 1)  # t2189: "cuda:0 bf16[1, 1, 4096]"
      # t2189 = prims.broadcast_in_dim(t2188, [1, 1, 4096], [0, 2])  # t2189: "cuda:0 bf16[1, 1, 4096]"
  del t2188
  t853 = Tensor.expand(t2189, (1, 512, 4096))  # t853: "cuda:0 bf16[1, 512, 4096]"
    # t853 = ltorch.expand(t2189, (1, 512, 4096))  # t853: "cuda:0 bf16[1, 512, 4096]"
      # t853 = prims.broadcast_in_dim(t2189, (1, 512, 4096), (0, 1, 2))  # t853: "cuda:0 bf16[1, 512, 4096]"
  del t2189
  t2191 = torch.unsqueeze(t66, 0)  # t2191: "cuda:0 bf16[1, 4096]"
    # t2191 = ltorch.unsqueeze(t66, 0)  # t2191: "cuda:0 bf16[1, 4096]"
      # t2191 = prims.broadcast_in_dim(t66, [1, 4096], [1])  # t2191: "cuda:0 bf16[1, 4096]"
  t2192 = torch.unsqueeze(t2191, 1)  # t2192: "cuda:0 bf16[1, 1, 4096]"
    # t2192 = ltorch.unsqueeze(t2191, 1)  # t2192: "cuda:0 bf16[1, 1, 4096]"
      # t2192 = prims.broadcast_in_dim(t2191, [1, 1, 4096], [0, 2])  # t2192: "cuda:0 bf16[1, 1, 4096]"
  del t2191
  t889 = Tensor.expand(t2192, (1, 512, 4096))  # t889: "cuda:0 bf16[1, 512, 4096]"
    # t889 = ltorch.expand(t2192, (1, 512, 4096))  # t889: "cuda:0 bf16[1, 512, 4096]"
      # t889 = prims.broadcast_in_dim(t2192, (1, 512, 4096), (0, 1, 2))  # t889: "cuda:0 bf16[1, 512, 4096]"
  del t2192
  t2212 = torch.unsqueeze(t76, 0)  # t2212: "cuda:0 bf16[1, 4096]"
    # t2212 = ltorch.unsqueeze(t76, 0)  # t2212: "cuda:0 bf16[1, 4096]"
      # t2212 = prims.broadcast_in_dim(t76, [1, 4096], [1])  # t2212: "cuda:0 bf16[1, 4096]"
  t2213 = torch.unsqueeze(t2212, 1)  # t2213: "cuda:0 bf16[1, 1, 4096]"
    # t2213 = ltorch.unsqueeze(t2212, 1)  # t2213: "cuda:0 bf16[1, 1, 4096]"
      # t2213 = prims.broadcast_in_dim(t2212, [1, 1, 4096], [0, 2])  # t2213: "cuda:0 bf16[1, 1, 4096]"
  del t2212
  t961 = Tensor.expand(t2213, (1, 512, 4096))  # t961: "cuda:0 bf16[1, 512, 4096]"
    # t961 = ltorch.expand(t2213, (1, 512, 4096))  # t961: "cuda:0 bf16[1, 512, 4096]"
      # t961 = prims.broadcast_in_dim(t2213, (1, 512, 4096), (0, 1, 2))  # t961: "cuda:0 bf16[1, 512, 4096]"
  del t2213
  t2215 = torch.unsqueeze(t67, 0)  # t2215: "cuda:0 bf16[1, 4096]"
    # t2215 = ltorch.unsqueeze(t67, 0)  # t2215: "cuda:0 bf16[1, 4096]"
      # t2215 = prims.broadcast_in_dim(t67, [1, 4096], [1])  # t2215: "cuda:0 bf16[1, 4096]"
  t2216 = torch.unsqueeze(t2215, 1)  # t2216: "cuda:0 bf16[1, 1, 4096]"
    # t2216 = ltorch.unsqueeze(t2215, 1)  # t2216: "cuda:0 bf16[1, 1, 4096]"
      # t2216 = prims.broadcast_in_dim(t2215, [1, 1, 4096], [0, 2])  # t2216: "cuda:0 bf16[1, 1, 4096]"
  del t2215
  t997 = Tensor.expand(t2216, (1, 512, 4096))  # t997: "cuda:0 bf16[1, 512, 4096]"
    # t997 = ltorch.expand(t2216, (1, 512, 4096))  # t997: "cuda:0 bf16[1, 512, 4096]"
      # t997 = prims.broadcast_in_dim(t2216, (1, 512, 4096), (0, 1, 2))  # t997: "cuda:0 bf16[1, 512, 4096]"
  del t2216
  t2236 = torch.unsqueeze(t77, 0)  # t2236: "cuda:0 bf16[1, 4096]"
    # t2236 = ltorch.unsqueeze(t77, 0)  # t2236: "cuda:0 bf16[1, 4096]"
      # t2236 = prims.broadcast_in_dim(t77, [1, 4096], [1])  # t2236: "cuda:0 bf16[1, 4096]"
  t2237 = torch.unsqueeze(t2236, 1)  # t2237: "cuda:0 bf16[1, 1, 4096]"
    # t2237 = ltorch.unsqueeze(t2236, 1)  # t2237: "cuda:0 bf16[1, 1, 4096]"
      # t2237 = prims.broadcast_in_dim(t2236, [1, 1, 4096], [0, 2])  # t2237: "cuda:0 bf16[1, 1, 4096]"
  del t2236
  t1069 = Tensor.expand(t2237, (1, 512, 4096))  # t1069: "cuda:0 bf16[1, 512, 4096]"
    # t1069 = ltorch.expand(t2237, (1, 512, 4096))  # t1069: "cuda:0 bf16[1, 512, 4096]"
      # t1069 = prims.broadcast_in_dim(t2237, (1, 512, 4096), (0, 1, 2))  # t1069: "cuda:0 bf16[1, 512, 4096]"
  del t2237
  t2239 = torch.unsqueeze(t68, 0)  # t2239: "cuda:0 bf16[1, 4096]"
    # t2239 = ltorch.unsqueeze(t68, 0)  # t2239: "cuda:0 bf16[1, 4096]"
      # t2239 = prims.broadcast_in_dim(t68, [1, 4096], [1])  # t2239: "cuda:0 bf16[1, 4096]"
  t2240 = torch.unsqueeze(t2239, 1)  # t2240: "cuda:0 bf16[1, 1, 4096]"
    # t2240 = ltorch.unsqueeze(t2239, 1)  # t2240: "cuda:0 bf16[1, 1, 4096]"
      # t2240 = prims.broadcast_in_dim(t2239, [1, 1, 4096], [0, 2])  # t2240: "cuda:0 bf16[1, 1, 4096]"
  del t2239
  t1105 = Tensor.expand(t2240, (1, 512, 4096))  # t1105: "cuda:0 bf16[1, 512, 4096]"
    # t1105 = ltorch.expand(t2240, (1, 512, 4096))  # t1105: "cuda:0 bf16[1, 512, 4096]"
      # t1105 = prims.broadcast_in_dim(t2240, (1, 512, 4096), (0, 1, 2))  # t1105: "cuda:0 bf16[1, 512, 4096]"
  del t2240
  t2260 = torch.unsqueeze(t78, 0)  # t2260: "cuda:0 bf16[1, 4096]"
    # t2260 = ltorch.unsqueeze(t78, 0)  # t2260: "cuda:0 bf16[1, 4096]"
      # t2260 = prims.broadcast_in_dim(t78, [1, 4096], [1])  # t2260: "cuda:0 bf16[1, 4096]"
  t2261 = torch.unsqueeze(t2260, 1)  # t2261: "cuda:0 bf16[1, 1, 4096]"
    # t2261 = ltorch.unsqueeze(t2260, 1)  # t2261: "cuda:0 bf16[1, 1, 4096]"
      # t2261 = prims.broadcast_in_dim(t2260, [1, 1, 4096], [0, 2])  # t2261: "cuda:0 bf16[1, 1, 4096]"
  del t2260
  t1177 = Tensor.expand(t2261, (1, 512, 4096))  # t1177: "cuda:0 bf16[1, 512, 4096]"
    # t1177 = ltorch.expand(t2261, (1, 512, 4096))  # t1177: "cuda:0 bf16[1, 512, 4096]"
      # t1177 = prims.broadcast_in_dim(t2261, (1, 512, 4096), (0, 1, 2))  # t1177: "cuda:0 bf16[1, 512, 4096]"
  del t2261
  t2263 = torch.unsqueeze(t54, 0)  # t2263: "cuda:0 bf16[1, 4096]"
    # t2263 = ltorch.unsqueeze(t54, 0)  # t2263: "cuda:0 bf16[1, 4096]"
      # t2263 = prims.broadcast_in_dim(t54, [1, 4096], [1])  # t2263: "cuda:0 bf16[1, 4096]"
  t2264 = torch.unsqueeze(t2263, 1)  # t2264: "cuda:0 bf16[1, 1, 4096]"
    # t2264 = ltorch.unsqueeze(t2263, 1)  # t2264: "cuda:0 bf16[1, 1, 4096]"
      # t2264 = prims.broadcast_in_dim(t2263, [1, 1, 4096], [0, 2])  # t2264: "cuda:0 bf16[1, 1, 4096]"
  del t2263
  t1213 = Tensor.expand(t2264, (1, 512, 4096))  # t1213: "cuda:0 bf16[1, 512, 4096]"
    # t1213 = ltorch.expand(t2264, (1, 512, 4096))  # t1213: "cuda:0 bf16[1, 512, 4096]"
      # t1213 = prims.broadcast_in_dim(t2264, (1, 512, 4096), (0, 1, 2))  # t1213: "cuda:0 bf16[1, 512, 4096]"
  del t2264
  t2284 = torch.unsqueeze(t79, 0)  # t2284: "cuda:0 bf16[1, 4096]"
    # t2284 = ltorch.unsqueeze(t79, 0)  # t2284: "cuda:0 bf16[1, 4096]"
      # t2284 = prims.broadcast_in_dim(t79, [1, 4096], [1])  # t2284: "cuda:0 bf16[1, 4096]"
  t2285 = torch.unsqueeze(t2284, 1)  # t2285: "cuda:0 bf16[1, 1, 4096]"
    # t2285 = ltorch.unsqueeze(t2284, 1)  # t2285: "cuda:0 bf16[1, 1, 4096]"
      # t2285 = prims.broadcast_in_dim(t2284, [1, 1, 4096], [0, 2])  # t2285: "cuda:0 bf16[1, 1, 4096]"
  del t2284
  t1285 = Tensor.expand(t2285, (1, 512, 4096))  # t1285: "cuda:0 bf16[1, 512, 4096]"
    # t1285 = ltorch.expand(t2285, (1, 512, 4096))  # t1285: "cuda:0 bf16[1, 512, 4096]"
      # t1285 = prims.broadcast_in_dim(t2285, (1, 512, 4096), (0, 1, 2))  # t1285: "cuda:0 bf16[1, 512, 4096]"
  del t2285
  t2287 = torch.unsqueeze(t55, 0)  # t2287: "cuda:0 bf16[1, 4096]"
    # t2287 = ltorch.unsqueeze(t55, 0)  # t2287: "cuda:0 bf16[1, 4096]"
      # t2287 = prims.broadcast_in_dim(t55, [1, 4096], [1])  # t2287: "cuda:0 bf16[1, 4096]"
  t2288 = torch.unsqueeze(t2287, 1)  # t2288: "cuda:0 bf16[1, 1, 4096]"
    # t2288 = ltorch.unsqueeze(t2287, 1)  # t2288: "cuda:0 bf16[1, 1, 4096]"
      # t2288 = prims.broadcast_in_dim(t2287, [1, 1, 4096], [0, 2])  # t2288: "cuda:0 bf16[1, 1, 4096]"
  del t2287
  t1321 = Tensor.expand(t2288, (1, 512, 4096))  # t1321: "cuda:0 bf16[1, 512, 4096]"
    # t1321 = ltorch.expand(t2288, (1, 512, 4096))  # t1321: "cuda:0 bf16[1, 512, 4096]"
      # t1321 = prims.broadcast_in_dim(t2288, (1, 512, 4096), (0, 1, 2))  # t1321: "cuda:0 bf16[1, 512, 4096]"
  del t2288
  t2308 = torch.unsqueeze(t80, 0)  # t2308: "cuda:0 bf16[1, 4096]"
    # t2308 = ltorch.unsqueeze(t80, 0)  # t2308: "cuda:0 bf16[1, 4096]"
      # t2308 = prims.broadcast_in_dim(t80, [1, 4096], [1])  # t2308: "cuda:0 bf16[1, 4096]"
  t2309 = torch.unsqueeze(t2308, 1)  # t2309: "cuda:0 bf16[1, 1, 4096]"
    # t2309 = ltorch.unsqueeze(t2308, 1)  # t2309: "cuda:0 bf16[1, 1, 4096]"
      # t2309 = prims.broadcast_in_dim(t2308, [1, 1, 4096], [0, 2])  # t2309: "cuda:0 bf16[1, 1, 4096]"
  del t2308
  t1393 = Tensor.expand(t2309, (1, 512, 4096))  # t1393: "cuda:0 bf16[1, 512, 4096]"
    # t1393 = ltorch.expand(t2309, (1, 512, 4096))  # t1393: "cuda:0 bf16[1, 512, 4096]"
      # t1393 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2))  # t1393: "cuda:0 bf16[1, 512, 4096]"
  del t2309
  t2311 = torch.unsqueeze(t56, 0)  # t2311: "cuda:0 bf16[1, 4096]"
    # t2311 = ltorch.unsqueeze(t56, 0)  # t2311: "cuda:0 bf16[1, 4096]"
      # t2311 = prims.broadcast_in_dim(t56, [1, 4096], [1])  # t2311: "cuda:0 bf16[1, 4096]"
  t2312 = torch.unsqueeze(t2311, 1)  # t2312: "cuda:0 bf16[1, 1, 4096]"
    # t2312 = ltorch.unsqueeze(t2311, 1)  # t2312: "cuda:0 bf16[1, 1, 4096]"
      # t2312 = prims.broadcast_in_dim(t2311, [1, 1, 4096], [0, 2])  # t2312: "cuda:0 bf16[1, 1, 4096]"
  del t2311
  t1429 = Tensor.expand(t2312, (1, 512, 4096))  # t1429: "cuda:0 bf16[1, 512, 4096]"
    # t1429 = ltorch.expand(t2312, (1, 512, 4096))  # t1429: "cuda:0 bf16[1, 512, 4096]"
      # t1429 = prims.broadcast_in_dim(t2312, (1, 512, 4096), (0, 1, 2))  # t1429: "cuda:0 bf16[1, 512, 4096]"
  del t2312
  t2332 = torch.unsqueeze(t81, 0)  # t2332: "cuda:0 bf16[1, 4096]"
    # t2332 = ltorch.unsqueeze(t81, 0)  # t2332: "cuda:0 bf16[1, 4096]"
      # t2332 = prims.broadcast_in_dim(t81, [1, 4096], [1])  # t2332: "cuda:0 bf16[1, 4096]"
  t2333 = torch.unsqueeze(t2332, 1)  # t2333: "cuda:0 bf16[1, 1, 4096]"
    # t2333 = ltorch.unsqueeze(t2332, 1)  # t2333: "cuda:0 bf16[1, 1, 4096]"
      # t2333 = prims.broadcast_in_dim(t2332, [1, 1, 4096], [0, 2])  # t2333: "cuda:0 bf16[1, 1, 4096]"
  del t2332
  t1501 = Tensor.expand(t2333, (1, 512, 4096))  # t1501: "cuda:0 bf16[1, 512, 4096]"
    # t1501 = ltorch.expand(t2333, (1, 512, 4096))  # t1501: "cuda:0 bf16[1, 512, 4096]"
      # t1501 = prims.broadcast_in_dim(t2333, (1, 512, 4096), (0, 1, 2))  # t1501: "cuda:0 bf16[1, 512, 4096]"
  del t2333
  t2335 = torch.unsqueeze(t57, 0)  # t2335: "cuda:0 bf16[1, 4096]"
    # t2335 = ltorch.unsqueeze(t57, 0)  # t2335: "cuda:0 bf16[1, 4096]"
      # t2335 = prims.broadcast_in_dim(t57, [1, 4096], [1])  # t2335: "cuda:0 bf16[1, 4096]"
  t2336 = torch.unsqueeze(t2335, 1)  # t2336: "cuda:0 bf16[1, 1, 4096]"
    # t2336 = ltorch.unsqueeze(t2335, 1)  # t2336: "cuda:0 bf16[1, 1, 4096]"
      # t2336 = prims.broadcast_in_dim(t2335, [1, 1, 4096], [0, 2])  # t2336: "cuda:0 bf16[1, 1, 4096]"
  del t2335
  t1537 = Tensor.expand(t2336, (1, 512, 4096))  # t1537: "cuda:0 bf16[1, 512, 4096]"
    # t1537 = ltorch.expand(t2336, (1, 512, 4096))  # t1537: "cuda:0 bf16[1, 512, 4096]"
      # t1537 = prims.broadcast_in_dim(t2336, (1, 512, 4096), (0, 1, 2))  # t1537: "cuda:0 bf16[1, 512, 4096]"
  del t2336
  t2036 = torch.unsqueeze(t118, 0)  # t2036: "cuda:0 f32[1, 512, 128]"
    # t2036 = ltorch.unsqueeze(t118, 0)  # t2036: "cuda:0 f32[1, 512, 128]"
      # t2036 = prims.broadcast_in_dim(t118, [1, 512, 128], [1, 2])  # t2036: "cuda:0 f32[1, 512, 128]"
  del t118
  t2037 = torch.unsqueeze(t2036, 1)  # t2037: "cuda:0 f32[1, 1, 512, 128]"
    # t2037 = ltorch.unsqueeze(t2036, 1)  # t2037: "cuda:0 f32[1, 1, 512, 128]"
      # t2037 = prims.broadcast_in_dim(t2036, [1, 1, 512, 128], [0, 2, 3])  # t2037: "cuda:0 f32[1, 1, 512, 128]"
  del t2036
  t154 = Tensor.expand(t2037, (1, 32, 512, 128))  # t154: "cuda:0 f32[1, 32, 512, 128]"
    # t154 = ltorch.expand(t2037, (1, 32, 512, 128))  # t154: "cuda:0 f32[1, 32, 512, 128]"
      # t154 = prims.broadcast_in_dim(t2037, (1, 32, 512, 128), (0, 1, 2, 3))  # t154: "cuda:0 f32[1, 32, 512, 128]"
  del t2037
  t2039 = torch.unsqueeze(t119, 0)  # t2039: "cuda:0 f32[1, 512, 128]"
    # t2039 = ltorch.unsqueeze(t119, 0)  # t2039: "cuda:0 f32[1, 512, 128]"
      # t2039 = prims.broadcast_in_dim(t119, [1, 512, 128], [1, 2])  # t2039: "cuda:0 f32[1, 512, 128]"
  del t119
  t2040 = torch.unsqueeze(t2039, 1)  # t2040: "cuda:0 f32[1, 1, 512, 128]"
    # t2040 = ltorch.unsqueeze(t2039, 1)  # t2040: "cuda:0 f32[1, 1, 512, 128]"
      # t2040 = prims.broadcast_in_dim(t2039, [1, 1, 512, 128], [0, 2, 3])  # t2040: "cuda:0 f32[1, 1, 512, 128]"
  del t2039
  t157 = Tensor.expand(t2040, (1, 32, 512, 128))  # t157: "cuda:0 f32[1, 32, 512, 128]"
    # t157 = ltorch.expand(t2040, (1, 32, 512, 128))  # t157: "cuda:0 f32[1, 32, 512, 128]"
      # t157 = prims.broadcast_in_dim(t2040, (1, 32, 512, 128), (0, 1, 2, 3))  # t157: "cuda:0 f32[1, 32, 512, 128]"
  del t2040
  [t129, t137] = nvFusion0(t122, t133)
    # t123 = prims.convert_element_type(t122, dtypes.float32)  # t123: "cuda:0 f32[1, 512, 4096]"
    # t124 = prims.mul(t123, t123)  # t124: "cuda:0 f32[1, 512, 4096]"
    # t125 = prims.sum(t124, (2,))  # t125: "cuda:0 f32[1, 512]"
    # t126 = prims.broadcast_in_dim(t125, [1, 512, 1], [0, 1])  # t126: "cuda:0 f32[1, 512, 1]"
    # t127 = prims.div(t126, 4096.0)  # t127: "cuda:0 f32[1, 512, 1]"
    # t128 = prims.add(t127, 1e-05)  # t128: "cuda:0 f32[1, 512, 1]"
    # t129 = prims.rsqrt(t128)  # t129: "cuda:0 f32[1, 512, 1]"
    # t130 = prims.broadcast_in_dim(t129, (1, 512, 4096), (0, 1, 2))  # t130: "cuda:0 f32[1, 512, 4096]"
    # t131 = prims.mul(t123, t130)  # t131: "cuda:0 f32[1, 512, 4096]"
    # t135 = prims.convert_element_type(t133, dtypes.float32)  # t135: "cuda:0 f32[1, 512, 4096]"
    # t136 = prims.mul(t131, t135)  # t136: "cuda:0 f32[1, 512, 4096]"
    # t137 = prims.convert_element_type(t136, dtypes.bfloat16)  # t137: "cuda:0 bf16[1, 512, 4096]"
  t138 = torch.nn.functional.linear(t137, t3, None)  # t138: "cuda:0 bf16[1, 512, 12288]"
    # t138 = ltorch.linear(t137, t3, None)  # t138: "cuda:0 bf16[1, 512, 12288]"
      # t138 = prims.linear(t137, t3, None)  # t138: "cuda:0 bf16[1, 512, 12288]"
  t139 = torch.reshape(t138, (1, 512, 32, 3, 128))  # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t139 = ltorch.reshape(t138, (1, 512, 32, 3, 128))  # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t139 = prims.reshape(t138, (1, 512, 32, 3, 128))  # t139: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t138
  t140 = torch.permute(t139, (0, 2, 3, 1, 4))  # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t140 = ltorch.permute(t139, (0, 2, 3, 1, 4))  # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t140 = prims.transpose(t139, (0, 2, 3, 1, 4))  # t140: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t139
  (t141, t142, t143) = torch.split(t140, (1, 1, 1), 2)
    # (t141, t142, t143) = ltorch.split(t140, (1, 1, 1), 2)
      # t141 = prims.slice_prim(t140, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t141: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t142 = prims.slice_prim(t140, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t142: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t143 = prims.slice_prim(t140, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t143: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t140
  t144 = torch.reshape(t141, (1, 32, 512, 128))  # t144: "cuda:0 bf16[1, 32, 512, 128]"
    # t144 = ltorch.reshape(t141, (1, 32, 512, 128))  # t144: "cuda:0 bf16[1, 32, 512, 128]"
      # t144 = prims.reshape(t141, (1, 32, 512, 128))  # t144: "cuda:0 bf16[1, 32, 512, 128]"
  del t141
  t145 = torch.reshape(t142, (1, 32, 512, 128))  # t145: "cuda:0 bf16[1, 32, 512, 128]"
    # t145 = ltorch.reshape(t142, (1, 32, 512, 128))  # t145: "cuda:0 bf16[1, 32, 512, 128]"
      # t145 = prims.reshape(t142, (1, 32, 512, 128))  # t145: "cuda:0 bf16[1, 32, 512, 128]"
  del t142
  t146 = torch.reshape(t143, (1, 32, 512, 128))  # t146: "cuda:0 bf16[1, 32, 512, 128]"
    # t146 = ltorch.reshape(t143, (1, 32, 512, 128))  # t146: "cuda:0 bf16[1, 32, 512, 128]"
      # t146 = prims.reshape(t143, (1, 32, 512, 128))  # t146: "cuda:0 bf16[1, 32, 512, 128]"
  del t143
  t147 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t147: "cuda:0 bf16[1, 32, 512, 128]"
  t162 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t162: "cuda:0 bf16[1, 32, 512, 128]"
  t177 = torch_slice_prim_impl(t144, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t177: "cuda:0 bf16[1, 32, 512, 0]"
  del t144
  t179 = torch_slice_prim_impl(t145, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t179: "cuda:0 bf16[1, 32, 512, 0]"
  del t145
  t149 = torch_slice_prim_impl(t147, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t149: "cuda:0 bf16[1, 32, 512, 64]"
  t148 = torch_slice_prim_impl(t147, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t148: "cuda:0 bf16[1, 32, 512, 64]"
  t163 = torch_slice_prim_impl(t162, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t163: "cuda:0 bf16[1, 32, 512, 64]"
  t164 = torch_slice_prim_impl(t162, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t164: "cuda:0 bf16[1, 32, 512, 64]"
  [t152, t167] = nvFusion1(t147, t149, t162, t164)
    # t150 = prims.convert_element_type(t149, dtypes.float32)  # t150: "cuda:0 f32[1, 32, 512, 64]"
    # t151 = prims.neg(t150)  # t151: "cuda:0 f32[1, 32, 512, 64]"
    # t152 = prims.convert_element_type(t151, dtypes.bfloat16)  # t152: "cuda:0 bf16[1, 32, 512, 64]"
    # t165 = prims.convert_element_type(t164, dtypes.float32)  # t165: "cuda:0 f32[1, 32, 512, 64]"
    # t166 = prims.neg(t165)  # t166: "cuda:0 f32[1, 32, 512, 64]"
    # t167 = prims.convert_element_type(t166, dtypes.bfloat16)  # t167: "cuda:0 bf16[1, 32, 512, 64]"
  del t149, t164
  t168 = torch.cat((t167, t163), -1)  # t168: "cuda:0 bf16[1, 32, 512, 128]"
    # t168 = ltorch.cat((t167, t163), -1)  # t168: "cuda:0 bf16[1, 32, 512, 128]"
      # t168 = prims.cat((t167, t163), -1)  # t168: "cuda:0 bf16[1, 32, 512, 128]"
  del t167, t163
  t153 = torch.cat((t152, t148), -1)  # t153: "cuda:0 bf16[1, 32, 512, 128]"
    # t153 = ltorch.cat((t152, t148), -1)  # t153: "cuda:0 bf16[1, 32, 512, 128]"
      # t153 = prims.cat((t152, t148), -1)  # t153: "cuda:0 bf16[1, 32, 512, 128]"
  del t152, t148
  [t161, t176] = nvFusion2(t147, t153, t154, t157, t162, t168)
    # t155 = prims.convert_element_type(t147, dtypes.float32)  # t155: "cuda:0 f32[1, 32, 512, 128]"
    # t170 = prims.convert_element_type(t162, dtypes.float32)  # t170: "cuda:0 f32[1, 32, 512, 128]"
    # t156 = prims.mul(t155, t154)  # t156: "cuda:0 f32[1, 32, 512, 128]"
    # t158 = prims.convert_element_type(t153, dtypes.float32)  # t158: "cuda:0 f32[1, 32, 512, 128]"
    # t159 = prims.mul(t158, t157)  # t159: "cuda:0 f32[1, 32, 512, 128]"
    # t160 = prims.add(t156, t159)  # t160: "cuda:0 f32[1, 32, 512, 128]"
    # t161 = prims.convert_element_type(t160, dtypes.bfloat16)  # t161: "cuda:0 bf16[1, 32, 512, 128]"
    # t171 = prims.mul(t170, t154)  # t171: "cuda:0 f32[1, 32, 512, 128]"
    # t173 = prims.convert_element_type(t168, dtypes.float32)  # t173: "cuda:0 f32[1, 32, 512, 128]"
    # t174 = prims.mul(t173, t157)  # t174: "cuda:0 f32[1, 32, 512, 128]"
    # t175 = prims.add(t171, t174)  # t175: "cuda:0 f32[1, 32, 512, 128]"
    # t176 = prims.convert_element_type(t175, dtypes.bfloat16)  # t176: "cuda:0 bf16[1, 32, 512, 128]"
  del t147, t153, t162, t168
  t178 = torch.cat((t161, t177), -1)  # t178: "cuda:0 bf16[1, 32, 512, 128]"
    # t178 = ltorch.cat((t161, t177), -1)  # t178: "cuda:0 bf16[1, 32, 512, 128]"
      # t178 = prims.cat((t161, t177), -1)  # t178: "cuda:0 bf16[1, 32, 512, 128]"
  del t161, t177
  t180 = torch.cat((t176, t179), -1)  # t180: "cuda:0 bf16[1, 32, 512, 128]"
    # t180 = ltorch.cat((t176, t179), -1)  # t180: "cuda:0 bf16[1, 32, 512, 128]"
      # t180 = prims.cat((t176, t179), -1)  # t180: "cuda:0 bf16[1, 32, 512, 128]"
  del t176, t179
  (t181, t182, t183, t184, _, _, t185, t186, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t178, t180, t146, 0.0, True, scale=0.08838834764831843)
  t188 = torch.permute(t181, (0, 2, 1, 3))  # t188: "cuda:0 bf16[1, 512, 32, 128]"
    # t188 = ltorch.permute(t181, (0, 2, 1, 3))  # t188: "cuda:0 bf16[1, 512, 32, 128]"
      # t188 = prims.transpose(t181, (0, 2, 1, 3))  # t188: "cuda:0 bf16[1, 512, 32, 128]"
  t189 = torch.reshape(t188, (1, 512, 4096))  # t189: "cuda:0 bf16[1, 512, 4096]"
    # t189 = ltorch.reshape(t188, (1, 512, 4096))  # t189: "cuda:0 bf16[1, 512, 4096]"
      # t189 = prims.reshape(t188, (1, 512, 4096))  # t189: "cuda:0 bf16[1, 512, 4096]"
  del t188
  t190 = torch.nn.functional.linear(t189, t85, None)  # t190: "cuda:0 bf16[1, 512, 4096]"
    # t190 = ltorch.linear(t189, t85, None)  # t190: "cuda:0 bf16[1, 512, 4096]"
      # t190 = prims.linear(t189, t85, None)  # t190: "cuda:0 bf16[1, 512, 4096]"
  [t194, t201, t209] = nvFusion3(t122, t190, t205)
    # t191 = prims.convert_element_type(t190, dtypes.float32)  # t191: "cuda:0 f32[1, 512, 4096]"
    # t192 = prims.convert_element_type(t122, dtypes.float32)  # t192: "cuda:0 f32[1, 512, 4096]"
    # t193 = prims.add(t191, t192)  # t193: "cuda:0 f32[1, 512, 4096]"
    # t194 = prims.convert_element_type(t193, dtypes.bfloat16)  # t194: "cuda:0 bf16[1, 512, 4096]"
    # t196 = prims.mul(t193, t193)  # t196: "cuda:0 f32[1, 512, 4096]"
    # t197 = prims.sum(t196, (2,))  # t197: "cuda:0 f32[1, 512]"
    # t198 = prims.broadcast_in_dim(t197, [1, 512, 1], [0, 1])  # t198: "cuda:0 f32[1, 512, 1]"
    # t199 = prims.div(t198, 4096.0)  # t199: "cuda:0 f32[1, 512, 1]"
    # t200 = prims.add(t199, 1e-05)  # t200: "cuda:0 f32[1, 512, 1]"
    # t201 = prims.rsqrt(t200)  # t201: "cuda:0 f32[1, 512, 1]"
    # t202 = prims.broadcast_in_dim(t201, (1, 512, 4096), (0, 1, 2))  # t202: "cuda:0 f32[1, 512, 4096]"
    # t203 = prims.mul(t193, t202)  # t203: "cuda:0 f32[1, 512, 4096]"
    # t207 = prims.convert_element_type(t205, dtypes.float32)  # t207: "cuda:0 f32[1, 512, 4096]"
    # t208 = prims.mul(t203, t207)  # t208: "cuda:0 f32[1, 512, 4096]"
    # t209 = prims.convert_element_type(t208, dtypes.bfloat16)  # t209: "cuda:0 bf16[1, 512, 4096]"
  t210 = torch.nn.functional.linear(t209, t19, None)  # t210: "cuda:0 bf16[1, 512, 11008]"
    # t210 = ltorch.linear(t209, t19, None)  # t210: "cuda:0 bf16[1, 512, 11008]"
      # t210 = prims.linear(t209, t19, None)  # t210: "cuda:0 bf16[1, 512, 11008]"
  t211 = torch.nn.functional.linear(t209, t35, None)  # t211: "cuda:0 bf16[1, 512, 11008]"
    # t211 = ltorch.linear(t209, t35, None)  # t211: "cuda:0 bf16[1, 512, 11008]"
      # t211 = prims.linear(t209, t35, None)  # t211: "cuda:0 bf16[1, 512, 11008]"
  [t225] = nvFusion4(t210, t211)
    # t212 = prims.convert_element_type(t210, dtypes.float32)  # t212: "cuda:0 f32[1, 512, 11008]"
    # t213 = prims.neg(t212)  # t213: "cuda:0 f32[1, 512, 11008]"
    # t214 = prims.exp(t213)  # t214: "cuda:0 f32[1, 512, 11008]"
    # t215 = prims.add(1.0, t214)  # t215: "cuda:0 f32[1, 512, 11008]"
    # t216 = prims.reciprocal(t215)  # t216: "cuda:0 f32[1, 512, 11008]"
    # t220 = prims.mul(t212, t216)  # t220: "cuda:0 f32[1, 512, 11008]"
    # t223 = prims.convert_element_type(t211, dtypes.float32)  # t223: "cuda:0 f32[1, 512, 11008]"
    # t224 = prims.mul(t220, t223)  # t224: "cuda:0 f32[1, 512, 11008]"
    # t225 = prims.convert_element_type(t224, dtypes.bfloat16)  # t225: "cuda:0 bf16[1, 512, 11008]"
  t226 = torch.nn.functional.linear(t225, t86, None)  # t226: "cuda:0 bf16[1, 512, 4096]"
    # t226 = ltorch.linear(t225, t86, None)  # t226: "cuda:0 bf16[1, 512, 4096]"
      # t226 = prims.linear(t225, t86, None)  # t226: "cuda:0 bf16[1, 512, 4096]"
  [t230, t237, t245] = nvFusion5(t194, t226, t241)
    # t228 = prims.convert_element_type(t194, dtypes.float32)  # t228: "cuda:0 f32[1, 512, 4096]"
    # t227 = prims.convert_element_type(t226, dtypes.float32)  # t227: "cuda:0 f32[1, 512, 4096]"
    # t229 = prims.add(t227, t228)  # t229: "cuda:0 f32[1, 512, 4096]"
    # t230 = prims.convert_element_type(t229, dtypes.bfloat16)  # t230: "cuda:0 bf16[1, 512, 4096]"
    # t232 = prims.mul(t229, t229)  # t232: "cuda:0 f32[1, 512, 4096]"
    # t233 = prims.sum(t232, (2,))  # t233: "cuda:0 f32[1, 512]"
    # t234 = prims.broadcast_in_dim(t233, [1, 512, 1], [0, 1])  # t234: "cuda:0 f32[1, 512, 1]"
    # t235 = prims.div(t234, 4096.0)  # t235: "cuda:0 f32[1, 512, 1]"
    # t236 = prims.add(t235, 1e-05)  # t236: "cuda:0 f32[1, 512, 1]"
    # t237 = prims.rsqrt(t236)  # t237: "cuda:0 f32[1, 512, 1]"
    # t238 = prims.broadcast_in_dim(t237, (1, 512, 4096), (0, 1, 2))  # t238: "cuda:0 f32[1, 512, 4096]"
    # t239 = prims.mul(t229, t238)  # t239: "cuda:0 f32[1, 512, 4096]"
    # t243 = prims.convert_element_type(t241, dtypes.float32)  # t243: "cuda:0 f32[1, 512, 4096]"
    # t244 = prims.mul(t239, t243)  # t244: "cuda:0 f32[1, 512, 4096]"
    # t245 = prims.convert_element_type(t244, dtypes.bfloat16)  # t245: "cuda:0 bf16[1, 512, 4096]"
  t246 = torch.nn.functional.linear(t245, t4, None)  # t246: "cuda:0 bf16[1, 512, 12288]"
    # t246 = ltorch.linear(t245, t4, None)  # t246: "cuda:0 bf16[1, 512, 12288]"
      # t246 = prims.linear(t245, t4, None)  # t246: "cuda:0 bf16[1, 512, 12288]"
  t247 = torch.reshape(t246, (1, 512, 32, 3, 128))  # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t247 = ltorch.reshape(t246, (1, 512, 32, 3, 128))  # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t247 = prims.reshape(t246, (1, 512, 32, 3, 128))  # t247: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t246
  t248 = torch.permute(t247, (0, 2, 3, 1, 4))  # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t248 = ltorch.permute(t247, (0, 2, 3, 1, 4))  # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t248 = prims.transpose(t247, (0, 2, 3, 1, 4))  # t248: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t247
  (t249, t250, t251) = torch.split(t248, (1, 1, 1), 2)
    # (t249, t250, t251) = ltorch.split(t248, (1, 1, 1), 2)
      # t249 = prims.slice_prim(t248, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t249: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t250 = prims.slice_prim(t248, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t250: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t251 = prims.slice_prim(t248, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t251: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t248
  t252 = torch.reshape(t249, (1, 32, 512, 128))  # t252: "cuda:0 bf16[1, 32, 512, 128]"
    # t252 = ltorch.reshape(t249, (1, 32, 512, 128))  # t252: "cuda:0 bf16[1, 32, 512, 128]"
      # t252 = prims.reshape(t249, (1, 32, 512, 128))  # t252: "cuda:0 bf16[1, 32, 512, 128]"
  del t249
  t253 = torch.reshape(t250, (1, 32, 512, 128))  # t253: "cuda:0 bf16[1, 32, 512, 128]"
    # t253 = ltorch.reshape(t250, (1, 32, 512, 128))  # t253: "cuda:0 bf16[1, 32, 512, 128]"
      # t253 = prims.reshape(t250, (1, 32, 512, 128))  # t253: "cuda:0 bf16[1, 32, 512, 128]"
  del t250
  t254 = torch.reshape(t251, (1, 32, 512, 128))  # t254: "cuda:0 bf16[1, 32, 512, 128]"
    # t254 = ltorch.reshape(t251, (1, 32, 512, 128))  # t254: "cuda:0 bf16[1, 32, 512, 128]"
      # t254 = prims.reshape(t251, (1, 32, 512, 128))  # t254: "cuda:0 bf16[1, 32, 512, 128]"
  del t251
  t285 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t285: "cuda:0 bf16[1, 32, 512, 0]"
  t287 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t287: "cuda:0 bf16[1, 32, 512, 0]"
  t255 = torch_slice_prim_impl(t252, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t255: "cuda:0 bf16[1, 32, 512, 128]"
  del t252
  t270 = torch_slice_prim_impl(t253, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t270: "cuda:0 bf16[1, 32, 512, 128]"
  del t253
  t256 = torch_slice_prim_impl(t255, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t256: "cuda:0 bf16[1, 32, 512, 64]"
  t257 = torch_slice_prim_impl(t255, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t257: "cuda:0 bf16[1, 32, 512, 64]"
  t272 = torch_slice_prim_impl(t270, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t272: "cuda:0 bf16[1, 32, 512, 64]"
  t271 = torch_slice_prim_impl(t270, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t271: "cuda:0 bf16[1, 32, 512, 64]"
  [t260, t275] = nvFusion6(t255, t257, t270, t272)
    # t258 = prims.convert_element_type(t257, dtypes.float32)  # t258: "cuda:0 f32[1, 32, 512, 64]"
    # t259 = prims.neg(t258)  # t259: "cuda:0 f32[1, 32, 512, 64]"
    # t260 = prims.convert_element_type(t259, dtypes.bfloat16)  # t260: "cuda:0 bf16[1, 32, 512, 64]"
    # t273 = prims.convert_element_type(t272, dtypes.float32)  # t273: "cuda:0 f32[1, 32, 512, 64]"
    # t274 = prims.neg(t273)  # t274: "cuda:0 f32[1, 32, 512, 64]"
    # t275 = prims.convert_element_type(t274, dtypes.bfloat16)  # t275: "cuda:0 bf16[1, 32, 512, 64]"
  del t257, t272
  t261 = torch.cat((t260, t256), -1)  # t261: "cuda:0 bf16[1, 32, 512, 128]"
    # t261 = ltorch.cat((t260, t256), -1)  # t261: "cuda:0 bf16[1, 32, 512, 128]"
      # t261 = prims.cat((t260, t256), -1)  # t261: "cuda:0 bf16[1, 32, 512, 128]"
  del t260, t256
  t276 = torch.cat((t275, t271), -1)  # t276: "cuda:0 bf16[1, 32, 512, 128]"
    # t276 = ltorch.cat((t275, t271), -1)  # t276: "cuda:0 bf16[1, 32, 512, 128]"
      # t276 = prims.cat((t275, t271), -1)  # t276: "cuda:0 bf16[1, 32, 512, 128]"
  del t275, t271
  [t269, t284] = nvFusion7(t154, t157, t255, t261, t270, t276)
    # t263 = prims.convert_element_type(t255, dtypes.float32)  # t263: "cuda:0 f32[1, 32, 512, 128]"
    # t278 = prims.convert_element_type(t270, dtypes.float32)  # t278: "cuda:0 f32[1, 32, 512, 128]"
    # t264 = prims.mul(t263, t154)  # t264: "cuda:0 f32[1, 32, 512, 128]"
    # t266 = prims.convert_element_type(t261, dtypes.float32)  # t266: "cuda:0 f32[1, 32, 512, 128]"
    # t267 = prims.mul(t266, t157)  # t267: "cuda:0 f32[1, 32, 512, 128]"
    # t268 = prims.add(t264, t267)  # t268: "cuda:0 f32[1, 32, 512, 128]"
    # t269 = prims.convert_element_type(t268, dtypes.bfloat16)  # t269: "cuda:0 bf16[1, 32, 512, 128]"
    # t279 = prims.mul(t278, t154)  # t279: "cuda:0 f32[1, 32, 512, 128]"
    # t281 = prims.convert_element_type(t276, dtypes.float32)  # t281: "cuda:0 f32[1, 32, 512, 128]"
    # t282 = prims.mul(t281, t157)  # t282: "cuda:0 f32[1, 32, 512, 128]"
    # t283 = prims.add(t279, t282)  # t283: "cuda:0 f32[1, 32, 512, 128]"
    # t284 = prims.convert_element_type(t283, dtypes.bfloat16)  # t284: "cuda:0 bf16[1, 32, 512, 128]"
  del t255, t261, t270, t276
  t288 = torch.cat((t284, t287), -1)  # t288: "cuda:0 bf16[1, 32, 512, 128]"
    # t288 = ltorch.cat((t284, t287), -1)  # t288: "cuda:0 bf16[1, 32, 512, 128]"
      # t288 = prims.cat((t284, t287), -1)  # t288: "cuda:0 bf16[1, 32, 512, 128]"
  del t284, t287
  t286 = torch.cat((t269, t285), -1)  # t286: "cuda:0 bf16[1, 32, 512, 128]"
    # t286 = ltorch.cat((t269, t285), -1)  # t286: "cuda:0 bf16[1, 32, 512, 128]"
      # t286 = prims.cat((t269, t285), -1)  # t286: "cuda:0 bf16[1, 32, 512, 128]"
  del t269, t285
  (t289, t290, t291, t292, _, _, t293, t294, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t286, t288, t254, 0.0, True, scale=0.08838834764831843)
  t296 = torch.permute(t289, (0, 2, 1, 3))  # t296: "cuda:0 bf16[1, 512, 32, 128]"
    # t296 = ltorch.permute(t289, (0, 2, 1, 3))  # t296: "cuda:0 bf16[1, 512, 32, 128]"
      # t296 = prims.transpose(t289, (0, 2, 1, 3))  # t296: "cuda:0 bf16[1, 512, 32, 128]"
  t297 = torch.reshape(t296, (1, 512, 4096))  # t297: "cuda:0 bf16[1, 512, 4096]"
    # t297 = ltorch.reshape(t296, (1, 512, 4096))  # t297: "cuda:0 bf16[1, 512, 4096]"
      # t297 = prims.reshape(t296, (1, 512, 4096))  # t297: "cuda:0 bf16[1, 512, 4096]"
  del t296
  t298 = torch.nn.functional.linear(t297, t87, None)  # t298: "cuda:0 bf16[1, 512, 4096]"
    # t298 = ltorch.linear(t297, t87, None)  # t298: "cuda:0 bf16[1, 512, 4096]"
      # t298 = prims.linear(t297, t87, None)  # t298: "cuda:0 bf16[1, 512, 4096]"
  [t302, t309, t317] = nvFusion8(t230, t298, t313)
    # t300 = prims.convert_element_type(t230, dtypes.float32)  # t300: "cuda:0 f32[1, 512, 4096]"
    # t299 = prims.convert_element_type(t298, dtypes.float32)  # t299: "cuda:0 f32[1, 512, 4096]"
    # t301 = prims.add(t299, t300)  # t301: "cuda:0 f32[1, 512, 4096]"
    # t302 = prims.convert_element_type(t301, dtypes.bfloat16)  # t302: "cuda:0 bf16[1, 512, 4096]"
    # t304 = prims.mul(t301, t301)  # t304: "cuda:0 f32[1, 512, 4096]"
    # t305 = prims.sum(t304, (2,))  # t305: "cuda:0 f32[1, 512]"
    # t306 = prims.broadcast_in_dim(t305, [1, 512, 1], [0, 1])  # t306: "cuda:0 f32[1, 512, 1]"
    # t307 = prims.div(t306, 4096.0)  # t307: "cuda:0 f32[1, 512, 1]"
    # t308 = prims.add(t307, 1e-05)  # t308: "cuda:0 f32[1, 512, 1]"
    # t309 = prims.rsqrt(t308)  # t309: "cuda:0 f32[1, 512, 1]"
    # t310 = prims.broadcast_in_dim(t309, (1, 512, 4096), (0, 1, 2))  # t310: "cuda:0 f32[1, 512, 4096]"
    # t311 = prims.mul(t301, t310)  # t311: "cuda:0 f32[1, 512, 4096]"
    # t315 = prims.convert_element_type(t313, dtypes.float32)  # t315: "cuda:0 f32[1, 512, 4096]"
    # t316 = prims.mul(t311, t315)  # t316: "cuda:0 f32[1, 512, 4096]"
    # t317 = prims.convert_element_type(t316, dtypes.bfloat16)  # t317: "cuda:0 bf16[1, 512, 4096]"
  t318 = torch.nn.functional.linear(t317, t20, None)  # t318: "cuda:0 bf16[1, 512, 11008]"
    # t318 = ltorch.linear(t317, t20, None)  # t318: "cuda:0 bf16[1, 512, 11008]"
      # t318 = prims.linear(t317, t20, None)  # t318: "cuda:0 bf16[1, 512, 11008]"
  t319 = torch.nn.functional.linear(t317, t36, None)  # t319: "cuda:0 bf16[1, 512, 11008]"
    # t319 = ltorch.linear(t317, t36, None)  # t319: "cuda:0 bf16[1, 512, 11008]"
      # t319 = prims.linear(t317, t36, None)  # t319: "cuda:0 bf16[1, 512, 11008]"
  [t333] = nvFusion9(t318, t319)
    # t320 = prims.convert_element_type(t318, dtypes.float32)  # t320: "cuda:0 f32[1, 512, 11008]"
    # t321 = prims.neg(t320)  # t321: "cuda:0 f32[1, 512, 11008]"
    # t322 = prims.exp(t321)  # t322: "cuda:0 f32[1, 512, 11008]"
    # t323 = prims.add(1.0, t322)  # t323: "cuda:0 f32[1, 512, 11008]"
    # t324 = prims.reciprocal(t323)  # t324: "cuda:0 f32[1, 512, 11008]"
    # t328 = prims.mul(t320, t324)  # t328: "cuda:0 f32[1, 512, 11008]"
    # t331 = prims.convert_element_type(t319, dtypes.float32)  # t331: "cuda:0 f32[1, 512, 11008]"
    # t332 = prims.mul(t328, t331)  # t332: "cuda:0 f32[1, 512, 11008]"
    # t333 = prims.convert_element_type(t332, dtypes.bfloat16)  # t333: "cuda:0 bf16[1, 512, 11008]"
  t334 = torch.nn.functional.linear(t333, t88, None)  # t334: "cuda:0 bf16[1, 512, 4096]"
    # t334 = ltorch.linear(t333, t88, None)  # t334: "cuda:0 bf16[1, 512, 4096]"
      # t334 = prims.linear(t333, t88, None)  # t334: "cuda:0 bf16[1, 512, 4096]"
  [t338, t345, t353] = nvFusion10(t302, t334, t349)
    # t336 = prims.convert_element_type(t302, dtypes.float32)  # t336: "cuda:0 f32[1, 512, 4096]"
    # t335 = prims.convert_element_type(t334, dtypes.float32)  # t335: "cuda:0 f32[1, 512, 4096]"
    # t337 = prims.add(t335, t336)  # t337: "cuda:0 f32[1, 512, 4096]"
    # t338 = prims.convert_element_type(t337, dtypes.bfloat16)  # t338: "cuda:0 bf16[1, 512, 4096]"
    # t340 = prims.mul(t337, t337)  # t340: "cuda:0 f32[1, 512, 4096]"
    # t341 = prims.sum(t340, (2,))  # t341: "cuda:0 f32[1, 512]"
    # t342 = prims.broadcast_in_dim(t341, [1, 512, 1], [0, 1])  # t342: "cuda:0 f32[1, 512, 1]"
    # t343 = prims.div(t342, 4096.0)  # t343: "cuda:0 f32[1, 512, 1]"
    # t344 = prims.add(t343, 1e-05)  # t344: "cuda:0 f32[1, 512, 1]"
    # t345 = prims.rsqrt(t344)  # t345: "cuda:0 f32[1, 512, 1]"
    # t346 = prims.broadcast_in_dim(t345, (1, 512, 4096), (0, 1, 2))  # t346: "cuda:0 f32[1, 512, 4096]"
    # t347 = prims.mul(t337, t346)  # t347: "cuda:0 f32[1, 512, 4096]"
    # t351 = prims.convert_element_type(t349, dtypes.float32)  # t351: "cuda:0 f32[1, 512, 4096]"
    # t352 = prims.mul(t347, t351)  # t352: "cuda:0 f32[1, 512, 4096]"
    # t353 = prims.convert_element_type(t352, dtypes.bfloat16)  # t353: "cuda:0 bf16[1, 512, 4096]"
  t354 = torch.nn.functional.linear(t353, t5, None)  # t354: "cuda:0 bf16[1, 512, 12288]"
    # t354 = ltorch.linear(t353, t5, None)  # t354: "cuda:0 bf16[1, 512, 12288]"
      # t354 = prims.linear(t353, t5, None)  # t354: "cuda:0 bf16[1, 512, 12288]"
  t355 = torch.reshape(t354, (1, 512, 32, 3, 128))  # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t355 = ltorch.reshape(t354, (1, 512, 32, 3, 128))  # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t355 = prims.reshape(t354, (1, 512, 32, 3, 128))  # t355: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t354
  t356 = torch.permute(t355, (0, 2, 3, 1, 4))  # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t356 = ltorch.permute(t355, (0, 2, 3, 1, 4))  # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t356 = prims.transpose(t355, (0, 2, 3, 1, 4))  # t356: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t355
  (t357, t358, t359) = torch.split(t356, (1, 1, 1), 2)
    # (t357, t358, t359) = ltorch.split(t356, (1, 1, 1), 2)
      # t357 = prims.slice_prim(t356, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t357: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t358 = prims.slice_prim(t356, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t358: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t359 = prims.slice_prim(t356, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t359: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t356
  t360 = torch.reshape(t357, (1, 32, 512, 128))  # t360: "cuda:0 bf16[1, 32, 512, 128]"
    # t360 = ltorch.reshape(t357, (1, 32, 512, 128))  # t360: "cuda:0 bf16[1, 32, 512, 128]"
      # t360 = prims.reshape(t357, (1, 32, 512, 128))  # t360: "cuda:0 bf16[1, 32, 512, 128]"
  del t357
  t361 = torch.reshape(t358, (1, 32, 512, 128))  # t361: "cuda:0 bf16[1, 32, 512, 128]"
    # t361 = ltorch.reshape(t358, (1, 32, 512, 128))  # t361: "cuda:0 bf16[1, 32, 512, 128]"
      # t361 = prims.reshape(t358, (1, 32, 512, 128))  # t361: "cuda:0 bf16[1, 32, 512, 128]"
  del t358
  t362 = torch.reshape(t359, (1, 32, 512, 128))  # t362: "cuda:0 bf16[1, 32, 512, 128]"
    # t362 = ltorch.reshape(t359, (1, 32, 512, 128))  # t362: "cuda:0 bf16[1, 32, 512, 128]"
      # t362 = prims.reshape(t359, (1, 32, 512, 128))  # t362: "cuda:0 bf16[1, 32, 512, 128]"
  del t359
  t363 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t363: "cuda:0 bf16[1, 32, 512, 128]"
  t378 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t378: "cuda:0 bf16[1, 32, 512, 128]"
  t393 = torch_slice_prim_impl(t360, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t393: "cuda:0 bf16[1, 32, 512, 0]"
  del t360
  t395 = torch_slice_prim_impl(t361, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t395: "cuda:0 bf16[1, 32, 512, 0]"
  del t361
  t364 = torch_slice_prim_impl(t363, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t364: "cuda:0 bf16[1, 32, 512, 64]"
  t365 = torch_slice_prim_impl(t363, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t365: "cuda:0 bf16[1, 32, 512, 64]"
  t379 = torch_slice_prim_impl(t378, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t379: "cuda:0 bf16[1, 32, 512, 64]"
  t380 = torch_slice_prim_impl(t378, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t380: "cuda:0 bf16[1, 32, 512, 64]"
  [t368, t383] = nvFusion11(t363, t365, t378, t380)
    # t366 = prims.convert_element_type(t365, dtypes.float32)  # t366: "cuda:0 f32[1, 32, 512, 64]"
    # t367 = prims.neg(t366)  # t367: "cuda:0 f32[1, 32, 512, 64]"
    # t368 = prims.convert_element_type(t367, dtypes.bfloat16)  # t368: "cuda:0 bf16[1, 32, 512, 64]"
    # t381 = prims.convert_element_type(t380, dtypes.float32)  # t381: "cuda:0 f32[1, 32, 512, 64]"
    # t382 = prims.neg(t381)  # t382: "cuda:0 f32[1, 32, 512, 64]"
    # t383 = prims.convert_element_type(t382, dtypes.bfloat16)  # t383: "cuda:0 bf16[1, 32, 512, 64]"
  del t365, t380
  t369 = torch.cat((t368, t364), -1)  # t369: "cuda:0 bf16[1, 32, 512, 128]"
    # t369 = ltorch.cat((t368, t364), -1)  # t369: "cuda:0 bf16[1, 32, 512, 128]"
      # t369 = prims.cat((t368, t364), -1)  # t369: "cuda:0 bf16[1, 32, 512, 128]"
  del t368, t364
  t384 = torch.cat((t383, t379), -1)  # t384: "cuda:0 bf16[1, 32, 512, 128]"
    # t384 = ltorch.cat((t383, t379), -1)  # t384: "cuda:0 bf16[1, 32, 512, 128]"
      # t384 = prims.cat((t383, t379), -1)  # t384: "cuda:0 bf16[1, 32, 512, 128]"
  del t383, t379
  [t377, t392] = nvFusion12(t154, t157, t363, t369, t378, t384)
    # t371 = prims.convert_element_type(t363, dtypes.float32)  # t371: "cuda:0 f32[1, 32, 512, 128]"
    # t386 = prims.convert_element_type(t378, dtypes.float32)  # t386: "cuda:0 f32[1, 32, 512, 128]"
    # t372 = prims.mul(t371, t154)  # t372: "cuda:0 f32[1, 32, 512, 128]"
    # t374 = prims.convert_element_type(t369, dtypes.float32)  # t374: "cuda:0 f32[1, 32, 512, 128]"
    # t375 = prims.mul(t374, t157)  # t375: "cuda:0 f32[1, 32, 512, 128]"
    # t376 = prims.add(t372, t375)  # t376: "cuda:0 f32[1, 32, 512, 128]"
    # t377 = prims.convert_element_type(t376, dtypes.bfloat16)  # t377: "cuda:0 bf16[1, 32, 512, 128]"
    # t387 = prims.mul(t386, t154)  # t387: "cuda:0 f32[1, 32, 512, 128]"
    # t389 = prims.convert_element_type(t384, dtypes.float32)  # t389: "cuda:0 f32[1, 32, 512, 128]"
    # t390 = prims.mul(t389, t157)  # t390: "cuda:0 f32[1, 32, 512, 128]"
    # t391 = prims.add(t387, t390)  # t391: "cuda:0 f32[1, 32, 512, 128]"
    # t392 = prims.convert_element_type(t391, dtypes.bfloat16)  # t392: "cuda:0 bf16[1, 32, 512, 128]"
  del t363, t369, t378, t384
  t394 = torch.cat((t377, t393), -1)  # t394: "cuda:0 bf16[1, 32, 512, 128]"
    # t394 = ltorch.cat((t377, t393), -1)  # t394: "cuda:0 bf16[1, 32, 512, 128]"
      # t394 = prims.cat((t377, t393), -1)  # t394: "cuda:0 bf16[1, 32, 512, 128]"
  del t377, t393
  t396 = torch.cat((t392, t395), -1)  # t396: "cuda:0 bf16[1, 32, 512, 128]"
    # t396 = ltorch.cat((t392, t395), -1)  # t396: "cuda:0 bf16[1, 32, 512, 128]"
      # t396 = prims.cat((t392, t395), -1)  # t396: "cuda:0 bf16[1, 32, 512, 128]"
  del t392, t395
  (t397, t398, t399, t400, _, _, t401, t402, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t394, t396, t362, 0.0, True, scale=0.08838834764831843)
  t404 = torch.permute(t397, (0, 2, 1, 3))  # t404: "cuda:0 bf16[1, 512, 32, 128]"
    # t404 = ltorch.permute(t397, (0, 2, 1, 3))  # t404: "cuda:0 bf16[1, 512, 32, 128]"
      # t404 = prims.transpose(t397, (0, 2, 1, 3))  # t404: "cuda:0 bf16[1, 512, 32, 128]"
  t405 = torch.reshape(t404, (1, 512, 4096))  # t405: "cuda:0 bf16[1, 512, 4096]"
    # t405 = ltorch.reshape(t404, (1, 512, 4096))  # t405: "cuda:0 bf16[1, 512, 4096]"
      # t405 = prims.reshape(t404, (1, 512, 4096))  # t405: "cuda:0 bf16[1, 512, 4096]"
  del t404
  t406 = torch.nn.functional.linear(t405, t89, None)  # t406: "cuda:0 bf16[1, 512, 4096]"
    # t406 = ltorch.linear(t405, t89, None)  # t406: "cuda:0 bf16[1, 512, 4096]"
      # t406 = prims.linear(t405, t89, None)  # t406: "cuda:0 bf16[1, 512, 4096]"
  [t410, t417, t425] = nvFusion13(t338, t406, t421)
    # t408 = prims.convert_element_type(t338, dtypes.float32)  # t408: "cuda:0 f32[1, 512, 4096]"
    # t407 = prims.convert_element_type(t406, dtypes.float32)  # t407: "cuda:0 f32[1, 512, 4096]"
    # t409 = prims.add(t407, t408)  # t409: "cuda:0 f32[1, 512, 4096]"
    # t410 = prims.convert_element_type(t409, dtypes.bfloat16)  # t410: "cuda:0 bf16[1, 512, 4096]"
    # t412 = prims.mul(t409, t409)  # t412: "cuda:0 f32[1, 512, 4096]"
    # t413 = prims.sum(t412, (2,))  # t413: "cuda:0 f32[1, 512]"
    # t414 = prims.broadcast_in_dim(t413, [1, 512, 1], [0, 1])  # t414: "cuda:0 f32[1, 512, 1]"
    # t415 = prims.div(t414, 4096.0)  # t415: "cuda:0 f32[1, 512, 1]"
    # t416 = prims.add(t415, 1e-05)  # t416: "cuda:0 f32[1, 512, 1]"
    # t417 = prims.rsqrt(t416)  # t417: "cuda:0 f32[1, 512, 1]"
    # t418 = prims.broadcast_in_dim(t417, (1, 512, 4096), (0, 1, 2))  # t418: "cuda:0 f32[1, 512, 4096]"
    # t419 = prims.mul(t409, t418)  # t419: "cuda:0 f32[1, 512, 4096]"
    # t423 = prims.convert_element_type(t421, dtypes.float32)  # t423: "cuda:0 f32[1, 512, 4096]"
    # t424 = prims.mul(t419, t423)  # t424: "cuda:0 f32[1, 512, 4096]"
    # t425 = prims.convert_element_type(t424, dtypes.bfloat16)  # t425: "cuda:0 bf16[1, 512, 4096]"
  t426 = torch.nn.functional.linear(t425, t21, None)  # t426: "cuda:0 bf16[1, 512, 11008]"
    # t426 = ltorch.linear(t425, t21, None)  # t426: "cuda:0 bf16[1, 512, 11008]"
      # t426 = prims.linear(t425, t21, None)  # t426: "cuda:0 bf16[1, 512, 11008]"
  t427 = torch.nn.functional.linear(t425, t37, None)  # t427: "cuda:0 bf16[1, 512, 11008]"
    # t427 = ltorch.linear(t425, t37, None)  # t427: "cuda:0 bf16[1, 512, 11008]"
      # t427 = prims.linear(t425, t37, None)  # t427: "cuda:0 bf16[1, 512, 11008]"
  [t441] = nvFusion14(t426, t427)
    # t428 = prims.convert_element_type(t426, dtypes.float32)  # t428: "cuda:0 f32[1, 512, 11008]"
    # t429 = prims.neg(t428)  # t429: "cuda:0 f32[1, 512, 11008]"
    # t430 = prims.exp(t429)  # t430: "cuda:0 f32[1, 512, 11008]"
    # t431 = prims.add(1.0, t430)  # t431: "cuda:0 f32[1, 512, 11008]"
    # t432 = prims.reciprocal(t431)  # t432: "cuda:0 f32[1, 512, 11008]"
    # t436 = prims.mul(t428, t432)  # t436: "cuda:0 f32[1, 512, 11008]"
    # t439 = prims.convert_element_type(t427, dtypes.float32)  # t439: "cuda:0 f32[1, 512, 11008]"
    # t440 = prims.mul(t436, t439)  # t440: "cuda:0 f32[1, 512, 11008]"
    # t441 = prims.convert_element_type(t440, dtypes.bfloat16)  # t441: "cuda:0 bf16[1, 512, 11008]"
  t442 = torch.nn.functional.linear(t441, t90, None)  # t442: "cuda:0 bf16[1, 512, 4096]"
    # t442 = ltorch.linear(t441, t90, None)  # t442: "cuda:0 bf16[1, 512, 4096]"
      # t442 = prims.linear(t441, t90, None)  # t442: "cuda:0 bf16[1, 512, 4096]"
  [t446, t453, t461] = nvFusion15(t410, t442, t457)
    # t444 = prims.convert_element_type(t410, dtypes.float32)  # t444: "cuda:0 f32[1, 512, 4096]"
    # t443 = prims.convert_element_type(t442, dtypes.float32)  # t443: "cuda:0 f32[1, 512, 4096]"
    # t445 = prims.add(t443, t444)  # t445: "cuda:0 f32[1, 512, 4096]"
    # t446 = prims.convert_element_type(t445, dtypes.bfloat16)  # t446: "cuda:0 bf16[1, 512, 4096]"
    # t448 = prims.mul(t445, t445)  # t448: "cuda:0 f32[1, 512, 4096]"
    # t449 = prims.sum(t448, (2,))  # t449: "cuda:0 f32[1, 512]"
    # t450 = prims.broadcast_in_dim(t449, [1, 512, 1], [0, 1])  # t450: "cuda:0 f32[1, 512, 1]"
    # t451 = prims.div(t450, 4096.0)  # t451: "cuda:0 f32[1, 512, 1]"
    # t452 = prims.add(t451, 1e-05)  # t452: "cuda:0 f32[1, 512, 1]"
    # t453 = prims.rsqrt(t452)  # t453: "cuda:0 f32[1, 512, 1]"
    # t454 = prims.broadcast_in_dim(t453, (1, 512, 4096), (0, 1, 2))  # t454: "cuda:0 f32[1, 512, 4096]"
    # t455 = prims.mul(t445, t454)  # t455: "cuda:0 f32[1, 512, 4096]"
    # t459 = prims.convert_element_type(t457, dtypes.float32)  # t459: "cuda:0 f32[1, 512, 4096]"
    # t460 = prims.mul(t455, t459)  # t460: "cuda:0 f32[1, 512, 4096]"
    # t461 = prims.convert_element_type(t460, dtypes.bfloat16)  # t461: "cuda:0 bf16[1, 512, 4096]"
  t462 = torch.nn.functional.linear(t461, t6, None)  # t462: "cuda:0 bf16[1, 512, 12288]"
    # t462 = ltorch.linear(t461, t6, None)  # t462: "cuda:0 bf16[1, 512, 12288]"
      # t462 = prims.linear(t461, t6, None)  # t462: "cuda:0 bf16[1, 512, 12288]"
  t463 = torch.reshape(t462, (1, 512, 32, 3, 128))  # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t463 = ltorch.reshape(t462, (1, 512, 32, 3, 128))  # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t463 = prims.reshape(t462, (1, 512, 32, 3, 128))  # t463: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t462
  t464 = torch.permute(t463, (0, 2, 3, 1, 4))  # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t464 = ltorch.permute(t463, (0, 2, 3, 1, 4))  # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t464 = prims.transpose(t463, (0, 2, 3, 1, 4))  # t464: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t463
  (t465, t466, t467) = torch.split(t464, (1, 1, 1), 2)
    # (t465, t466, t467) = ltorch.split(t464, (1, 1, 1), 2)
      # t465 = prims.slice_prim(t464, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t465: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t466 = prims.slice_prim(t464, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t466: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t467 = prims.slice_prim(t464, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t467: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t464
  t468 = torch.reshape(t465, (1, 32, 512, 128))  # t468: "cuda:0 bf16[1, 32, 512, 128]"
    # t468 = ltorch.reshape(t465, (1, 32, 512, 128))  # t468: "cuda:0 bf16[1, 32, 512, 128]"
      # t468 = prims.reshape(t465, (1, 32, 512, 128))  # t468: "cuda:0 bf16[1, 32, 512, 128]"
  del t465
  t469 = torch.reshape(t466, (1, 32, 512, 128))  # t469: "cuda:0 bf16[1, 32, 512, 128]"
    # t469 = ltorch.reshape(t466, (1, 32, 512, 128))  # t469: "cuda:0 bf16[1, 32, 512, 128]"
      # t469 = prims.reshape(t466, (1, 32, 512, 128))  # t469: "cuda:0 bf16[1, 32, 512, 128]"
  del t466
  t470 = torch.reshape(t467, (1, 32, 512, 128))  # t470: "cuda:0 bf16[1, 32, 512, 128]"
    # t470 = ltorch.reshape(t467, (1, 32, 512, 128))  # t470: "cuda:0 bf16[1, 32, 512, 128]"
      # t470 = prims.reshape(t467, (1, 32, 512, 128))  # t470: "cuda:0 bf16[1, 32, 512, 128]"
  del t467
  t471 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t471: "cuda:0 bf16[1, 32, 512, 128]"
  t486 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t486: "cuda:0 bf16[1, 32, 512, 128]"
  t501 = torch_slice_prim_impl(t468, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t501: "cuda:0 bf16[1, 32, 512, 0]"
  del t468
  t503 = torch_slice_prim_impl(t469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t503: "cuda:0 bf16[1, 32, 512, 0]"
  del t469
  t472 = torch_slice_prim_impl(t471, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t472: "cuda:0 bf16[1, 32, 512, 64]"
  t473 = torch_slice_prim_impl(t471, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t473: "cuda:0 bf16[1, 32, 512, 64]"
  t487 = torch_slice_prim_impl(t486, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t487: "cuda:0 bf16[1, 32, 512, 64]"
  t488 = torch_slice_prim_impl(t486, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t488: "cuda:0 bf16[1, 32, 512, 64]"
  [t476, t491] = nvFusion16(t471, t473, t486, t488)
    # t474 = prims.convert_element_type(t473, dtypes.float32)  # t474: "cuda:0 f32[1, 32, 512, 64]"
    # t475 = prims.neg(t474)  # t475: "cuda:0 f32[1, 32, 512, 64]"
    # t476 = prims.convert_element_type(t475, dtypes.bfloat16)  # t476: "cuda:0 bf16[1, 32, 512, 64]"
    # t489 = prims.convert_element_type(t488, dtypes.float32)  # t489: "cuda:0 f32[1, 32, 512, 64]"
    # t490 = prims.neg(t489)  # t490: "cuda:0 f32[1, 32, 512, 64]"
    # t491 = prims.convert_element_type(t490, dtypes.bfloat16)  # t491: "cuda:0 bf16[1, 32, 512, 64]"
  del t473, t488
  t477 = torch.cat((t476, t472), -1)  # t477: "cuda:0 bf16[1, 32, 512, 128]"
    # t477 = ltorch.cat((t476, t472), -1)  # t477: "cuda:0 bf16[1, 32, 512, 128]"
      # t477 = prims.cat((t476, t472), -1)  # t477: "cuda:0 bf16[1, 32, 512, 128]"
  del t476, t472
  t492 = torch.cat((t491, t487), -1)  # t492: "cuda:0 bf16[1, 32, 512, 128]"
    # t492 = ltorch.cat((t491, t487), -1)  # t492: "cuda:0 bf16[1, 32, 512, 128]"
      # t492 = prims.cat((t491, t487), -1)  # t492: "cuda:0 bf16[1, 32, 512, 128]"
  del t491, t487
  [t485, t500] = nvFusion17(t154, t157, t471, t477, t486, t492)
    # t479 = prims.convert_element_type(t471, dtypes.float32)  # t479: "cuda:0 f32[1, 32, 512, 128]"
    # t494 = prims.convert_element_type(t486, dtypes.float32)  # t494: "cuda:0 f32[1, 32, 512, 128]"
    # t480 = prims.mul(t479, t154)  # t480: "cuda:0 f32[1, 32, 512, 128]"
    # t482 = prims.convert_element_type(t477, dtypes.float32)  # t482: "cuda:0 f32[1, 32, 512, 128]"
    # t483 = prims.mul(t482, t157)  # t483: "cuda:0 f32[1, 32, 512, 128]"
    # t484 = prims.add(t480, t483)  # t484: "cuda:0 f32[1, 32, 512, 128]"
    # t485 = prims.convert_element_type(t484, dtypes.bfloat16)  # t485: "cuda:0 bf16[1, 32, 512, 128]"
    # t495 = prims.mul(t494, t154)  # t495: "cuda:0 f32[1, 32, 512, 128]"
    # t497 = prims.convert_element_type(t492, dtypes.float32)  # t497: "cuda:0 f32[1, 32, 512, 128]"
    # t498 = prims.mul(t497, t157)  # t498: "cuda:0 f32[1, 32, 512, 128]"
    # t499 = prims.add(t495, t498)  # t499: "cuda:0 f32[1, 32, 512, 128]"
    # t500 = prims.convert_element_type(t499, dtypes.bfloat16)  # t500: "cuda:0 bf16[1, 32, 512, 128]"
  del t471, t477, t486, t492
  t502 = torch.cat((t485, t501), -1)  # t502: "cuda:0 bf16[1, 32, 512, 128]"
    # t502 = ltorch.cat((t485, t501), -1)  # t502: "cuda:0 bf16[1, 32, 512, 128]"
      # t502 = prims.cat((t485, t501), -1)  # t502: "cuda:0 bf16[1, 32, 512, 128]"
  del t485, t501
  t504 = torch.cat((t500, t503), -1)  # t504: "cuda:0 bf16[1, 32, 512, 128]"
    # t504 = ltorch.cat((t500, t503), -1)  # t504: "cuda:0 bf16[1, 32, 512, 128]"
      # t504 = prims.cat((t500, t503), -1)  # t504: "cuda:0 bf16[1, 32, 512, 128]"
  del t500, t503
  (t505, t506, t507, t508, _, _, t509, t510, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t502, t504, t470, 0.0, True, scale=0.08838834764831843)
  t512 = torch.permute(t505, (0, 2, 1, 3))  # t512: "cuda:0 bf16[1, 512, 32, 128]"
    # t512 = ltorch.permute(t505, (0, 2, 1, 3))  # t512: "cuda:0 bf16[1, 512, 32, 128]"
      # t512 = prims.transpose(t505, (0, 2, 1, 3))  # t512: "cuda:0 bf16[1, 512, 32, 128]"
  t513 = torch.reshape(t512, (1, 512, 4096))  # t513: "cuda:0 bf16[1, 512, 4096]"
    # t513 = ltorch.reshape(t512, (1, 512, 4096))  # t513: "cuda:0 bf16[1, 512, 4096]"
      # t513 = prims.reshape(t512, (1, 512, 4096))  # t513: "cuda:0 bf16[1, 512, 4096]"
  del t512
  t514 = torch.nn.functional.linear(t513, t91, None)  # t514: "cuda:0 bf16[1, 512, 4096]"
    # t514 = ltorch.linear(t513, t91, None)  # t514: "cuda:0 bf16[1, 512, 4096]"
      # t514 = prims.linear(t513, t91, None)  # t514: "cuda:0 bf16[1, 512, 4096]"
  [t518, t525, t533] = nvFusion18(t446, t514, t529)
    # t516 = prims.convert_element_type(t446, dtypes.float32)  # t516: "cuda:0 f32[1, 512, 4096]"
    # t515 = prims.convert_element_type(t514, dtypes.float32)  # t515: "cuda:0 f32[1, 512, 4096]"
    # t517 = prims.add(t515, t516)  # t517: "cuda:0 f32[1, 512, 4096]"
    # t518 = prims.convert_element_type(t517, dtypes.bfloat16)  # t518: "cuda:0 bf16[1, 512, 4096]"
    # t520 = prims.mul(t517, t517)  # t520: "cuda:0 f32[1, 512, 4096]"
    # t521 = prims.sum(t520, (2,))  # t521: "cuda:0 f32[1, 512]"
    # t522 = prims.broadcast_in_dim(t521, [1, 512, 1], [0, 1])  # t522: "cuda:0 f32[1, 512, 1]"
    # t523 = prims.div(t522, 4096.0)  # t523: "cuda:0 f32[1, 512, 1]"
    # t524 = prims.add(t523, 1e-05)  # t524: "cuda:0 f32[1, 512, 1]"
    # t525 = prims.rsqrt(t524)  # t525: "cuda:0 f32[1, 512, 1]"
    # t526 = prims.broadcast_in_dim(t525, (1, 512, 4096), (0, 1, 2))  # t526: "cuda:0 f32[1, 512, 4096]"
    # t527 = prims.mul(t517, t526)  # t527: "cuda:0 f32[1, 512, 4096]"
    # t531 = prims.convert_element_type(t529, dtypes.float32)  # t531: "cuda:0 f32[1, 512, 4096]"
    # t532 = prims.mul(t527, t531)  # t532: "cuda:0 f32[1, 512, 4096]"
    # t533 = prims.convert_element_type(t532, dtypes.bfloat16)  # t533: "cuda:0 bf16[1, 512, 4096]"
  t534 = torch.nn.functional.linear(t533, t22, None)  # t534: "cuda:0 bf16[1, 512, 11008]"
    # t534 = ltorch.linear(t533, t22, None)  # t534: "cuda:0 bf16[1, 512, 11008]"
      # t534 = prims.linear(t533, t22, None)  # t534: "cuda:0 bf16[1, 512, 11008]"
  t535 = torch.nn.functional.linear(t533, t38, None)  # t535: "cuda:0 bf16[1, 512, 11008]"
    # t535 = ltorch.linear(t533, t38, None)  # t535: "cuda:0 bf16[1, 512, 11008]"
      # t535 = prims.linear(t533, t38, None)  # t535: "cuda:0 bf16[1, 512, 11008]"
  [t549] = nvFusion19(t534, t535)
    # t536 = prims.convert_element_type(t534, dtypes.float32)  # t536: "cuda:0 f32[1, 512, 11008]"
    # t537 = prims.neg(t536)  # t537: "cuda:0 f32[1, 512, 11008]"
    # t538 = prims.exp(t537)  # t538: "cuda:0 f32[1, 512, 11008]"
    # t539 = prims.add(1.0, t538)  # t539: "cuda:0 f32[1, 512, 11008]"
    # t540 = prims.reciprocal(t539)  # t540: "cuda:0 f32[1, 512, 11008]"
    # t544 = prims.mul(t536, t540)  # t544: "cuda:0 f32[1, 512, 11008]"
    # t547 = prims.convert_element_type(t535, dtypes.float32)  # t547: "cuda:0 f32[1, 512, 11008]"
    # t548 = prims.mul(t544, t547)  # t548: "cuda:0 f32[1, 512, 11008]"
    # t549 = prims.convert_element_type(t548, dtypes.bfloat16)  # t549: "cuda:0 bf16[1, 512, 11008]"
  t550 = torch.nn.functional.linear(t549, t92, None)  # t550: "cuda:0 bf16[1, 512, 4096]"
    # t550 = ltorch.linear(t549, t92, None)  # t550: "cuda:0 bf16[1, 512, 4096]"
      # t550 = prims.linear(t549, t92, None)  # t550: "cuda:0 bf16[1, 512, 4096]"
  [t554, t561, t569] = nvFusion20(t518, t550, t565)
    # t552 = prims.convert_element_type(t518, dtypes.float32)  # t552: "cuda:0 f32[1, 512, 4096]"
    # t551 = prims.convert_element_type(t550, dtypes.float32)  # t551: "cuda:0 f32[1, 512, 4096]"
    # t553 = prims.add(t551, t552)  # t553: "cuda:0 f32[1, 512, 4096]"
    # t554 = prims.convert_element_type(t553, dtypes.bfloat16)  # t554: "cuda:0 bf16[1, 512, 4096]"
    # t556 = prims.mul(t553, t553)  # t556: "cuda:0 f32[1, 512, 4096]"
    # t557 = prims.sum(t556, (2,))  # t557: "cuda:0 f32[1, 512]"
    # t558 = prims.broadcast_in_dim(t557, [1, 512, 1], [0, 1])  # t558: "cuda:0 f32[1, 512, 1]"
    # t559 = prims.div(t558, 4096.0)  # t559: "cuda:0 f32[1, 512, 1]"
    # t560 = prims.add(t559, 1e-05)  # t560: "cuda:0 f32[1, 512, 1]"
    # t561 = prims.rsqrt(t560)  # t561: "cuda:0 f32[1, 512, 1]"
    # t562 = prims.broadcast_in_dim(t561, (1, 512, 4096), (0, 1, 2))  # t562: "cuda:0 f32[1, 512, 4096]"
    # t563 = prims.mul(t553, t562)  # t563: "cuda:0 f32[1, 512, 4096]"
    # t567 = prims.convert_element_type(t565, dtypes.float32)  # t567: "cuda:0 f32[1, 512, 4096]"
    # t568 = prims.mul(t563, t567)  # t568: "cuda:0 f32[1, 512, 4096]"
    # t569 = prims.convert_element_type(t568, dtypes.bfloat16)  # t569: "cuda:0 bf16[1, 512, 4096]"
  t570 = torch.nn.functional.linear(t569, t7, None)  # t570: "cuda:0 bf16[1, 512, 12288]"
    # t570 = ltorch.linear(t569, t7, None)  # t570: "cuda:0 bf16[1, 512, 12288]"
      # t570 = prims.linear(t569, t7, None)  # t570: "cuda:0 bf16[1, 512, 12288]"
  t571 = torch.reshape(t570, (1, 512, 32, 3, 128))  # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t571 = ltorch.reshape(t570, (1, 512, 32, 3, 128))  # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t571 = prims.reshape(t570, (1, 512, 32, 3, 128))  # t571: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t570
  t572 = torch.permute(t571, (0, 2, 3, 1, 4))  # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t572 = ltorch.permute(t571, (0, 2, 3, 1, 4))  # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t572 = prims.transpose(t571, (0, 2, 3, 1, 4))  # t572: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t571
  (t573, t574, t575) = torch.split(t572, (1, 1, 1), 2)
    # (t573, t574, t575) = ltorch.split(t572, (1, 1, 1), 2)
      # t573 = prims.slice_prim(t572, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t573: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t574 = prims.slice_prim(t572, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t574: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t575 = prims.slice_prim(t572, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t575: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t572
  t576 = torch.reshape(t573, (1, 32, 512, 128))  # t576: "cuda:0 bf16[1, 32, 512, 128]"
    # t576 = ltorch.reshape(t573, (1, 32, 512, 128))  # t576: "cuda:0 bf16[1, 32, 512, 128]"
      # t576 = prims.reshape(t573, (1, 32, 512, 128))  # t576: "cuda:0 bf16[1, 32, 512, 128]"
  del t573
  t577 = torch.reshape(t574, (1, 32, 512, 128))  # t577: "cuda:0 bf16[1, 32, 512, 128]"
    # t577 = ltorch.reshape(t574, (1, 32, 512, 128))  # t577: "cuda:0 bf16[1, 32, 512, 128]"
      # t577 = prims.reshape(t574, (1, 32, 512, 128))  # t577: "cuda:0 bf16[1, 32, 512, 128]"
  del t574
  t578 = torch.reshape(t575, (1, 32, 512, 128))  # t578: "cuda:0 bf16[1, 32, 512, 128]"
    # t578 = ltorch.reshape(t575, (1, 32, 512, 128))  # t578: "cuda:0 bf16[1, 32, 512, 128]"
      # t578 = prims.reshape(t575, (1, 32, 512, 128))  # t578: "cuda:0 bf16[1, 32, 512, 128]"
  del t575
  t579 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t579: "cuda:0 bf16[1, 32, 512, 128]"
  t594 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t594: "cuda:0 bf16[1, 32, 512, 128]"
  t609 = torch_slice_prim_impl(t576, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t609: "cuda:0 bf16[1, 32, 512, 0]"
  del t576
  t611 = torch_slice_prim_impl(t577, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t611: "cuda:0 bf16[1, 32, 512, 0]"
  del t577
  t580 = torch_slice_prim_impl(t579, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t580: "cuda:0 bf16[1, 32, 512, 64]"
  t581 = torch_slice_prim_impl(t579, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t581: "cuda:0 bf16[1, 32, 512, 64]"
  t595 = torch_slice_prim_impl(t594, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t595: "cuda:0 bf16[1, 32, 512, 64]"
  t596 = torch_slice_prim_impl(t594, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t596: "cuda:0 bf16[1, 32, 512, 64]"
  [t584, t599] = nvFusion21(t579, t581, t594, t596)
    # t582 = prims.convert_element_type(t581, dtypes.float32)  # t582: "cuda:0 f32[1, 32, 512, 64]"
    # t583 = prims.neg(t582)  # t583: "cuda:0 f32[1, 32, 512, 64]"
    # t584 = prims.convert_element_type(t583, dtypes.bfloat16)  # t584: "cuda:0 bf16[1, 32, 512, 64]"
    # t597 = prims.convert_element_type(t596, dtypes.float32)  # t597: "cuda:0 f32[1, 32, 512, 64]"
    # t598 = prims.neg(t597)  # t598: "cuda:0 f32[1, 32, 512, 64]"
    # t599 = prims.convert_element_type(t598, dtypes.bfloat16)  # t599: "cuda:0 bf16[1, 32, 512, 64]"
  del t581, t596
  t600 = torch.cat((t599, t595), -1)  # t600: "cuda:0 bf16[1, 32, 512, 128]"
    # t600 = ltorch.cat((t599, t595), -1)  # t600: "cuda:0 bf16[1, 32, 512, 128]"
      # t600 = prims.cat((t599, t595), -1)  # t600: "cuda:0 bf16[1, 32, 512, 128]"
  del t599, t595
  t585 = torch.cat((t584, t580), -1)  # t585: "cuda:0 bf16[1, 32, 512, 128]"
    # t585 = ltorch.cat((t584, t580), -1)  # t585: "cuda:0 bf16[1, 32, 512, 128]"
      # t585 = prims.cat((t584, t580), -1)  # t585: "cuda:0 bf16[1, 32, 512, 128]"
  del t584, t580
  [t593, t608] = nvFusion22(t154, t157, t579, t585, t594, t600)
    # t587 = prims.convert_element_type(t579, dtypes.float32)  # t587: "cuda:0 f32[1, 32, 512, 128]"
    # t602 = prims.convert_element_type(t594, dtypes.float32)  # t602: "cuda:0 f32[1, 32, 512, 128]"
    # t603 = prims.mul(t602, t154)  # t603: "cuda:0 f32[1, 32, 512, 128]"
    # t605 = prims.convert_element_type(t600, dtypes.float32)  # t605: "cuda:0 f32[1, 32, 512, 128]"
    # t606 = prims.mul(t605, t157)  # t606: "cuda:0 f32[1, 32, 512, 128]"
    # t607 = prims.add(t603, t606)  # t607: "cuda:0 f32[1, 32, 512, 128]"
    # t608 = prims.convert_element_type(t607, dtypes.bfloat16)  # t608: "cuda:0 bf16[1, 32, 512, 128]"
    # t588 = prims.mul(t587, t154)  # t588: "cuda:0 f32[1, 32, 512, 128]"
    # t590 = prims.convert_element_type(t585, dtypes.float32)  # t590: "cuda:0 f32[1, 32, 512, 128]"
    # t591 = prims.mul(t590, t157)  # t591: "cuda:0 f32[1, 32, 512, 128]"
    # t592 = prims.add(t588, t591)  # t592: "cuda:0 f32[1, 32, 512, 128]"
    # t593 = prims.convert_element_type(t592, dtypes.bfloat16)  # t593: "cuda:0 bf16[1, 32, 512, 128]"
  del t579, t585, t594, t600
  t612 = torch.cat((t608, t611), -1)  # t612: "cuda:0 bf16[1, 32, 512, 128]"
    # t612 = ltorch.cat((t608, t611), -1)  # t612: "cuda:0 bf16[1, 32, 512, 128]"
      # t612 = prims.cat((t608, t611), -1)  # t612: "cuda:0 bf16[1, 32, 512, 128]"
  del t608, t611
  t610 = torch.cat((t593, t609), -1)  # t610: "cuda:0 bf16[1, 32, 512, 128]"
    # t610 = ltorch.cat((t593, t609), -1)  # t610: "cuda:0 bf16[1, 32, 512, 128]"
      # t610 = prims.cat((t593, t609), -1)  # t610: "cuda:0 bf16[1, 32, 512, 128]"
  del t593, t609
  (t613, t614, t615, t616, _, _, t617, t618, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t610, t612, t578, 0.0, True, scale=0.08838834764831843)
  t620 = torch.permute(t613, (0, 2, 1, 3))  # t620: "cuda:0 bf16[1, 512, 32, 128]"
    # t620 = ltorch.permute(t613, (0, 2, 1, 3))  # t620: "cuda:0 bf16[1, 512, 32, 128]"
      # t620 = prims.transpose(t613, (0, 2, 1, 3))  # t620: "cuda:0 bf16[1, 512, 32, 128]"
  t621 = torch.reshape(t620, (1, 512, 4096))  # t621: "cuda:0 bf16[1, 512, 4096]"
    # t621 = ltorch.reshape(t620, (1, 512, 4096))  # t621: "cuda:0 bf16[1, 512, 4096]"
      # t621 = prims.reshape(t620, (1, 512, 4096))  # t621: "cuda:0 bf16[1, 512, 4096]"
  del t620
  t622 = torch.nn.functional.linear(t621, t93, None)  # t622: "cuda:0 bf16[1, 512, 4096]"
    # t622 = ltorch.linear(t621, t93, None)  # t622: "cuda:0 bf16[1, 512, 4096]"
      # t622 = prims.linear(t621, t93, None)  # t622: "cuda:0 bf16[1, 512, 4096]"
  [t626, t633, t641] = nvFusion23(t554, t622, t637)
    # t624 = prims.convert_element_type(t554, dtypes.float32)  # t624: "cuda:0 f32[1, 512, 4096]"
    # t623 = prims.convert_element_type(t622, dtypes.float32)  # t623: "cuda:0 f32[1, 512, 4096]"
    # t625 = prims.add(t623, t624)  # t625: "cuda:0 f32[1, 512, 4096]"
    # t626 = prims.convert_element_type(t625, dtypes.bfloat16)  # t626: "cuda:0 bf16[1, 512, 4096]"
    # t628 = prims.mul(t625, t625)  # t628: "cuda:0 f32[1, 512, 4096]"
    # t629 = prims.sum(t628, (2,))  # t629: "cuda:0 f32[1, 512]"
    # t630 = prims.broadcast_in_dim(t629, [1, 512, 1], [0, 1])  # t630: "cuda:0 f32[1, 512, 1]"
    # t631 = prims.div(t630, 4096.0)  # t631: "cuda:0 f32[1, 512, 1]"
    # t632 = prims.add(t631, 1e-05)  # t632: "cuda:0 f32[1, 512, 1]"
    # t633 = prims.rsqrt(t632)  # t633: "cuda:0 f32[1, 512, 1]"
    # t634 = prims.broadcast_in_dim(t633, (1, 512, 4096), (0, 1, 2))  # t634: "cuda:0 f32[1, 512, 4096]"
    # t635 = prims.mul(t625, t634)  # t635: "cuda:0 f32[1, 512, 4096]"
    # t639 = prims.convert_element_type(t637, dtypes.float32)  # t639: "cuda:0 f32[1, 512, 4096]"
    # t640 = prims.mul(t635, t639)  # t640: "cuda:0 f32[1, 512, 4096]"
    # t641 = prims.convert_element_type(t640, dtypes.bfloat16)  # t641: "cuda:0 bf16[1, 512, 4096]"
  t643 = torch.nn.functional.linear(t641, t39, None)  # t643: "cuda:0 bf16[1, 512, 11008]"
    # t643 = ltorch.linear(t641, t39, None)  # t643: "cuda:0 bf16[1, 512, 11008]"
      # t643 = prims.linear(t641, t39, None)  # t643: "cuda:0 bf16[1, 512, 11008]"
  t642 = torch.nn.functional.linear(t641, t23, None)  # t642: "cuda:0 bf16[1, 512, 11008]"
    # t642 = ltorch.linear(t641, t23, None)  # t642: "cuda:0 bf16[1, 512, 11008]"
      # t642 = prims.linear(t641, t23, None)  # t642: "cuda:0 bf16[1, 512, 11008]"
  [t657] = nvFusion24(t642, t643)
    # t644 = prims.convert_element_type(t642, dtypes.float32)  # t644: "cuda:0 f32[1, 512, 11008]"
    # t645 = prims.neg(t644)  # t645: "cuda:0 f32[1, 512, 11008]"
    # t646 = prims.exp(t645)  # t646: "cuda:0 f32[1, 512, 11008]"
    # t647 = prims.add(1.0, t646)  # t647: "cuda:0 f32[1, 512, 11008]"
    # t648 = prims.reciprocal(t647)  # t648: "cuda:0 f32[1, 512, 11008]"
    # t652 = prims.mul(t644, t648)  # t652: "cuda:0 f32[1, 512, 11008]"
    # t655 = prims.convert_element_type(t643, dtypes.float32)  # t655: "cuda:0 f32[1, 512, 11008]"
    # t656 = prims.mul(t652, t655)  # t656: "cuda:0 f32[1, 512, 11008]"
    # t657 = prims.convert_element_type(t656, dtypes.bfloat16)  # t657: "cuda:0 bf16[1, 512, 11008]"
  t658 = torch.nn.functional.linear(t657, t94, None)  # t658: "cuda:0 bf16[1, 512, 4096]"
    # t658 = ltorch.linear(t657, t94, None)  # t658: "cuda:0 bf16[1, 512, 4096]"
      # t658 = prims.linear(t657, t94, None)  # t658: "cuda:0 bf16[1, 512, 4096]"
  [t662, t669, t677] = nvFusion25(t626, t658, t673)
    # t660 = prims.convert_element_type(t626, dtypes.float32)  # t660: "cuda:0 f32[1, 512, 4096]"
    # t659 = prims.convert_element_type(t658, dtypes.float32)  # t659: "cuda:0 f32[1, 512, 4096]"
    # t661 = prims.add(t659, t660)  # t661: "cuda:0 f32[1, 512, 4096]"
    # t662 = prims.convert_element_type(t661, dtypes.bfloat16)  # t662: "cuda:0 bf16[1, 512, 4096]"
    # t664 = prims.mul(t661, t661)  # t664: "cuda:0 f32[1, 512, 4096]"
    # t665 = prims.sum(t664, (2,))  # t665: "cuda:0 f32[1, 512]"
    # t666 = prims.broadcast_in_dim(t665, [1, 512, 1], [0, 1])  # t666: "cuda:0 f32[1, 512, 1]"
    # t667 = prims.div(t666, 4096.0)  # t667: "cuda:0 f32[1, 512, 1]"
    # t668 = prims.add(t667, 1e-05)  # t668: "cuda:0 f32[1, 512, 1]"
    # t669 = prims.rsqrt(t668)  # t669: "cuda:0 f32[1, 512, 1]"
    # t670 = prims.broadcast_in_dim(t669, (1, 512, 4096), (0, 1, 2))  # t670: "cuda:0 f32[1, 512, 4096]"
    # t671 = prims.mul(t661, t670)  # t671: "cuda:0 f32[1, 512, 4096]"
    # t675 = prims.convert_element_type(t673, dtypes.float32)  # t675: "cuda:0 f32[1, 512, 4096]"
    # t676 = prims.mul(t671, t675)  # t676: "cuda:0 f32[1, 512, 4096]"
    # t677 = prims.convert_element_type(t676, dtypes.bfloat16)  # t677: "cuda:0 bf16[1, 512, 4096]"
  t678 = torch.nn.functional.linear(t677, t8, None)  # t678: "cuda:0 bf16[1, 512, 12288]"
    # t678 = ltorch.linear(t677, t8, None)  # t678: "cuda:0 bf16[1, 512, 12288]"
      # t678 = prims.linear(t677, t8, None)  # t678: "cuda:0 bf16[1, 512, 12288]"
  t679 = torch.reshape(t678, (1, 512, 32, 3, 128))  # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t679 = ltorch.reshape(t678, (1, 512, 32, 3, 128))  # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t679 = prims.reshape(t678, (1, 512, 32, 3, 128))  # t679: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t678
  t680 = torch.permute(t679, (0, 2, 3, 1, 4))  # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t680 = ltorch.permute(t679, (0, 2, 3, 1, 4))  # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t680 = prims.transpose(t679, (0, 2, 3, 1, 4))  # t680: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t679
  (t681, t682, t683) = torch.split(t680, (1, 1, 1), 2)
    # (t681, t682, t683) = ltorch.split(t680, (1, 1, 1), 2)
      # t681 = prims.slice_prim(t680, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t681: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t682 = prims.slice_prim(t680, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t682: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t683 = prims.slice_prim(t680, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t683: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t680
  t684 = torch.reshape(t681, (1, 32, 512, 128))  # t684: "cuda:0 bf16[1, 32, 512, 128]"
    # t684 = ltorch.reshape(t681, (1, 32, 512, 128))  # t684: "cuda:0 bf16[1, 32, 512, 128]"
      # t684 = prims.reshape(t681, (1, 32, 512, 128))  # t684: "cuda:0 bf16[1, 32, 512, 128]"
  del t681
  t685 = torch.reshape(t682, (1, 32, 512, 128))  # t685: "cuda:0 bf16[1, 32, 512, 128]"
    # t685 = ltorch.reshape(t682, (1, 32, 512, 128))  # t685: "cuda:0 bf16[1, 32, 512, 128]"
      # t685 = prims.reshape(t682, (1, 32, 512, 128))  # t685: "cuda:0 bf16[1, 32, 512, 128]"
  del t682
  t686 = torch.reshape(t683, (1, 32, 512, 128))  # t686: "cuda:0 bf16[1, 32, 512, 128]"
    # t686 = ltorch.reshape(t683, (1, 32, 512, 128))  # t686: "cuda:0 bf16[1, 32, 512, 128]"
      # t686 = prims.reshape(t683, (1, 32, 512, 128))  # t686: "cuda:0 bf16[1, 32, 512, 128]"
  del t683
  t687 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t687: "cuda:0 bf16[1, 32, 512, 128]"
  t702 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t702: "cuda:0 bf16[1, 32, 512, 128]"
  t717 = torch_slice_prim_impl(t684, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t717: "cuda:0 bf16[1, 32, 512, 0]"
  del t684
  t719 = torch_slice_prim_impl(t685, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t719: "cuda:0 bf16[1, 32, 512, 0]"
  del t685
  t688 = torch_slice_prim_impl(t687, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t688: "cuda:0 bf16[1, 32, 512, 64]"
  t689 = torch_slice_prim_impl(t687, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t689: "cuda:0 bf16[1, 32, 512, 64]"
  t703 = torch_slice_prim_impl(t702, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t703: "cuda:0 bf16[1, 32, 512, 64]"
  t704 = torch_slice_prim_impl(t702, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t704: "cuda:0 bf16[1, 32, 512, 64]"
  [t692, t707] = nvFusion26(t687, t689, t702, t704)
    # t690 = prims.convert_element_type(t689, dtypes.float32)  # t690: "cuda:0 f32[1, 32, 512, 64]"
    # t691 = prims.neg(t690)  # t691: "cuda:0 f32[1, 32, 512, 64]"
    # t692 = prims.convert_element_type(t691, dtypes.bfloat16)  # t692: "cuda:0 bf16[1, 32, 512, 64]"
    # t705 = prims.convert_element_type(t704, dtypes.float32)  # t705: "cuda:0 f32[1, 32, 512, 64]"
    # t706 = prims.neg(t705)  # t706: "cuda:0 f32[1, 32, 512, 64]"
    # t707 = prims.convert_element_type(t706, dtypes.bfloat16)  # t707: "cuda:0 bf16[1, 32, 512, 64]"
  del t689, t704
  t708 = torch.cat((t707, t703), -1)  # t708: "cuda:0 bf16[1, 32, 512, 128]"
    # t708 = ltorch.cat((t707, t703), -1)  # t708: "cuda:0 bf16[1, 32, 512, 128]"
      # t708 = prims.cat((t707, t703), -1)  # t708: "cuda:0 bf16[1, 32, 512, 128]"
  del t707, t703
  t693 = torch.cat((t692, t688), -1)  # t693: "cuda:0 bf16[1, 32, 512, 128]"
    # t693 = ltorch.cat((t692, t688), -1)  # t693: "cuda:0 bf16[1, 32, 512, 128]"
      # t693 = prims.cat((t692, t688), -1)  # t693: "cuda:0 bf16[1, 32, 512, 128]"
  del t692, t688
  [t701, t716] = nvFusion27(t154, t157, t687, t693, t702, t708)
    # t695 = prims.convert_element_type(t687, dtypes.float32)  # t695: "cuda:0 f32[1, 32, 512, 128]"
    # t710 = prims.convert_element_type(t702, dtypes.float32)  # t710: "cuda:0 f32[1, 32, 512, 128]"
    # t711 = prims.mul(t710, t154)  # t711: "cuda:0 f32[1, 32, 512, 128]"
    # t713 = prims.convert_element_type(t708, dtypes.float32)  # t713: "cuda:0 f32[1, 32, 512, 128]"
    # t714 = prims.mul(t713, t157)  # t714: "cuda:0 f32[1, 32, 512, 128]"
    # t715 = prims.add(t711, t714)  # t715: "cuda:0 f32[1, 32, 512, 128]"
    # t716 = prims.convert_element_type(t715, dtypes.bfloat16)  # t716: "cuda:0 bf16[1, 32, 512, 128]"
    # t696 = prims.mul(t695, t154)  # t696: "cuda:0 f32[1, 32, 512, 128]"
    # t698 = prims.convert_element_type(t693, dtypes.float32)  # t698: "cuda:0 f32[1, 32, 512, 128]"
    # t699 = prims.mul(t698, t157)  # t699: "cuda:0 f32[1, 32, 512, 128]"
    # t700 = prims.add(t696, t699)  # t700: "cuda:0 f32[1, 32, 512, 128]"
    # t701 = prims.convert_element_type(t700, dtypes.bfloat16)  # t701: "cuda:0 bf16[1, 32, 512, 128]"
  del t687, t693, t702, t708
  t720 = torch.cat((t716, t719), -1)  # t720: "cuda:0 bf16[1, 32, 512, 128]"
    # t720 = ltorch.cat((t716, t719), -1)  # t720: "cuda:0 bf16[1, 32, 512, 128]"
      # t720 = prims.cat((t716, t719), -1)  # t720: "cuda:0 bf16[1, 32, 512, 128]"
  del t716, t719
  t718 = torch.cat((t701, t717), -1)  # t718: "cuda:0 bf16[1, 32, 512, 128]"
    # t718 = ltorch.cat((t701, t717), -1)  # t718: "cuda:0 bf16[1, 32, 512, 128]"
      # t718 = prims.cat((t701, t717), -1)  # t718: "cuda:0 bf16[1, 32, 512, 128]"
  del t701, t717
  (t721, t722, t723, t724, _, _, t725, t726, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t718, t720, t686, 0.0, True, scale=0.08838834764831843)
  t728 = torch.permute(t721, (0, 2, 1, 3))  # t728: "cuda:0 bf16[1, 512, 32, 128]"
    # t728 = ltorch.permute(t721, (0, 2, 1, 3))  # t728: "cuda:0 bf16[1, 512, 32, 128]"
      # t728 = prims.transpose(t721, (0, 2, 1, 3))  # t728: "cuda:0 bf16[1, 512, 32, 128]"
  t729 = torch.reshape(t728, (1, 512, 4096))  # t729: "cuda:0 bf16[1, 512, 4096]"
    # t729 = ltorch.reshape(t728, (1, 512, 4096))  # t729: "cuda:0 bf16[1, 512, 4096]"
      # t729 = prims.reshape(t728, (1, 512, 4096))  # t729: "cuda:0 bf16[1, 512, 4096]"
  del t728
  t730 = torch.nn.functional.linear(t729, t95, None)  # t730: "cuda:0 bf16[1, 512, 4096]"
    # t730 = ltorch.linear(t729, t95, None)  # t730: "cuda:0 bf16[1, 512, 4096]"
      # t730 = prims.linear(t729, t95, None)  # t730: "cuda:0 bf16[1, 512, 4096]"
  [t734, t741, t749] = nvFusion28(t662, t730, t745)
    # t732 = prims.convert_element_type(t662, dtypes.float32)  # t732: "cuda:0 f32[1, 512, 4096]"
    # t731 = prims.convert_element_type(t730, dtypes.float32)  # t731: "cuda:0 f32[1, 512, 4096]"
    # t733 = prims.add(t731, t732)  # t733: "cuda:0 f32[1, 512, 4096]"
    # t734 = prims.convert_element_type(t733, dtypes.bfloat16)  # t734: "cuda:0 bf16[1, 512, 4096]"
    # t736 = prims.mul(t733, t733)  # t736: "cuda:0 f32[1, 512, 4096]"
    # t737 = prims.sum(t736, (2,))  # t737: "cuda:0 f32[1, 512]"
    # t738 = prims.broadcast_in_dim(t737, [1, 512, 1], [0, 1])  # t738: "cuda:0 f32[1, 512, 1]"
    # t739 = prims.div(t738, 4096.0)  # t739: "cuda:0 f32[1, 512, 1]"
    # t740 = prims.add(t739, 1e-05)  # t740: "cuda:0 f32[1, 512, 1]"
    # t741 = prims.rsqrt(t740)  # t741: "cuda:0 f32[1, 512, 1]"
    # t742 = prims.broadcast_in_dim(t741, (1, 512, 4096), (0, 1, 2))  # t742: "cuda:0 f32[1, 512, 4096]"
    # t743 = prims.mul(t733, t742)  # t743: "cuda:0 f32[1, 512, 4096]"
    # t747 = prims.convert_element_type(t745, dtypes.float32)  # t747: "cuda:0 f32[1, 512, 4096]"
    # t748 = prims.mul(t743, t747)  # t748: "cuda:0 f32[1, 512, 4096]"
    # t749 = prims.convert_element_type(t748, dtypes.bfloat16)  # t749: "cuda:0 bf16[1, 512, 4096]"
  t750 = torch.nn.functional.linear(t749, t24, None)  # t750: "cuda:0 bf16[1, 512, 11008]"
    # t750 = ltorch.linear(t749, t24, None)  # t750: "cuda:0 bf16[1, 512, 11008]"
      # t750 = prims.linear(t749, t24, None)  # t750: "cuda:0 bf16[1, 512, 11008]"
  t751 = torch.nn.functional.linear(t749, t40, None)  # t751: "cuda:0 bf16[1, 512, 11008]"
    # t751 = ltorch.linear(t749, t40, None)  # t751: "cuda:0 bf16[1, 512, 11008]"
      # t751 = prims.linear(t749, t40, None)  # t751: "cuda:0 bf16[1, 512, 11008]"
  [t765] = nvFusion29(t750, t751)
    # t752 = prims.convert_element_type(t750, dtypes.float32)  # t752: "cuda:0 f32[1, 512, 11008]"
    # t753 = prims.neg(t752)  # t753: "cuda:0 f32[1, 512, 11008]"
    # t754 = prims.exp(t753)  # t754: "cuda:0 f32[1, 512, 11008]"
    # t755 = prims.add(1.0, t754)  # t755: "cuda:0 f32[1, 512, 11008]"
    # t756 = prims.reciprocal(t755)  # t756: "cuda:0 f32[1, 512, 11008]"
    # t760 = prims.mul(t752, t756)  # t760: "cuda:0 f32[1, 512, 11008]"
    # t763 = prims.convert_element_type(t751, dtypes.float32)  # t763: "cuda:0 f32[1, 512, 11008]"
    # t764 = prims.mul(t760, t763)  # t764: "cuda:0 f32[1, 512, 11008]"
    # t765 = prims.convert_element_type(t764, dtypes.bfloat16)  # t765: "cuda:0 bf16[1, 512, 11008]"
  t766 = torch.nn.functional.linear(t765, t96, None)  # t766: "cuda:0 bf16[1, 512, 4096]"
    # t766 = ltorch.linear(t765, t96, None)  # t766: "cuda:0 bf16[1, 512, 4096]"
      # t766 = prims.linear(t765, t96, None)  # t766: "cuda:0 bf16[1, 512, 4096]"
  [t770, t777, t785] = nvFusion30(t734, t766, t781)
    # t768 = prims.convert_element_type(t734, dtypes.float32)  # t768: "cuda:0 f32[1, 512, 4096]"
    # t767 = prims.convert_element_type(t766, dtypes.float32)  # t767: "cuda:0 f32[1, 512, 4096]"
    # t769 = prims.add(t767, t768)  # t769: "cuda:0 f32[1, 512, 4096]"
    # t770 = prims.convert_element_type(t769, dtypes.bfloat16)  # t770: "cuda:0 bf16[1, 512, 4096]"
    # t772 = prims.mul(t769, t769)  # t772: "cuda:0 f32[1, 512, 4096]"
    # t773 = prims.sum(t772, (2,))  # t773: "cuda:0 f32[1, 512]"
    # t774 = prims.broadcast_in_dim(t773, [1, 512, 1], [0, 1])  # t774: "cuda:0 f32[1, 512, 1]"
    # t775 = prims.div(t774, 4096.0)  # t775: "cuda:0 f32[1, 512, 1]"
    # t776 = prims.add(t775, 1e-05)  # t776: "cuda:0 f32[1, 512, 1]"
    # t777 = prims.rsqrt(t776)  # t777: "cuda:0 f32[1, 512, 1]"
    # t778 = prims.broadcast_in_dim(t777, (1, 512, 4096), (0, 1, 2))  # t778: "cuda:0 f32[1, 512, 4096]"
    # t779 = prims.mul(t769, t778)  # t779: "cuda:0 f32[1, 512, 4096]"
    # t783 = prims.convert_element_type(t781, dtypes.float32)  # t783: "cuda:0 f32[1, 512, 4096]"
    # t784 = prims.mul(t779, t783)  # t784: "cuda:0 f32[1, 512, 4096]"
    # t785 = prims.convert_element_type(t784, dtypes.bfloat16)  # t785: "cuda:0 bf16[1, 512, 4096]"
  t786 = torch.nn.functional.linear(t785, t9, None)  # t786: "cuda:0 bf16[1, 512, 12288]"
    # t786 = ltorch.linear(t785, t9, None)  # t786: "cuda:0 bf16[1, 512, 12288]"
      # t786 = prims.linear(t785, t9, None)  # t786: "cuda:0 bf16[1, 512, 12288]"
  t787 = torch.reshape(t786, (1, 512, 32, 3, 128))  # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t787 = ltorch.reshape(t786, (1, 512, 32, 3, 128))  # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t787 = prims.reshape(t786, (1, 512, 32, 3, 128))  # t787: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t786
  t788 = torch.permute(t787, (0, 2, 3, 1, 4))  # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t788 = ltorch.permute(t787, (0, 2, 3, 1, 4))  # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t788 = prims.transpose(t787, (0, 2, 3, 1, 4))  # t788: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t787
  (t789, t790, t791) = torch.split(t788, (1, 1, 1), 2)
    # (t789, t790, t791) = ltorch.split(t788, (1, 1, 1), 2)
      # t789 = prims.slice_prim(t788, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t789: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t790 = prims.slice_prim(t788, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t790: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t791 = prims.slice_prim(t788, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t791: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t788
  t792 = torch.reshape(t789, (1, 32, 512, 128))  # t792: "cuda:0 bf16[1, 32, 512, 128]"
    # t792 = ltorch.reshape(t789, (1, 32, 512, 128))  # t792: "cuda:0 bf16[1, 32, 512, 128]"
      # t792 = prims.reshape(t789, (1, 32, 512, 128))  # t792: "cuda:0 bf16[1, 32, 512, 128]"
  del t789
  t793 = torch.reshape(t790, (1, 32, 512, 128))  # t793: "cuda:0 bf16[1, 32, 512, 128]"
    # t793 = ltorch.reshape(t790, (1, 32, 512, 128))  # t793: "cuda:0 bf16[1, 32, 512, 128]"
      # t793 = prims.reshape(t790, (1, 32, 512, 128))  # t793: "cuda:0 bf16[1, 32, 512, 128]"
  del t790
  t794 = torch.reshape(t791, (1, 32, 512, 128))  # t794: "cuda:0 bf16[1, 32, 512, 128]"
    # t794 = ltorch.reshape(t791, (1, 32, 512, 128))  # t794: "cuda:0 bf16[1, 32, 512, 128]"
      # t794 = prims.reshape(t791, (1, 32, 512, 128))  # t794: "cuda:0 bf16[1, 32, 512, 128]"
  del t791
  t795 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t795: "cuda:0 bf16[1, 32, 512, 128]"
  t810 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t810: "cuda:0 bf16[1, 32, 512, 128]"
  t825 = torch_slice_prim_impl(t792, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t825: "cuda:0 bf16[1, 32, 512, 0]"
  del t792
  t827 = torch_slice_prim_impl(t793, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t827: "cuda:0 bf16[1, 32, 512, 0]"
  del t793
  t796 = torch_slice_prim_impl(t795, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t796: "cuda:0 bf16[1, 32, 512, 64]"
  t797 = torch_slice_prim_impl(t795, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t797: "cuda:0 bf16[1, 32, 512, 64]"
  t811 = torch_slice_prim_impl(t810, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t811: "cuda:0 bf16[1, 32, 512, 64]"
  t812 = torch_slice_prim_impl(t810, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t812: "cuda:0 bf16[1, 32, 512, 64]"
  [t800, t815] = nvFusion31(t795, t797, t810, t812)
    # t798 = prims.convert_element_type(t797, dtypes.float32)  # t798: "cuda:0 f32[1, 32, 512, 64]"
    # t799 = prims.neg(t798)  # t799: "cuda:0 f32[1, 32, 512, 64]"
    # t800 = prims.convert_element_type(t799, dtypes.bfloat16)  # t800: "cuda:0 bf16[1, 32, 512, 64]"
    # t813 = prims.convert_element_type(t812, dtypes.float32)  # t813: "cuda:0 f32[1, 32, 512, 64]"
    # t814 = prims.neg(t813)  # t814: "cuda:0 f32[1, 32, 512, 64]"
    # t815 = prims.convert_element_type(t814, dtypes.bfloat16)  # t815: "cuda:0 bf16[1, 32, 512, 64]"
  del t797, t812
  t816 = torch.cat((t815, t811), -1)  # t816: "cuda:0 bf16[1, 32, 512, 128]"
    # t816 = ltorch.cat((t815, t811), -1)  # t816: "cuda:0 bf16[1, 32, 512, 128]"
      # t816 = prims.cat((t815, t811), -1)  # t816: "cuda:0 bf16[1, 32, 512, 128]"
  del t815, t811
  t801 = torch.cat((t800, t796), -1)  # t801: "cuda:0 bf16[1, 32, 512, 128]"
    # t801 = ltorch.cat((t800, t796), -1)  # t801: "cuda:0 bf16[1, 32, 512, 128]"
      # t801 = prims.cat((t800, t796), -1)  # t801: "cuda:0 bf16[1, 32, 512, 128]"
  del t800, t796
  [t809, t824] = nvFusion32(t154, t157, t795, t801, t810, t816)
    # t803 = prims.convert_element_type(t795, dtypes.float32)  # t803: "cuda:0 f32[1, 32, 512, 128]"
    # t818 = prims.convert_element_type(t810, dtypes.float32)  # t818: "cuda:0 f32[1, 32, 512, 128]"
    # t819 = prims.mul(t818, t154)  # t819: "cuda:0 f32[1, 32, 512, 128]"
    # t821 = prims.convert_element_type(t816, dtypes.float32)  # t821: "cuda:0 f32[1, 32, 512, 128]"
    # t822 = prims.mul(t821, t157)  # t822: "cuda:0 f32[1, 32, 512, 128]"
    # t823 = prims.add(t819, t822)  # t823: "cuda:0 f32[1, 32, 512, 128]"
    # t824 = prims.convert_element_type(t823, dtypes.bfloat16)  # t824: "cuda:0 bf16[1, 32, 512, 128]"
    # t804 = prims.mul(t803, t154)  # t804: "cuda:0 f32[1, 32, 512, 128]"
    # t806 = prims.convert_element_type(t801, dtypes.float32)  # t806: "cuda:0 f32[1, 32, 512, 128]"
    # t807 = prims.mul(t806, t157)  # t807: "cuda:0 f32[1, 32, 512, 128]"
    # t808 = prims.add(t804, t807)  # t808: "cuda:0 f32[1, 32, 512, 128]"
    # t809 = prims.convert_element_type(t808, dtypes.bfloat16)  # t809: "cuda:0 bf16[1, 32, 512, 128]"
  del t795, t801, t810, t816
  t828 = torch.cat((t824, t827), -1)  # t828: "cuda:0 bf16[1, 32, 512, 128]"
    # t828 = ltorch.cat((t824, t827), -1)  # t828: "cuda:0 bf16[1, 32, 512, 128]"
      # t828 = prims.cat((t824, t827), -1)  # t828: "cuda:0 bf16[1, 32, 512, 128]"
  del t824, t827
  t826 = torch.cat((t809, t825), -1)  # t826: "cuda:0 bf16[1, 32, 512, 128]"
    # t826 = ltorch.cat((t809, t825), -1)  # t826: "cuda:0 bf16[1, 32, 512, 128]"
      # t826 = prims.cat((t809, t825), -1)  # t826: "cuda:0 bf16[1, 32, 512, 128]"
  del t809, t825
  (t829, t830, t831, t832, _, _, t833, t834, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t826, t828, t794, 0.0, True, scale=0.08838834764831843)
  t836 = torch.permute(t829, (0, 2, 1, 3))  # t836: "cuda:0 bf16[1, 512, 32, 128]"
    # t836 = ltorch.permute(t829, (0, 2, 1, 3))  # t836: "cuda:0 bf16[1, 512, 32, 128]"
      # t836 = prims.transpose(t829, (0, 2, 1, 3))  # t836: "cuda:0 bf16[1, 512, 32, 128]"
  t837 = torch.reshape(t836, (1, 512, 4096))  # t837: "cuda:0 bf16[1, 512, 4096]"
    # t837 = ltorch.reshape(t836, (1, 512, 4096))  # t837: "cuda:0 bf16[1, 512, 4096]"
      # t837 = prims.reshape(t836, (1, 512, 4096))  # t837: "cuda:0 bf16[1, 512, 4096]"
  del t836
  t838 = torch.nn.functional.linear(t837, t97, None)  # t838: "cuda:0 bf16[1, 512, 4096]"
    # t838 = ltorch.linear(t837, t97, None)  # t838: "cuda:0 bf16[1, 512, 4096]"
      # t838 = prims.linear(t837, t97, None)  # t838: "cuda:0 bf16[1, 512, 4096]"
  [t842, t849, t857] = nvFusion33(t770, t838, t853)
    # t840 = prims.convert_element_type(t770, dtypes.float32)  # t840: "cuda:0 f32[1, 512, 4096]"
    # t839 = prims.convert_element_type(t838, dtypes.float32)  # t839: "cuda:0 f32[1, 512, 4096]"
    # t841 = prims.add(t839, t840)  # t841: "cuda:0 f32[1, 512, 4096]"
    # t842 = prims.convert_element_type(t841, dtypes.bfloat16)  # t842: "cuda:0 bf16[1, 512, 4096]"
    # t844 = prims.mul(t841, t841)  # t844: "cuda:0 f32[1, 512, 4096]"
    # t845 = prims.sum(t844, (2,))  # t845: "cuda:0 f32[1, 512]"
    # t846 = prims.broadcast_in_dim(t845, [1, 512, 1], [0, 1])  # t846: "cuda:0 f32[1, 512, 1]"
    # t847 = prims.div(t846, 4096.0)  # t847: "cuda:0 f32[1, 512, 1]"
    # t848 = prims.add(t847, 1e-05)  # t848: "cuda:0 f32[1, 512, 1]"
    # t849 = prims.rsqrt(t848)  # t849: "cuda:0 f32[1, 512, 1]"
    # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2))  # t850: "cuda:0 f32[1, 512, 4096]"
    # t851 = prims.mul(t841, t850)  # t851: "cuda:0 f32[1, 512, 4096]"
    # t855 = prims.convert_element_type(t853, dtypes.float32)  # t855: "cuda:0 f32[1, 512, 4096]"
    # t856 = prims.mul(t851, t855)  # t856: "cuda:0 f32[1, 512, 4096]"
    # t857 = prims.convert_element_type(t856, dtypes.bfloat16)  # t857: "cuda:0 bf16[1, 512, 4096]"
  t858 = torch.nn.functional.linear(t857, t25, None)  # t858: "cuda:0 bf16[1, 512, 11008]"
    # t858 = ltorch.linear(t857, t25, None)  # t858: "cuda:0 bf16[1, 512, 11008]"
      # t858 = prims.linear(t857, t25, None)  # t858: "cuda:0 bf16[1, 512, 11008]"
  t859 = torch.nn.functional.linear(t857, t41, None)  # t859: "cuda:0 bf16[1, 512, 11008]"
    # t859 = ltorch.linear(t857, t41, None)  # t859: "cuda:0 bf16[1, 512, 11008]"
      # t859 = prims.linear(t857, t41, None)  # t859: "cuda:0 bf16[1, 512, 11008]"
  [t873] = nvFusion34(t858, t859)
    # t860 = prims.convert_element_type(t858, dtypes.float32)  # t860: "cuda:0 f32[1, 512, 11008]"
    # t861 = prims.neg(t860)  # t861: "cuda:0 f32[1, 512, 11008]"
    # t862 = prims.exp(t861)  # t862: "cuda:0 f32[1, 512, 11008]"
    # t863 = prims.add(1.0, t862)  # t863: "cuda:0 f32[1, 512, 11008]"
    # t864 = prims.reciprocal(t863)  # t864: "cuda:0 f32[1, 512, 11008]"
    # t868 = prims.mul(t860, t864)  # t868: "cuda:0 f32[1, 512, 11008]"
    # t871 = prims.convert_element_type(t859, dtypes.float32)  # t871: "cuda:0 f32[1, 512, 11008]"
    # t872 = prims.mul(t868, t871)  # t872: "cuda:0 f32[1, 512, 11008]"
    # t873 = prims.convert_element_type(t872, dtypes.bfloat16)  # t873: "cuda:0 bf16[1, 512, 11008]"
  t874 = torch.nn.functional.linear(t873, t98, None)  # t874: "cuda:0 bf16[1, 512, 4096]"
    # t874 = ltorch.linear(t873, t98, None)  # t874: "cuda:0 bf16[1, 512, 4096]"
      # t874 = prims.linear(t873, t98, None)  # t874: "cuda:0 bf16[1, 512, 4096]"
  [t878, t885, t893] = nvFusion35(t842, t874, t889)
    # t876 = prims.convert_element_type(t842, dtypes.float32)  # t876: "cuda:0 f32[1, 512, 4096]"
    # t875 = prims.convert_element_type(t874, dtypes.float32)  # t875: "cuda:0 f32[1, 512, 4096]"
    # t877 = prims.add(t875, t876)  # t877: "cuda:0 f32[1, 512, 4096]"
    # t878 = prims.convert_element_type(t877, dtypes.bfloat16)  # t878: "cuda:0 bf16[1, 512, 4096]"
    # t880 = prims.mul(t877, t877)  # t880: "cuda:0 f32[1, 512, 4096]"
    # t881 = prims.sum(t880, (2,))  # t881: "cuda:0 f32[1, 512]"
    # t882 = prims.broadcast_in_dim(t881, [1, 512, 1], [0, 1])  # t882: "cuda:0 f32[1, 512, 1]"
    # t883 = prims.div(t882, 4096.0)  # t883: "cuda:0 f32[1, 512, 1]"
    # t884 = prims.add(t883, 1e-05)  # t884: "cuda:0 f32[1, 512, 1]"
    # t885 = prims.rsqrt(t884)  # t885: "cuda:0 f32[1, 512, 1]"
    # t886 = prims.broadcast_in_dim(t885, (1, 512, 4096), (0, 1, 2))  # t886: "cuda:0 f32[1, 512, 4096]"
    # t887 = prims.mul(t877, t886)  # t887: "cuda:0 f32[1, 512, 4096]"
    # t891 = prims.convert_element_type(t889, dtypes.float32)  # t891: "cuda:0 f32[1, 512, 4096]"
    # t892 = prims.mul(t887, t891)  # t892: "cuda:0 f32[1, 512, 4096]"
    # t893 = prims.convert_element_type(t892, dtypes.bfloat16)  # t893: "cuda:0 bf16[1, 512, 4096]"
  t894 = torch.nn.functional.linear(t893, t10, None)  # t894: "cuda:0 bf16[1, 512, 12288]"
    # t894 = ltorch.linear(t893, t10, None)  # t894: "cuda:0 bf16[1, 512, 12288]"
      # t894 = prims.linear(t893, t10, None)  # t894: "cuda:0 bf16[1, 512, 12288]"
  t895 = torch.reshape(t894, (1, 512, 32, 3, 128))  # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t895 = ltorch.reshape(t894, (1, 512, 32, 3, 128))  # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t895 = prims.reshape(t894, (1, 512, 32, 3, 128))  # t895: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t894
  t896 = torch.permute(t895, (0, 2, 3, 1, 4))  # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t896 = ltorch.permute(t895, (0, 2, 3, 1, 4))  # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t896 = prims.transpose(t895, (0, 2, 3, 1, 4))  # t896: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t895
  (t897, t898, t899) = torch.split(t896, (1, 1, 1), 2)
    # (t897, t898, t899) = ltorch.split(t896, (1, 1, 1), 2)
      # t897 = prims.slice_prim(t896, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t897: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t898 = prims.slice_prim(t896, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t898: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t899 = prims.slice_prim(t896, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t899: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t896
  t900 = torch.reshape(t897, (1, 32, 512, 128))  # t900: "cuda:0 bf16[1, 32, 512, 128]"
    # t900 = ltorch.reshape(t897, (1, 32, 512, 128))  # t900: "cuda:0 bf16[1, 32, 512, 128]"
      # t900 = prims.reshape(t897, (1, 32, 512, 128))  # t900: "cuda:0 bf16[1, 32, 512, 128]"
  del t897
  t901 = torch.reshape(t898, (1, 32, 512, 128))  # t901: "cuda:0 bf16[1, 32, 512, 128]"
    # t901 = ltorch.reshape(t898, (1, 32, 512, 128))  # t901: "cuda:0 bf16[1, 32, 512, 128]"
      # t901 = prims.reshape(t898, (1, 32, 512, 128))  # t901: "cuda:0 bf16[1, 32, 512, 128]"
  del t898
  t902 = torch.reshape(t899, (1, 32, 512, 128))  # t902: "cuda:0 bf16[1, 32, 512, 128]"
    # t902 = ltorch.reshape(t899, (1, 32, 512, 128))  # t902: "cuda:0 bf16[1, 32, 512, 128]"
      # t902 = prims.reshape(t899, (1, 32, 512, 128))  # t902: "cuda:0 bf16[1, 32, 512, 128]"
  del t899
  t935 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t935: "cuda:0 bf16[1, 32, 512, 0]"
  t903 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t903: "cuda:0 bf16[1, 32, 512, 128]"
  t918 = torch_slice_prim_impl(t901, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t918: "cuda:0 bf16[1, 32, 512, 128]"
  del t901
  t933 = torch_slice_prim_impl(t900, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t933: "cuda:0 bf16[1, 32, 512, 0]"
  del t900
  t904 = torch_slice_prim_impl(t903, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t904: "cuda:0 bf16[1, 32, 512, 64]"
  t905 = torch_slice_prim_impl(t903, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t905: "cuda:0 bf16[1, 32, 512, 64]"
  t919 = torch_slice_prim_impl(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t919: "cuda:0 bf16[1, 32, 512, 64]"
  t920 = torch_slice_prim_impl(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t920: "cuda:0 bf16[1, 32, 512, 64]"
  [t908, t923] = nvFusion36(t903, t905, t918, t920)
    # t906 = prims.convert_element_type(t905, dtypes.float32)  # t906: "cuda:0 f32[1, 32, 512, 64]"
    # t907 = prims.neg(t906)  # t907: "cuda:0 f32[1, 32, 512, 64]"
    # t908 = prims.convert_element_type(t907, dtypes.bfloat16)  # t908: "cuda:0 bf16[1, 32, 512, 64]"
    # t921 = prims.convert_element_type(t920, dtypes.float32)  # t921: "cuda:0 f32[1, 32, 512, 64]"
    # t922 = prims.neg(t921)  # t922: "cuda:0 f32[1, 32, 512, 64]"
    # t923 = prims.convert_element_type(t922, dtypes.bfloat16)  # t923: "cuda:0 bf16[1, 32, 512, 64]"
  del t905, t920
  t924 = torch.cat((t923, t919), -1)  # t924: "cuda:0 bf16[1, 32, 512, 128]"
    # t924 = ltorch.cat((t923, t919), -1)  # t924: "cuda:0 bf16[1, 32, 512, 128]"
      # t924 = prims.cat((t923, t919), -1)  # t924: "cuda:0 bf16[1, 32, 512, 128]"
  del t923, t919
  t909 = torch.cat((t908, t904), -1)  # t909: "cuda:0 bf16[1, 32, 512, 128]"
    # t909 = ltorch.cat((t908, t904), -1)  # t909: "cuda:0 bf16[1, 32, 512, 128]"
      # t909 = prims.cat((t908, t904), -1)  # t909: "cuda:0 bf16[1, 32, 512, 128]"
  del t908, t904
  [t917, t932] = nvFusion37(t154, t157, t903, t909, t918, t924)
    # t911 = prims.convert_element_type(t903, dtypes.float32)  # t911: "cuda:0 f32[1, 32, 512, 128]"
    # t926 = prims.convert_element_type(t918, dtypes.float32)  # t926: "cuda:0 f32[1, 32, 512, 128]"
    # t927 = prims.mul(t926, t154)  # t927: "cuda:0 f32[1, 32, 512, 128]"
    # t929 = prims.convert_element_type(t924, dtypes.float32)  # t929: "cuda:0 f32[1, 32, 512, 128]"
    # t930 = prims.mul(t929, t157)  # t930: "cuda:0 f32[1, 32, 512, 128]"
    # t931 = prims.add(t927, t930)  # t931: "cuda:0 f32[1, 32, 512, 128]"
    # t932 = prims.convert_element_type(t931, dtypes.bfloat16)  # t932: "cuda:0 bf16[1, 32, 512, 128]"
    # t912 = prims.mul(t911, t154)  # t912: "cuda:0 f32[1, 32, 512, 128]"
    # t914 = prims.convert_element_type(t909, dtypes.float32)  # t914: "cuda:0 f32[1, 32, 512, 128]"
    # t915 = prims.mul(t914, t157)  # t915: "cuda:0 f32[1, 32, 512, 128]"
    # t916 = prims.add(t912, t915)  # t916: "cuda:0 f32[1, 32, 512, 128]"
    # t917 = prims.convert_element_type(t916, dtypes.bfloat16)  # t917: "cuda:0 bf16[1, 32, 512, 128]"
  del t903, t909, t918, t924
  t936 = torch.cat((t932, t935), -1)  # t936: "cuda:0 bf16[1, 32, 512, 128]"
    # t936 = ltorch.cat((t932, t935), -1)  # t936: "cuda:0 bf16[1, 32, 512, 128]"
      # t936 = prims.cat((t932, t935), -1)  # t936: "cuda:0 bf16[1, 32, 512, 128]"
  del t932, t935
  t934 = torch.cat((t917, t933), -1)  # t934: "cuda:0 bf16[1, 32, 512, 128]"
    # t934 = ltorch.cat((t917, t933), -1)  # t934: "cuda:0 bf16[1, 32, 512, 128]"
      # t934 = prims.cat((t917, t933), -1)  # t934: "cuda:0 bf16[1, 32, 512, 128]"
  del t917, t933
  (t937, t938, t939, t940, _, _, t941, t942, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t934, t936, t902, 0.0, True, scale=0.08838834764831843)
  t944 = torch.permute(t937, (0, 2, 1, 3))  # t944: "cuda:0 bf16[1, 512, 32, 128]"
    # t944 = ltorch.permute(t937, (0, 2, 1, 3))  # t944: "cuda:0 bf16[1, 512, 32, 128]"
      # t944 = prims.transpose(t937, (0, 2, 1, 3))  # t944: "cuda:0 bf16[1, 512, 32, 128]"
  t945 = torch.reshape(t944, (1, 512, 4096))  # t945: "cuda:0 bf16[1, 512, 4096]"
    # t945 = ltorch.reshape(t944, (1, 512, 4096))  # t945: "cuda:0 bf16[1, 512, 4096]"
      # t945 = prims.reshape(t944, (1, 512, 4096))  # t945: "cuda:0 bf16[1, 512, 4096]"
  del t944
  t946 = torch.nn.functional.linear(t945, t99, None)  # t946: "cuda:0 bf16[1, 512, 4096]"
    # t946 = ltorch.linear(t945, t99, None)  # t946: "cuda:0 bf16[1, 512, 4096]"
      # t946 = prims.linear(t945, t99, None)  # t946: "cuda:0 bf16[1, 512, 4096]"
  [t950, t957, t965] = nvFusion38(t878, t946, t961)
    # t948 = prims.convert_element_type(t878, dtypes.float32)  # t948: "cuda:0 f32[1, 512, 4096]"
    # t947 = prims.convert_element_type(t946, dtypes.float32)  # t947: "cuda:0 f32[1, 512, 4096]"
    # t949 = prims.add(t947, t948)  # t949: "cuda:0 f32[1, 512, 4096]"
    # t950 = prims.convert_element_type(t949, dtypes.bfloat16)  # t950: "cuda:0 bf16[1, 512, 4096]"
    # t952 = prims.mul(t949, t949)  # t952: "cuda:0 f32[1, 512, 4096]"
    # t953 = prims.sum(t952, (2,))  # t953: "cuda:0 f32[1, 512]"
    # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1])  # t954: "cuda:0 f32[1, 512, 1]"
    # t955 = prims.div(t954, 4096.0)  # t955: "cuda:0 f32[1, 512, 1]"
    # t956 = prims.add(t955, 1e-05)  # t956: "cuda:0 f32[1, 512, 1]"
    # t957 = prims.rsqrt(t956)  # t957: "cuda:0 f32[1, 512, 1]"
    # t958 = prims.broadcast_in_dim(t957, (1, 512, 4096), (0, 1, 2))  # t958: "cuda:0 f32[1, 512, 4096]"
    # t959 = prims.mul(t949, t958)  # t959: "cuda:0 f32[1, 512, 4096]"
    # t963 = prims.convert_element_type(t961, dtypes.float32)  # t963: "cuda:0 f32[1, 512, 4096]"
    # t964 = prims.mul(t959, t963)  # t964: "cuda:0 f32[1, 512, 4096]"
    # t965 = prims.convert_element_type(t964, dtypes.bfloat16)  # t965: "cuda:0 bf16[1, 512, 4096]"
  t967 = torch.nn.functional.linear(t965, t42, None)  # t967: "cuda:0 bf16[1, 512, 11008]"
    # t967 = ltorch.linear(t965, t42, None)  # t967: "cuda:0 bf16[1, 512, 11008]"
      # t967 = prims.linear(t965, t42, None)  # t967: "cuda:0 bf16[1, 512, 11008]"
  t966 = torch.nn.functional.linear(t965, t26, None)  # t966: "cuda:0 bf16[1, 512, 11008]"
    # t966 = ltorch.linear(t965, t26, None)  # t966: "cuda:0 bf16[1, 512, 11008]"
      # t966 = prims.linear(t965, t26, None)  # t966: "cuda:0 bf16[1, 512, 11008]"
  [t981] = nvFusion39(t966, t967)
    # t968 = prims.convert_element_type(t966, dtypes.float32)  # t968: "cuda:0 f32[1, 512, 11008]"
    # t969 = prims.neg(t968)  # t969: "cuda:0 f32[1, 512, 11008]"
    # t970 = prims.exp(t969)  # t970: "cuda:0 f32[1, 512, 11008]"
    # t971 = prims.add(1.0, t970)  # t971: "cuda:0 f32[1, 512, 11008]"
    # t972 = prims.reciprocal(t971)  # t972: "cuda:0 f32[1, 512, 11008]"
    # t976 = prims.mul(t968, t972)  # t976: "cuda:0 f32[1, 512, 11008]"
    # t979 = prims.convert_element_type(t967, dtypes.float32)  # t979: "cuda:0 f32[1, 512, 11008]"
    # t980 = prims.mul(t976, t979)  # t980: "cuda:0 f32[1, 512, 11008]"
    # t981 = prims.convert_element_type(t980, dtypes.bfloat16)  # t981: "cuda:0 bf16[1, 512, 11008]"
  t982 = torch.nn.functional.linear(t981, t100, None)  # t982: "cuda:0 bf16[1, 512, 4096]"
    # t982 = ltorch.linear(t981, t100, None)  # t982: "cuda:0 bf16[1, 512, 4096]"
      # t982 = prims.linear(t981, t100, None)  # t982: "cuda:0 bf16[1, 512, 4096]"
  [t1001, t986, t993] = nvFusion40(t950, t982, t997)
    # t984 = prims.convert_element_type(t950, dtypes.float32)  # t984: "cuda:0 f32[1, 512, 4096]"
    # t983 = prims.convert_element_type(t982, dtypes.float32)  # t983: "cuda:0 f32[1, 512, 4096]"
    # t985 = prims.add(t983, t984)  # t985: "cuda:0 f32[1, 512, 4096]"
    # t986 = prims.convert_element_type(t985, dtypes.bfloat16)  # t986: "cuda:0 bf16[1, 512, 4096]"
    # t988 = prims.mul(t985, t985)  # t988: "cuda:0 f32[1, 512, 4096]"
    # t989 = prims.sum(t988, (2,))  # t989: "cuda:0 f32[1, 512]"
    # t990 = prims.broadcast_in_dim(t989, [1, 512, 1], [0, 1])  # t990: "cuda:0 f32[1, 512, 1]"
    # t991 = prims.div(t990, 4096.0)  # t991: "cuda:0 f32[1, 512, 1]"
    # t992 = prims.add(t991, 1e-05)  # t992: "cuda:0 f32[1, 512, 1]"
    # t993 = prims.rsqrt(t992)  # t993: "cuda:0 f32[1, 512, 1]"
    # t994 = prims.broadcast_in_dim(t993, (1, 512, 4096), (0, 1, 2))  # t994: "cuda:0 f32[1, 512, 4096]"
    # t995 = prims.mul(t985, t994)  # t995: "cuda:0 f32[1, 512, 4096]"
    # t999 = prims.convert_element_type(t997, dtypes.float32)  # t999: "cuda:0 f32[1, 512, 4096]"
    # t1000 = prims.mul(t995, t999)  # t1000: "cuda:0 f32[1, 512, 4096]"
    # t1001 = prims.convert_element_type(t1000, dtypes.bfloat16)  # t1001: "cuda:0 bf16[1, 512, 4096]"
  t1002 = torch.nn.functional.linear(t1001, t11, None)  # t1002: "cuda:0 bf16[1, 512, 12288]"
    # t1002 = ltorch.linear(t1001, t11, None)  # t1002: "cuda:0 bf16[1, 512, 12288]"
      # t1002 = prims.linear(t1001, t11, None)  # t1002: "cuda:0 bf16[1, 512, 12288]"
  t1003 = torch.reshape(t1002, (1, 512, 32, 3, 128))  # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1003 = ltorch.reshape(t1002, (1, 512, 32, 3, 128))  # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1003 = prims.reshape(t1002, (1, 512, 32, 3, 128))  # t1003: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1002
  t1004 = torch.permute(t1003, (0, 2, 3, 1, 4))  # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1004 = ltorch.permute(t1003, (0, 2, 3, 1, 4))  # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1004 = prims.transpose(t1003, (0, 2, 3, 1, 4))  # t1004: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1003
  (t1005, t1006, t1007) = torch.split(t1004, (1, 1, 1), 2)
    # (t1005, t1006, t1007) = ltorch.split(t1004, (1, 1, 1), 2)
      # t1005 = prims.slice_prim(t1004, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1005: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1006 = prims.slice_prim(t1004, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1006: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1007 = prims.slice_prim(t1004, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1007: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1004
  t1008 = torch.reshape(t1005, (1, 32, 512, 128))  # t1008: "cuda:0 bf16[1, 32, 512, 128]"
    # t1008 = ltorch.reshape(t1005, (1, 32, 512, 128))  # t1008: "cuda:0 bf16[1, 32, 512, 128]"
      # t1008 = prims.reshape(t1005, (1, 32, 512, 128))  # t1008: "cuda:0 bf16[1, 32, 512, 128]"
  del t1005
  t1009 = torch.reshape(t1006, (1, 32, 512, 128))  # t1009: "cuda:0 bf16[1, 32, 512, 128]"
    # t1009 = ltorch.reshape(t1006, (1, 32, 512, 128))  # t1009: "cuda:0 bf16[1, 32, 512, 128]"
      # t1009 = prims.reshape(t1006, (1, 32, 512, 128))  # t1009: "cuda:0 bf16[1, 32, 512, 128]"
  del t1006
  t1010 = torch.reshape(t1007, (1, 32, 512, 128))  # t1010: "cuda:0 bf16[1, 32, 512, 128]"
    # t1010 = ltorch.reshape(t1007, (1, 32, 512, 128))  # t1010: "cuda:0 bf16[1, 32, 512, 128]"
      # t1010 = prims.reshape(t1007, (1, 32, 512, 128))  # t1010: "cuda:0 bf16[1, 32, 512, 128]"
  del t1007
  t1026 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1026: "cuda:0 bf16[1, 32, 512, 128]"
  t1041 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1041: "cuda:0 bf16[1, 32, 512, 0]"
  t1043 = torch_slice_prim_impl(t1009, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1043: "cuda:0 bf16[1, 32, 512, 0]"
  del t1009
  t1011 = torch_slice_prim_impl(t1008, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1011: "cuda:0 bf16[1, 32, 512, 128]"
  del t1008
  t1027 = torch_slice_prim_impl(t1026, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1027: "cuda:0 bf16[1, 32, 512, 64]"
  t1028 = torch_slice_prim_impl(t1026, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1028: "cuda:0 bf16[1, 32, 512, 64]"
  t1013 = torch_slice_prim_impl(t1011, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1013: "cuda:0 bf16[1, 32, 512, 64]"
  t1012 = torch_slice_prim_impl(t1011, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1012: "cuda:0 bf16[1, 32, 512, 64]"
  [t1016, t1031] = nvFusion41(t1011, t1013, t1026, t1028)
    # t1014 = prims.convert_element_type(t1013, dtypes.float32)  # t1014: "cuda:0 f32[1, 32, 512, 64]"
    # t1015 = prims.neg(t1014)  # t1015: "cuda:0 f32[1, 32, 512, 64]"
    # t1016 = prims.convert_element_type(t1015, dtypes.bfloat16)  # t1016: "cuda:0 bf16[1, 32, 512, 64]"
    # t1029 = prims.convert_element_type(t1028, dtypes.float32)  # t1029: "cuda:0 f32[1, 32, 512, 64]"
    # t1030 = prims.neg(t1029)  # t1030: "cuda:0 f32[1, 32, 512, 64]"
    # t1031 = prims.convert_element_type(t1030, dtypes.bfloat16)  # t1031: "cuda:0 bf16[1, 32, 512, 64]"
  del t1013, t1028
  t1032 = torch.cat((t1031, t1027), -1)  # t1032: "cuda:0 bf16[1, 32, 512, 128]"
    # t1032 = ltorch.cat((t1031, t1027), -1)  # t1032: "cuda:0 bf16[1, 32, 512, 128]"
      # t1032 = prims.cat((t1031, t1027), -1)  # t1032: "cuda:0 bf16[1, 32, 512, 128]"
  del t1031, t1027
  t1017 = torch.cat((t1016, t1012), -1)  # t1017: "cuda:0 bf16[1, 32, 512, 128]"
    # t1017 = ltorch.cat((t1016, t1012), -1)  # t1017: "cuda:0 bf16[1, 32, 512, 128]"
      # t1017 = prims.cat((t1016, t1012), -1)  # t1017: "cuda:0 bf16[1, 32, 512, 128]"
  del t1016, t1012
  [t1025, t1040] = nvFusion42(t1011, t1017, t1026, t1032, t154, t157)
    # t1019 = prims.convert_element_type(t1011, dtypes.float32)  # t1019: "cuda:0 f32[1, 32, 512, 128]"
    # t1034 = prims.convert_element_type(t1026, dtypes.float32)  # t1034: "cuda:0 f32[1, 32, 512, 128]"
    # t1020 = prims.mul(t1019, t154)  # t1020: "cuda:0 f32[1, 32, 512, 128]"
    # t1022 = prims.convert_element_type(t1017, dtypes.float32)  # t1022: "cuda:0 f32[1, 32, 512, 128]"
    # t1023 = prims.mul(t1022, t157)  # t1023: "cuda:0 f32[1, 32, 512, 128]"
    # t1024 = prims.add(t1020, t1023)  # t1024: "cuda:0 f32[1, 32, 512, 128]"
    # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16)  # t1025: "cuda:0 bf16[1, 32, 512, 128]"
    # t1035 = prims.mul(t1034, t154)  # t1035: "cuda:0 f32[1, 32, 512, 128]"
    # t1037 = prims.convert_element_type(t1032, dtypes.float32)  # t1037: "cuda:0 f32[1, 32, 512, 128]"
    # t1038 = prims.mul(t1037, t157)  # t1038: "cuda:0 f32[1, 32, 512, 128]"
    # t1039 = prims.add(t1035, t1038)  # t1039: "cuda:0 f32[1, 32, 512, 128]"
    # t1040 = prims.convert_element_type(t1039, dtypes.bfloat16)  # t1040: "cuda:0 bf16[1, 32, 512, 128]"
  del t1011, t1017, t1026, t1032
  t1042 = torch.cat((t1025, t1041), -1)  # t1042: "cuda:0 bf16[1, 32, 512, 128]"
    # t1042 = ltorch.cat((t1025, t1041), -1)  # t1042: "cuda:0 bf16[1, 32, 512, 128]"
      # t1042 = prims.cat((t1025, t1041), -1)  # t1042: "cuda:0 bf16[1, 32, 512, 128]"
  del t1025, t1041
  t1044 = torch.cat((t1040, t1043), -1)  # t1044: "cuda:0 bf16[1, 32, 512, 128]"
    # t1044 = ltorch.cat((t1040, t1043), -1)  # t1044: "cuda:0 bf16[1, 32, 512, 128]"
      # t1044 = prims.cat((t1040, t1043), -1)  # t1044: "cuda:0 bf16[1, 32, 512, 128]"
  del t1040, t1043
  (t1045, t1046, t1047, t1048, _, _, t1049, t1050, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1042, t1044, t1010, 0.0, True, scale=0.08838834764831843)
  t1052 = torch.permute(t1045, (0, 2, 1, 3))  # t1052: "cuda:0 bf16[1, 512, 32, 128]"
    # t1052 = ltorch.permute(t1045, (0, 2, 1, 3))  # t1052: "cuda:0 bf16[1, 512, 32, 128]"
      # t1052 = prims.transpose(t1045, (0, 2, 1, 3))  # t1052: "cuda:0 bf16[1, 512, 32, 128]"
  t1053 = torch.reshape(t1052, (1, 512, 4096))  # t1053: "cuda:0 bf16[1, 512, 4096]"
    # t1053 = ltorch.reshape(t1052, (1, 512, 4096))  # t1053: "cuda:0 bf16[1, 512, 4096]"
      # t1053 = prims.reshape(t1052, (1, 512, 4096))  # t1053: "cuda:0 bf16[1, 512, 4096]"
  del t1052
  t1054 = torch.nn.functional.linear(t1053, t101, None)  # t1054: "cuda:0 bf16[1, 512, 4096]"
    # t1054 = ltorch.linear(t1053, t101, None)  # t1054: "cuda:0 bf16[1, 512, 4096]"
      # t1054 = prims.linear(t1053, t101, None)  # t1054: "cuda:0 bf16[1, 512, 4096]"
  [t1058, t1065, t1073] = nvFusion43(t1054, t1069, t986)
    # t1056 = prims.convert_element_type(t986, dtypes.float32)  # t1056: "cuda:0 f32[1, 512, 4096]"
    # t1055 = prims.convert_element_type(t1054, dtypes.float32)  # t1055: "cuda:0 f32[1, 512, 4096]"
    # t1057 = prims.add(t1055, t1056)  # t1057: "cuda:0 f32[1, 512, 4096]"
    # t1058 = prims.convert_element_type(t1057, dtypes.bfloat16)  # t1058: "cuda:0 bf16[1, 512, 4096]"
    # t1060 = prims.mul(t1057, t1057)  # t1060: "cuda:0 f32[1, 512, 4096]"
    # t1061 = prims.sum(t1060, (2,))  # t1061: "cuda:0 f32[1, 512]"
    # t1062 = prims.broadcast_in_dim(t1061, [1, 512, 1], [0, 1])  # t1062: "cuda:0 f32[1, 512, 1]"
    # t1063 = prims.div(t1062, 4096.0)  # t1063: "cuda:0 f32[1, 512, 1]"
    # t1064 = prims.add(t1063, 1e-05)  # t1064: "cuda:0 f32[1, 512, 1]"
    # t1065 = prims.rsqrt(t1064)  # t1065: "cuda:0 f32[1, 512, 1]"
    # t1066 = prims.broadcast_in_dim(t1065, (1, 512, 4096), (0, 1, 2))  # t1066: "cuda:0 f32[1, 512, 4096]"
    # t1067 = prims.mul(t1057, t1066)  # t1067: "cuda:0 f32[1, 512, 4096]"
    # t1071 = prims.convert_element_type(t1069, dtypes.float32)  # t1071: "cuda:0 f32[1, 512, 4096]"
    # t1072 = prims.mul(t1067, t1071)  # t1072: "cuda:0 f32[1, 512, 4096]"
    # t1073 = prims.convert_element_type(t1072, dtypes.bfloat16)  # t1073: "cuda:0 bf16[1, 512, 4096]"
  t1074 = torch.nn.functional.linear(t1073, t27, None)  # t1074: "cuda:0 bf16[1, 512, 11008]"
    # t1074 = ltorch.linear(t1073, t27, None)  # t1074: "cuda:0 bf16[1, 512, 11008]"
      # t1074 = prims.linear(t1073, t27, None)  # t1074: "cuda:0 bf16[1, 512, 11008]"
  t1075 = torch.nn.functional.linear(t1073, t43, None)  # t1075: "cuda:0 bf16[1, 512, 11008]"
    # t1075 = ltorch.linear(t1073, t43, None)  # t1075: "cuda:0 bf16[1, 512, 11008]"
      # t1075 = prims.linear(t1073, t43, None)  # t1075: "cuda:0 bf16[1, 512, 11008]"
  [t1089] = nvFusion44(t1074, t1075)
    # t1076 = prims.convert_element_type(t1074, dtypes.float32)  # t1076: "cuda:0 f32[1, 512, 11008]"
    # t1077 = prims.neg(t1076)  # t1077: "cuda:0 f32[1, 512, 11008]"
    # t1078 = prims.exp(t1077)  # t1078: "cuda:0 f32[1, 512, 11008]"
    # t1079 = prims.add(1.0, t1078)  # t1079: "cuda:0 f32[1, 512, 11008]"
    # t1080 = prims.reciprocal(t1079)  # t1080: "cuda:0 f32[1, 512, 11008]"
    # t1084 = prims.mul(t1076, t1080)  # t1084: "cuda:0 f32[1, 512, 11008]"
    # t1087 = prims.convert_element_type(t1075, dtypes.float32)  # t1087: "cuda:0 f32[1, 512, 11008]"
    # t1088 = prims.mul(t1084, t1087)  # t1088: "cuda:0 f32[1, 512, 11008]"
    # t1089 = prims.convert_element_type(t1088, dtypes.bfloat16)  # t1089: "cuda:0 bf16[1, 512, 11008]"
  t1090 = torch.nn.functional.linear(t1089, t102, None)  # t1090: "cuda:0 bf16[1, 512, 4096]"
    # t1090 = ltorch.linear(t1089, t102, None)  # t1090: "cuda:0 bf16[1, 512, 4096]"
      # t1090 = prims.linear(t1089, t102, None)  # t1090: "cuda:0 bf16[1, 512, 4096]"
  [t1094, t1101, t1109] = nvFusion45(t1058, t1090, t1105)
    # t1092 = prims.convert_element_type(t1058, dtypes.float32)  # t1092: "cuda:0 f32[1, 512, 4096]"
    # t1091 = prims.convert_element_type(t1090, dtypes.float32)  # t1091: "cuda:0 f32[1, 512, 4096]"
    # t1093 = prims.add(t1091, t1092)  # t1093: "cuda:0 f32[1, 512, 4096]"
    # t1094 = prims.convert_element_type(t1093, dtypes.bfloat16)  # t1094: "cuda:0 bf16[1, 512, 4096]"
    # t1096 = prims.mul(t1093, t1093)  # t1096: "cuda:0 f32[1, 512, 4096]"
    # t1097 = prims.sum(t1096, (2,))  # t1097: "cuda:0 f32[1, 512]"
    # t1098 = prims.broadcast_in_dim(t1097, [1, 512, 1], [0, 1])  # t1098: "cuda:0 f32[1, 512, 1]"
    # t1099 = prims.div(t1098, 4096.0)  # t1099: "cuda:0 f32[1, 512, 1]"
    # t1100 = prims.add(t1099, 1e-05)  # t1100: "cuda:0 f32[1, 512, 1]"
    # t1101 = prims.rsqrt(t1100)  # t1101: "cuda:0 f32[1, 512, 1]"
    # t1102 = prims.broadcast_in_dim(t1101, (1, 512, 4096), (0, 1, 2))  # t1102: "cuda:0 f32[1, 512, 4096]"
    # t1103 = prims.mul(t1093, t1102)  # t1103: "cuda:0 f32[1, 512, 4096]"
    # t1107 = prims.convert_element_type(t1105, dtypes.float32)  # t1107: "cuda:0 f32[1, 512, 4096]"
    # t1108 = prims.mul(t1103, t1107)  # t1108: "cuda:0 f32[1, 512, 4096]"
    # t1109 = prims.convert_element_type(t1108, dtypes.bfloat16)  # t1109: "cuda:0 bf16[1, 512, 4096]"
  t1110 = torch.nn.functional.linear(t1109, t12, None)  # t1110: "cuda:0 bf16[1, 512, 12288]"
    # t1110 = ltorch.linear(t1109, t12, None)  # t1110: "cuda:0 bf16[1, 512, 12288]"
      # t1110 = prims.linear(t1109, t12, None)  # t1110: "cuda:0 bf16[1, 512, 12288]"
  t1111 = torch.reshape(t1110, (1, 512, 32, 3, 128))  # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1111 = ltorch.reshape(t1110, (1, 512, 32, 3, 128))  # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1111 = prims.reshape(t1110, (1, 512, 32, 3, 128))  # t1111: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1110
  t1112 = torch.permute(t1111, (0, 2, 3, 1, 4))  # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1112 = ltorch.permute(t1111, (0, 2, 3, 1, 4))  # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1112 = prims.transpose(t1111, (0, 2, 3, 1, 4))  # t1112: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1111
  (t1113, t1114, t1115) = torch.split(t1112, (1, 1, 1), 2)
    # (t1113, t1114, t1115) = ltorch.split(t1112, (1, 1, 1), 2)
      # t1113 = prims.slice_prim(t1112, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1113: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1114 = prims.slice_prim(t1112, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1114: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1115 = prims.slice_prim(t1112, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1115: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1112
  t1116 = torch.reshape(t1113, (1, 32, 512, 128))  # t1116: "cuda:0 bf16[1, 32, 512, 128]"
    # t1116 = ltorch.reshape(t1113, (1, 32, 512, 128))  # t1116: "cuda:0 bf16[1, 32, 512, 128]"
      # t1116 = prims.reshape(t1113, (1, 32, 512, 128))  # t1116: "cuda:0 bf16[1, 32, 512, 128]"
  del t1113
  t1117 = torch.reshape(t1114, (1, 32, 512, 128))  # t1117: "cuda:0 bf16[1, 32, 512, 128]"
    # t1117 = ltorch.reshape(t1114, (1, 32, 512, 128))  # t1117: "cuda:0 bf16[1, 32, 512, 128]"
      # t1117 = prims.reshape(t1114, (1, 32, 512, 128))  # t1117: "cuda:0 bf16[1, 32, 512, 128]"
  del t1114
  t1118 = torch.reshape(t1115, (1, 32, 512, 128))  # t1118: "cuda:0 bf16[1, 32, 512, 128]"
    # t1118 = ltorch.reshape(t1115, (1, 32, 512, 128))  # t1118: "cuda:0 bf16[1, 32, 512, 128]"
      # t1118 = prims.reshape(t1115, (1, 32, 512, 128))  # t1118: "cuda:0 bf16[1, 32, 512, 128]"
  del t1115
  t1119 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1119: "cuda:0 bf16[1, 32, 512, 128]"
  t1134 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1134: "cuda:0 bf16[1, 32, 512, 128]"
  t1149 = torch_slice_prim_impl(t1116, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1149: "cuda:0 bf16[1, 32, 512, 0]"
  del t1116
  t1151 = torch_slice_prim_impl(t1117, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1151: "cuda:0 bf16[1, 32, 512, 0]"
  del t1117
  t1120 = torch_slice_prim_impl(t1119, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1120: "cuda:0 bf16[1, 32, 512, 64]"
  t1121 = torch_slice_prim_impl(t1119, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1121: "cuda:0 bf16[1, 32, 512, 64]"
  t1136 = torch_slice_prim_impl(t1134, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1136: "cuda:0 bf16[1, 32, 512, 64]"
  t1135 = torch_slice_prim_impl(t1134, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1135: "cuda:0 bf16[1, 32, 512, 64]"
  [t1124, t1139] = nvFusion46(t1119, t1121, t1134, t1136)
    # t1122 = prims.convert_element_type(t1121, dtypes.float32)  # t1122: "cuda:0 f32[1, 32, 512, 64]"
    # t1123 = prims.neg(t1122)  # t1123: "cuda:0 f32[1, 32, 512, 64]"
    # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16)  # t1124: "cuda:0 bf16[1, 32, 512, 64]"
    # t1137 = prims.convert_element_type(t1136, dtypes.float32)  # t1137: "cuda:0 f32[1, 32, 512, 64]"
    # t1138 = prims.neg(t1137)  # t1138: "cuda:0 f32[1, 32, 512, 64]"
    # t1139 = prims.convert_element_type(t1138, dtypes.bfloat16)  # t1139: "cuda:0 bf16[1, 32, 512, 64]"
  del t1121, t1136
  t1125 = torch.cat((t1124, t1120), -1)  # t1125: "cuda:0 bf16[1, 32, 512, 128]"
    # t1125 = ltorch.cat((t1124, t1120), -1)  # t1125: "cuda:0 bf16[1, 32, 512, 128]"
      # t1125 = prims.cat((t1124, t1120), -1)  # t1125: "cuda:0 bf16[1, 32, 512, 128]"
  del t1124, t1120
  t1140 = torch.cat((t1139, t1135), -1)  # t1140: "cuda:0 bf16[1, 32, 512, 128]"
    # t1140 = ltorch.cat((t1139, t1135), -1)  # t1140: "cuda:0 bf16[1, 32, 512, 128]"
      # t1140 = prims.cat((t1139, t1135), -1)  # t1140: "cuda:0 bf16[1, 32, 512, 128]"
  del t1139, t1135
  [t1133, t1148] = nvFusion47(t1119, t1125, t1134, t1140, t154, t157)
    # t1127 = prims.convert_element_type(t1119, dtypes.float32)  # t1127: "cuda:0 f32[1, 32, 512, 128]"
    # t1142 = prims.convert_element_type(t1134, dtypes.float32)  # t1142: "cuda:0 f32[1, 32, 512, 128]"
    # t1128 = prims.mul(t1127, t154)  # t1128: "cuda:0 f32[1, 32, 512, 128]"
    # t1130 = prims.convert_element_type(t1125, dtypes.float32)  # t1130: "cuda:0 f32[1, 32, 512, 128]"
    # t1131 = prims.mul(t1130, t157)  # t1131: "cuda:0 f32[1, 32, 512, 128]"
    # t1132 = prims.add(t1128, t1131)  # t1132: "cuda:0 f32[1, 32, 512, 128]"
    # t1133 = prims.convert_element_type(t1132, dtypes.bfloat16)  # t1133: "cuda:0 bf16[1, 32, 512, 128]"
    # t1143 = prims.mul(t1142, t154)  # t1143: "cuda:0 f32[1, 32, 512, 128]"
    # t1145 = prims.convert_element_type(t1140, dtypes.float32)  # t1145: "cuda:0 f32[1, 32, 512, 128]"
    # t1146 = prims.mul(t1145, t157)  # t1146: "cuda:0 f32[1, 32, 512, 128]"
    # t1147 = prims.add(t1143, t1146)  # t1147: "cuda:0 f32[1, 32, 512, 128]"
    # t1148 = prims.convert_element_type(t1147, dtypes.bfloat16)  # t1148: "cuda:0 bf16[1, 32, 512, 128]"
  del t1119, t1125, t1134, t1140
  t1152 = torch.cat((t1148, t1151), -1)  # t1152: "cuda:0 bf16[1, 32, 512, 128]"
    # t1152 = ltorch.cat((t1148, t1151), -1)  # t1152: "cuda:0 bf16[1, 32, 512, 128]"
      # t1152 = prims.cat((t1148, t1151), -1)  # t1152: "cuda:0 bf16[1, 32, 512, 128]"
  del t1148, t1151
  t1150 = torch.cat((t1133, t1149), -1)  # t1150: "cuda:0 bf16[1, 32, 512, 128]"
    # t1150 = ltorch.cat((t1133, t1149), -1)  # t1150: "cuda:0 bf16[1, 32, 512, 128]"
      # t1150 = prims.cat((t1133, t1149), -1)  # t1150: "cuda:0 bf16[1, 32, 512, 128]"
  del t1133, t1149
  (t1153, t1154, t1155, t1156, _, _, t1157, t1158, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1150, t1152, t1118, 0.0, True, scale=0.08838834764831843)
  t1160 = torch.permute(t1153, (0, 2, 1, 3))  # t1160: "cuda:0 bf16[1, 512, 32, 128]"
    # t1160 = ltorch.permute(t1153, (0, 2, 1, 3))  # t1160: "cuda:0 bf16[1, 512, 32, 128]"
      # t1160 = prims.transpose(t1153, (0, 2, 1, 3))  # t1160: "cuda:0 bf16[1, 512, 32, 128]"
  t1161 = torch.reshape(t1160, (1, 512, 4096))  # t1161: "cuda:0 bf16[1, 512, 4096]"
    # t1161 = ltorch.reshape(t1160, (1, 512, 4096))  # t1161: "cuda:0 bf16[1, 512, 4096]"
      # t1161 = prims.reshape(t1160, (1, 512, 4096))  # t1161: "cuda:0 bf16[1, 512, 4096]"
  del t1160
  t1162 = torch.nn.functional.linear(t1161, t103, None)  # t1162: "cuda:0 bf16[1, 512, 4096]"
    # t1162 = ltorch.linear(t1161, t103, None)  # t1162: "cuda:0 bf16[1, 512, 4096]"
      # t1162 = prims.linear(t1161, t103, None)  # t1162: "cuda:0 bf16[1, 512, 4096]"
  [t1166, t1173, t1181] = nvFusion48(t1094, t1162, t1177)
    # t1164 = prims.convert_element_type(t1094, dtypes.float32)  # t1164: "cuda:0 f32[1, 512, 4096]"
    # t1163 = prims.convert_element_type(t1162, dtypes.float32)  # t1163: "cuda:0 f32[1, 512, 4096]"
    # t1165 = prims.add(t1163, t1164)  # t1165: "cuda:0 f32[1, 512, 4096]"
    # t1166 = prims.convert_element_type(t1165, dtypes.bfloat16)  # t1166: "cuda:0 bf16[1, 512, 4096]"
    # t1168 = prims.mul(t1165, t1165)  # t1168: "cuda:0 f32[1, 512, 4096]"
    # t1169 = prims.sum(t1168, (2,))  # t1169: "cuda:0 f32[1, 512]"
    # t1170 = prims.broadcast_in_dim(t1169, [1, 512, 1], [0, 1])  # t1170: "cuda:0 f32[1, 512, 1]"
    # t1171 = prims.div(t1170, 4096.0)  # t1171: "cuda:0 f32[1, 512, 1]"
    # t1172 = prims.add(t1171, 1e-05)  # t1172: "cuda:0 f32[1, 512, 1]"
    # t1173 = prims.rsqrt(t1172)  # t1173: "cuda:0 f32[1, 512, 1]"
    # t1174 = prims.broadcast_in_dim(t1173, (1, 512, 4096), (0, 1, 2))  # t1174: "cuda:0 f32[1, 512, 4096]"
    # t1175 = prims.mul(t1165, t1174)  # t1175: "cuda:0 f32[1, 512, 4096]"
    # t1179 = prims.convert_element_type(t1177, dtypes.float32)  # t1179: "cuda:0 f32[1, 512, 4096]"
    # t1180 = prims.mul(t1175, t1179)  # t1180: "cuda:0 f32[1, 512, 4096]"
    # t1181 = prims.convert_element_type(t1180, dtypes.bfloat16)  # t1181: "cuda:0 bf16[1, 512, 4096]"
  t1182 = torch.nn.functional.linear(t1181, t28, None)  # t1182: "cuda:0 bf16[1, 512, 11008]"
    # t1182 = ltorch.linear(t1181, t28, None)  # t1182: "cuda:0 bf16[1, 512, 11008]"
      # t1182 = prims.linear(t1181, t28, None)  # t1182: "cuda:0 bf16[1, 512, 11008]"
  t1183 = torch.nn.functional.linear(t1181, t44, None)  # t1183: "cuda:0 bf16[1, 512, 11008]"
    # t1183 = ltorch.linear(t1181, t44, None)  # t1183: "cuda:0 bf16[1, 512, 11008]"
      # t1183 = prims.linear(t1181, t44, None)  # t1183: "cuda:0 bf16[1, 512, 11008]"
  [t1197] = nvFusion49(t1182, t1183)
    # t1184 = prims.convert_element_type(t1182, dtypes.float32)  # t1184: "cuda:0 f32[1, 512, 11008]"
    # t1185 = prims.neg(t1184)  # t1185: "cuda:0 f32[1, 512, 11008]"
    # t1186 = prims.exp(t1185)  # t1186: "cuda:0 f32[1, 512, 11008]"
    # t1187 = prims.add(1.0, t1186)  # t1187: "cuda:0 f32[1, 512, 11008]"
    # t1188 = prims.reciprocal(t1187)  # t1188: "cuda:0 f32[1, 512, 11008]"
    # t1192 = prims.mul(t1184, t1188)  # t1192: "cuda:0 f32[1, 512, 11008]"
    # t1195 = prims.convert_element_type(t1183, dtypes.float32)  # t1195: "cuda:0 f32[1, 512, 11008]"
    # t1196 = prims.mul(t1192, t1195)  # t1196: "cuda:0 f32[1, 512, 11008]"
    # t1197 = prims.convert_element_type(t1196, dtypes.bfloat16)  # t1197: "cuda:0 bf16[1, 512, 11008]"
  t1198 = torch.nn.functional.linear(t1197, t104, None)  # t1198: "cuda:0 bf16[1, 512, 4096]"
    # t1198 = ltorch.linear(t1197, t104, None)  # t1198: "cuda:0 bf16[1, 512, 4096]"
      # t1198 = prims.linear(t1197, t104, None)  # t1198: "cuda:0 bf16[1, 512, 4096]"
  [t1202, t1209, t1217] = nvFusion50(t1166, t1198, t1213)
    # t1200 = prims.convert_element_type(t1166, dtypes.float32)  # t1200: "cuda:0 f32[1, 512, 4096]"
    # t1199 = prims.convert_element_type(t1198, dtypes.float32)  # t1199: "cuda:0 f32[1, 512, 4096]"
    # t1201 = prims.add(t1199, t1200)  # t1201: "cuda:0 f32[1, 512, 4096]"
    # t1202 = prims.convert_element_type(t1201, dtypes.bfloat16)  # t1202: "cuda:0 bf16[1, 512, 4096]"
    # t1204 = prims.mul(t1201, t1201)  # t1204: "cuda:0 f32[1, 512, 4096]"
    # t1205 = prims.sum(t1204, (2,))  # t1205: "cuda:0 f32[1, 512]"
    # t1206 = prims.broadcast_in_dim(t1205, [1, 512, 1], [0, 1])  # t1206: "cuda:0 f32[1, 512, 1]"
    # t1207 = prims.div(t1206, 4096.0)  # t1207: "cuda:0 f32[1, 512, 1]"
    # t1208 = prims.add(t1207, 1e-05)  # t1208: "cuda:0 f32[1, 512, 1]"
    # t1209 = prims.rsqrt(t1208)  # t1209: "cuda:0 f32[1, 512, 1]"
    # t1210 = prims.broadcast_in_dim(t1209, (1, 512, 4096), (0, 1, 2))  # t1210: "cuda:0 f32[1, 512, 4096]"
    # t1211 = prims.mul(t1201, t1210)  # t1211: "cuda:0 f32[1, 512, 4096]"
    # t1215 = prims.convert_element_type(t1213, dtypes.float32)  # t1215: "cuda:0 f32[1, 512, 4096]"
    # t1216 = prims.mul(t1211, t1215)  # t1216: "cuda:0 f32[1, 512, 4096]"
    # t1217 = prims.convert_element_type(t1216, dtypes.bfloat16)  # t1217: "cuda:0 bf16[1, 512, 4096]"
  t1218 = torch.nn.functional.linear(t1217, t13, None)  # t1218: "cuda:0 bf16[1, 512, 12288]"
    # t1218 = ltorch.linear(t1217, t13, None)  # t1218: "cuda:0 bf16[1, 512, 12288]"
      # t1218 = prims.linear(t1217, t13, None)  # t1218: "cuda:0 bf16[1, 512, 12288]"
  t1219 = torch.reshape(t1218, (1, 512, 32, 3, 128))  # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1219 = ltorch.reshape(t1218, (1, 512, 32, 3, 128))  # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1219 = prims.reshape(t1218, (1, 512, 32, 3, 128))  # t1219: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1218
  t1220 = torch.permute(t1219, (0, 2, 3, 1, 4))  # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1220 = ltorch.permute(t1219, (0, 2, 3, 1, 4))  # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1220 = prims.transpose(t1219, (0, 2, 3, 1, 4))  # t1220: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1219
  (t1221, t1222, t1223) = torch.split(t1220, (1, 1, 1), 2)
    # (t1221, t1222, t1223) = ltorch.split(t1220, (1, 1, 1), 2)
      # t1221 = prims.slice_prim(t1220, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1221: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1222 = prims.slice_prim(t1220, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1222: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1223 = prims.slice_prim(t1220, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1223: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1220
  t1224 = torch.reshape(t1221, (1, 32, 512, 128))  # t1224: "cuda:0 bf16[1, 32, 512, 128]"
    # t1224 = ltorch.reshape(t1221, (1, 32, 512, 128))  # t1224: "cuda:0 bf16[1, 32, 512, 128]"
      # t1224 = prims.reshape(t1221, (1, 32, 512, 128))  # t1224: "cuda:0 bf16[1, 32, 512, 128]"
  del t1221
  t1225 = torch.reshape(t1222, (1, 32, 512, 128))  # t1225: "cuda:0 bf16[1, 32, 512, 128]"
    # t1225 = ltorch.reshape(t1222, (1, 32, 512, 128))  # t1225: "cuda:0 bf16[1, 32, 512, 128]"
      # t1225 = prims.reshape(t1222, (1, 32, 512, 128))  # t1225: "cuda:0 bf16[1, 32, 512, 128]"
  del t1222
  t1226 = torch.reshape(t1223, (1, 32, 512, 128))  # t1226: "cuda:0 bf16[1, 32, 512, 128]"
    # t1226 = ltorch.reshape(t1223, (1, 32, 512, 128))  # t1226: "cuda:0 bf16[1, 32, 512, 128]"
      # t1226 = prims.reshape(t1223, (1, 32, 512, 128))  # t1226: "cuda:0 bf16[1, 32, 512, 128]"
  del t1223
  t1227 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1227: "cuda:0 bf16[1, 32, 512, 128]"
  t1242 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1242: "cuda:0 bf16[1, 32, 512, 128]"
  t1257 = torch_slice_prim_impl(t1224, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1257: "cuda:0 bf16[1, 32, 512, 0]"
  del t1224
  t1259 = torch_slice_prim_impl(t1225, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1259: "cuda:0 bf16[1, 32, 512, 0]"
  del t1225
  t1228 = torch_slice_prim_impl(t1227, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1228: "cuda:0 bf16[1, 32, 512, 64]"
  t1229 = torch_slice_prim_impl(t1227, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1229: "cuda:0 bf16[1, 32, 512, 64]"
  t1243 = torch_slice_prim_impl(t1242, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1243: "cuda:0 bf16[1, 32, 512, 64]"
  t1244 = torch_slice_prim_impl(t1242, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1244: "cuda:0 bf16[1, 32, 512, 64]"
  [t1232, t1247] = nvFusion51(t1227, t1229, t1242, t1244)
    # t1230 = prims.convert_element_type(t1229, dtypes.float32)  # t1230: "cuda:0 f32[1, 32, 512, 64]"
    # t1231 = prims.neg(t1230)  # t1231: "cuda:0 f32[1, 32, 512, 64]"
    # t1232 = prims.convert_element_type(t1231, dtypes.bfloat16)  # t1232: "cuda:0 bf16[1, 32, 512, 64]"
    # t1245 = prims.convert_element_type(t1244, dtypes.float32)  # t1245: "cuda:0 f32[1, 32, 512, 64]"
    # t1246 = prims.neg(t1245)  # t1246: "cuda:0 f32[1, 32, 512, 64]"
    # t1247 = prims.convert_element_type(t1246, dtypes.bfloat16)  # t1247: "cuda:0 bf16[1, 32, 512, 64]"
  del t1229, t1244
  t1233 = torch.cat((t1232, t1228), -1)  # t1233: "cuda:0 bf16[1, 32, 512, 128]"
    # t1233 = ltorch.cat((t1232, t1228), -1)  # t1233: "cuda:0 bf16[1, 32, 512, 128]"
      # t1233 = prims.cat((t1232, t1228), -1)  # t1233: "cuda:0 bf16[1, 32, 512, 128]"
  del t1232, t1228
  t1248 = torch.cat((t1247, t1243), -1)  # t1248: "cuda:0 bf16[1, 32, 512, 128]"
    # t1248 = ltorch.cat((t1247, t1243), -1)  # t1248: "cuda:0 bf16[1, 32, 512, 128]"
      # t1248 = prims.cat((t1247, t1243), -1)  # t1248: "cuda:0 bf16[1, 32, 512, 128]"
  del t1247, t1243
  [t1241, t1256] = nvFusion52(t1227, t1233, t1242, t1248, t154, t157)
    # t1235 = prims.convert_element_type(t1227, dtypes.float32)  # t1235: "cuda:0 f32[1, 32, 512, 128]"
    # t1250 = prims.convert_element_type(t1242, dtypes.float32)  # t1250: "cuda:0 f32[1, 32, 512, 128]"
    # t1236 = prims.mul(t1235, t154)  # t1236: "cuda:0 f32[1, 32, 512, 128]"
    # t1238 = prims.convert_element_type(t1233, dtypes.float32)  # t1238: "cuda:0 f32[1, 32, 512, 128]"
    # t1239 = prims.mul(t1238, t157)  # t1239: "cuda:0 f32[1, 32, 512, 128]"
    # t1240 = prims.add(t1236, t1239)  # t1240: "cuda:0 f32[1, 32, 512, 128]"
    # t1241 = prims.convert_element_type(t1240, dtypes.bfloat16)  # t1241: "cuda:0 bf16[1, 32, 512, 128]"
    # t1251 = prims.mul(t1250, t154)  # t1251: "cuda:0 f32[1, 32, 512, 128]"
    # t1253 = prims.convert_element_type(t1248, dtypes.float32)  # t1253: "cuda:0 f32[1, 32, 512, 128]"
    # t1254 = prims.mul(t1253, t157)  # t1254: "cuda:0 f32[1, 32, 512, 128]"
    # t1255 = prims.add(t1251, t1254)  # t1255: "cuda:0 f32[1, 32, 512, 128]"
    # t1256 = prims.convert_element_type(t1255, dtypes.bfloat16)  # t1256: "cuda:0 bf16[1, 32, 512, 128]"
  del t1227, t1233, t1242, t1248
  t1258 = torch.cat((t1241, t1257), -1)  # t1258: "cuda:0 bf16[1, 32, 512, 128]"
    # t1258 = ltorch.cat((t1241, t1257), -1)  # t1258: "cuda:0 bf16[1, 32, 512, 128]"
      # t1258 = prims.cat((t1241, t1257), -1)  # t1258: "cuda:0 bf16[1, 32, 512, 128]"
  del t1241, t1257
  t1260 = torch.cat((t1256, t1259), -1)  # t1260: "cuda:0 bf16[1, 32, 512, 128]"
    # t1260 = ltorch.cat((t1256, t1259), -1)  # t1260: "cuda:0 bf16[1, 32, 512, 128]"
      # t1260 = prims.cat((t1256, t1259), -1)  # t1260: "cuda:0 bf16[1, 32, 512, 128]"
  del t1256, t1259
  (t1261, t1262, t1263, t1264, _, _, t1265, t1266, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1258, t1260, t1226, 0.0, True, scale=0.08838834764831843)
  t1268 = torch.permute(t1261, (0, 2, 1, 3))  # t1268: "cuda:0 bf16[1, 512, 32, 128]"
    # t1268 = ltorch.permute(t1261, (0, 2, 1, 3))  # t1268: "cuda:0 bf16[1, 512, 32, 128]"
      # t1268 = prims.transpose(t1261, (0, 2, 1, 3))  # t1268: "cuda:0 bf16[1, 512, 32, 128]"
  t1269 = torch.reshape(t1268, (1, 512, 4096))  # t1269: "cuda:0 bf16[1, 512, 4096]"
    # t1269 = ltorch.reshape(t1268, (1, 512, 4096))  # t1269: "cuda:0 bf16[1, 512, 4096]"
      # t1269 = prims.reshape(t1268, (1, 512, 4096))  # t1269: "cuda:0 bf16[1, 512, 4096]"
  del t1268
  t1270 = torch.nn.functional.linear(t1269, t105, None)  # t1270: "cuda:0 bf16[1, 512, 4096]"
    # t1270 = ltorch.linear(t1269, t105, None)  # t1270: "cuda:0 bf16[1, 512, 4096]"
      # t1270 = prims.linear(t1269, t105, None)  # t1270: "cuda:0 bf16[1, 512, 4096]"
  [t1274, t1281, t1289] = nvFusion53(t1202, t1270, t1285)
    # t1272 = prims.convert_element_type(t1202, dtypes.float32)  # t1272: "cuda:0 f32[1, 512, 4096]"
    # t1271 = prims.convert_element_type(t1270, dtypes.float32)  # t1271: "cuda:0 f32[1, 512, 4096]"
    # t1273 = prims.add(t1271, t1272)  # t1273: "cuda:0 f32[1, 512, 4096]"
    # t1274 = prims.convert_element_type(t1273, dtypes.bfloat16)  # t1274: "cuda:0 bf16[1, 512, 4096]"
    # t1276 = prims.mul(t1273, t1273)  # t1276: "cuda:0 f32[1, 512, 4096]"
    # t1277 = prims.sum(t1276, (2,))  # t1277: "cuda:0 f32[1, 512]"
    # t1278 = prims.broadcast_in_dim(t1277, [1, 512, 1], [0, 1])  # t1278: "cuda:0 f32[1, 512, 1]"
    # t1279 = prims.div(t1278, 4096.0)  # t1279: "cuda:0 f32[1, 512, 1]"
    # t1280 = prims.add(t1279, 1e-05)  # t1280: "cuda:0 f32[1, 512, 1]"
    # t1281 = prims.rsqrt(t1280)  # t1281: "cuda:0 f32[1, 512, 1]"
    # t1282 = prims.broadcast_in_dim(t1281, (1, 512, 4096), (0, 1, 2))  # t1282: "cuda:0 f32[1, 512, 4096]"
    # t1283 = prims.mul(t1273, t1282)  # t1283: "cuda:0 f32[1, 512, 4096]"
    # t1287 = prims.convert_element_type(t1285, dtypes.float32)  # t1287: "cuda:0 f32[1, 512, 4096]"
    # t1288 = prims.mul(t1283, t1287)  # t1288: "cuda:0 f32[1, 512, 4096]"
    # t1289 = prims.convert_element_type(t1288, dtypes.bfloat16)  # t1289: "cuda:0 bf16[1, 512, 4096]"
  t1290 = torch.nn.functional.linear(t1289, t29, None)  # t1290: "cuda:0 bf16[1, 512, 11008]"
    # t1290 = ltorch.linear(t1289, t29, None)  # t1290: "cuda:0 bf16[1, 512, 11008]"
      # t1290 = prims.linear(t1289, t29, None)  # t1290: "cuda:0 bf16[1, 512, 11008]"
  t1291 = torch.nn.functional.linear(t1289, t45, None)  # t1291: "cuda:0 bf16[1, 512, 11008]"
    # t1291 = ltorch.linear(t1289, t45, None)  # t1291: "cuda:0 bf16[1, 512, 11008]"
      # t1291 = prims.linear(t1289, t45, None)  # t1291: "cuda:0 bf16[1, 512, 11008]"
  [t1305] = nvFusion54(t1290, t1291)
    # t1292 = prims.convert_element_type(t1290, dtypes.float32)  # t1292: "cuda:0 f32[1, 512, 11008]"
    # t1293 = prims.neg(t1292)  # t1293: "cuda:0 f32[1, 512, 11008]"
    # t1294 = prims.exp(t1293)  # t1294: "cuda:0 f32[1, 512, 11008]"
    # t1295 = prims.add(1.0, t1294)  # t1295: "cuda:0 f32[1, 512, 11008]"
    # t1296 = prims.reciprocal(t1295)  # t1296: "cuda:0 f32[1, 512, 11008]"
    # t1300 = prims.mul(t1292, t1296)  # t1300: "cuda:0 f32[1, 512, 11008]"
    # t1303 = prims.convert_element_type(t1291, dtypes.float32)  # t1303: "cuda:0 f32[1, 512, 11008]"
    # t1304 = prims.mul(t1300, t1303)  # t1304: "cuda:0 f32[1, 512, 11008]"
    # t1305 = prims.convert_element_type(t1304, dtypes.bfloat16)  # t1305: "cuda:0 bf16[1, 512, 11008]"
  t1306 = torch.nn.functional.linear(t1305, t106, None)  # t1306: "cuda:0 bf16[1, 512, 4096]"
    # t1306 = ltorch.linear(t1305, t106, None)  # t1306: "cuda:0 bf16[1, 512, 4096]"
      # t1306 = prims.linear(t1305, t106, None)  # t1306: "cuda:0 bf16[1, 512, 4096]"
  [t1310, t1317, t1325] = nvFusion55(t1274, t1306, t1321)
    # t1308 = prims.convert_element_type(t1274, dtypes.float32)  # t1308: "cuda:0 f32[1, 512, 4096]"
    # t1307 = prims.convert_element_type(t1306, dtypes.float32)  # t1307: "cuda:0 f32[1, 512, 4096]"
    # t1309 = prims.add(t1307, t1308)  # t1309: "cuda:0 f32[1, 512, 4096]"
    # t1310 = prims.convert_element_type(t1309, dtypes.bfloat16)  # t1310: "cuda:0 bf16[1, 512, 4096]"
    # t1312 = prims.mul(t1309, t1309)  # t1312: "cuda:0 f32[1, 512, 4096]"
    # t1313 = prims.sum(t1312, (2,))  # t1313: "cuda:0 f32[1, 512]"
    # t1314 = prims.broadcast_in_dim(t1313, [1, 512, 1], [0, 1])  # t1314: "cuda:0 f32[1, 512, 1]"
    # t1315 = prims.div(t1314, 4096.0)  # t1315: "cuda:0 f32[1, 512, 1]"
    # t1316 = prims.add(t1315, 1e-05)  # t1316: "cuda:0 f32[1, 512, 1]"
    # t1317 = prims.rsqrt(t1316)  # t1317: "cuda:0 f32[1, 512, 1]"
    # t1318 = prims.broadcast_in_dim(t1317, (1, 512, 4096), (0, 1, 2))  # t1318: "cuda:0 f32[1, 512, 4096]"
    # t1319 = prims.mul(t1309, t1318)  # t1319: "cuda:0 f32[1, 512, 4096]"
    # t1323 = prims.convert_element_type(t1321, dtypes.float32)  # t1323: "cuda:0 f32[1, 512, 4096]"
    # t1324 = prims.mul(t1319, t1323)  # t1324: "cuda:0 f32[1, 512, 4096]"
    # t1325 = prims.convert_element_type(t1324, dtypes.bfloat16)  # t1325: "cuda:0 bf16[1, 512, 4096]"
  t1326 = torch.nn.functional.linear(t1325, t14, None)  # t1326: "cuda:0 bf16[1, 512, 12288]"
    # t1326 = ltorch.linear(t1325, t14, None)  # t1326: "cuda:0 bf16[1, 512, 12288]"
      # t1326 = prims.linear(t1325, t14, None)  # t1326: "cuda:0 bf16[1, 512, 12288]"
  t1327 = torch.reshape(t1326, (1, 512, 32, 3, 128))  # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1327 = ltorch.reshape(t1326, (1, 512, 32, 3, 128))  # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1327 = prims.reshape(t1326, (1, 512, 32, 3, 128))  # t1327: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1326
  t1328 = torch.permute(t1327, (0, 2, 3, 1, 4))  # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1328 = ltorch.permute(t1327, (0, 2, 3, 1, 4))  # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1328 = prims.transpose(t1327, (0, 2, 3, 1, 4))  # t1328: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1327
  (t1329, t1330, t1331) = torch.split(t1328, (1, 1, 1), 2)
    # (t1329, t1330, t1331) = ltorch.split(t1328, (1, 1, 1), 2)
      # t1329 = prims.slice_prim(t1328, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1329: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1330 = prims.slice_prim(t1328, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1330: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1331 = prims.slice_prim(t1328, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1331: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1328
  t1332 = torch.reshape(t1329, (1, 32, 512, 128))  # t1332: "cuda:0 bf16[1, 32, 512, 128]"
    # t1332 = ltorch.reshape(t1329, (1, 32, 512, 128))  # t1332: "cuda:0 bf16[1, 32, 512, 128]"
      # t1332 = prims.reshape(t1329, (1, 32, 512, 128))  # t1332: "cuda:0 bf16[1, 32, 512, 128]"
  del t1329
  t1333 = torch.reshape(t1330, (1, 32, 512, 128))  # t1333: "cuda:0 bf16[1, 32, 512, 128]"
    # t1333 = ltorch.reshape(t1330, (1, 32, 512, 128))  # t1333: "cuda:0 bf16[1, 32, 512, 128]"
      # t1333 = prims.reshape(t1330, (1, 32, 512, 128))  # t1333: "cuda:0 bf16[1, 32, 512, 128]"
  del t1330
  t1334 = torch.reshape(t1331, (1, 32, 512, 128))  # t1334: "cuda:0 bf16[1, 32, 512, 128]"
    # t1334 = ltorch.reshape(t1331, (1, 32, 512, 128))  # t1334: "cuda:0 bf16[1, 32, 512, 128]"
      # t1334 = prims.reshape(t1331, (1, 32, 512, 128))  # t1334: "cuda:0 bf16[1, 32, 512, 128]"
  del t1331
  t1335 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1335: "cuda:0 bf16[1, 32, 512, 128]"
  t1350 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1350: "cuda:0 bf16[1, 32, 512, 128]"
  t1365 = torch_slice_prim_impl(t1332, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1365: "cuda:0 bf16[1, 32, 512, 0]"
  del t1332
  t1367 = torch_slice_prim_impl(t1333, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1367: "cuda:0 bf16[1, 32, 512, 0]"
  del t1333
  t1336 = torch_slice_prim_impl(t1335, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1336: "cuda:0 bf16[1, 32, 512, 64]"
  t1337 = torch_slice_prim_impl(t1335, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1337: "cuda:0 bf16[1, 32, 512, 64]"
  t1351 = torch_slice_prim_impl(t1350, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1351: "cuda:0 bf16[1, 32, 512, 64]"
  t1352 = torch_slice_prim_impl(t1350, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1352: "cuda:0 bf16[1, 32, 512, 64]"
  [t1340, t1355] = nvFusion56(t1335, t1337, t1350, t1352)
    # t1338 = prims.convert_element_type(t1337, dtypes.float32)  # t1338: "cuda:0 f32[1, 32, 512, 64]"
    # t1339 = prims.neg(t1338)  # t1339: "cuda:0 f32[1, 32, 512, 64]"
    # t1340 = prims.convert_element_type(t1339, dtypes.bfloat16)  # t1340: "cuda:0 bf16[1, 32, 512, 64]"
    # t1353 = prims.convert_element_type(t1352, dtypes.float32)  # t1353: "cuda:0 f32[1, 32, 512, 64]"
    # t1354 = prims.neg(t1353)  # t1354: "cuda:0 f32[1, 32, 512, 64]"
    # t1355 = prims.convert_element_type(t1354, dtypes.bfloat16)  # t1355: "cuda:0 bf16[1, 32, 512, 64]"
  del t1337, t1352
  t1341 = torch.cat((t1340, t1336), -1)  # t1341: "cuda:0 bf16[1, 32, 512, 128]"
    # t1341 = ltorch.cat((t1340, t1336), -1)  # t1341: "cuda:0 bf16[1, 32, 512, 128]"
      # t1341 = prims.cat((t1340, t1336), -1)  # t1341: "cuda:0 bf16[1, 32, 512, 128]"
  del t1340, t1336
  t1356 = torch.cat((t1355, t1351), -1)  # t1356: "cuda:0 bf16[1, 32, 512, 128]"
    # t1356 = ltorch.cat((t1355, t1351), -1)  # t1356: "cuda:0 bf16[1, 32, 512, 128]"
      # t1356 = prims.cat((t1355, t1351), -1)  # t1356: "cuda:0 bf16[1, 32, 512, 128]"
  del t1355, t1351
  [t1349, t1364] = nvFusion57(t1335, t1341, t1350, t1356, t154, t157)
    # t1343 = prims.convert_element_type(t1335, dtypes.float32)  # t1343: "cuda:0 f32[1, 32, 512, 128]"
    # t1358 = prims.convert_element_type(t1350, dtypes.float32)  # t1358: "cuda:0 f32[1, 32, 512, 128]"
    # t1344 = prims.mul(t1343, t154)  # t1344: "cuda:0 f32[1, 32, 512, 128]"
    # t1346 = prims.convert_element_type(t1341, dtypes.float32)  # t1346: "cuda:0 f32[1, 32, 512, 128]"
    # t1347 = prims.mul(t1346, t157)  # t1347: "cuda:0 f32[1, 32, 512, 128]"
    # t1348 = prims.add(t1344, t1347)  # t1348: "cuda:0 f32[1, 32, 512, 128]"
    # t1349 = prims.convert_element_type(t1348, dtypes.bfloat16)  # t1349: "cuda:0 bf16[1, 32, 512, 128]"
    # t1359 = prims.mul(t1358, t154)  # t1359: "cuda:0 f32[1, 32, 512, 128]"
    # t1361 = prims.convert_element_type(t1356, dtypes.float32)  # t1361: "cuda:0 f32[1, 32, 512, 128]"
    # t1362 = prims.mul(t1361, t157)  # t1362: "cuda:0 f32[1, 32, 512, 128]"
    # t1363 = prims.add(t1359, t1362)  # t1363: "cuda:0 f32[1, 32, 512, 128]"
    # t1364 = prims.convert_element_type(t1363, dtypes.bfloat16)  # t1364: "cuda:0 bf16[1, 32, 512, 128]"
  del t1335, t1341, t1350, t1356
  t1366 = torch.cat((t1349, t1365), -1)  # t1366: "cuda:0 bf16[1, 32, 512, 128]"
    # t1366 = ltorch.cat((t1349, t1365), -1)  # t1366: "cuda:0 bf16[1, 32, 512, 128]"
      # t1366 = prims.cat((t1349, t1365), -1)  # t1366: "cuda:0 bf16[1, 32, 512, 128]"
  del t1349, t1365
  t1368 = torch.cat((t1364, t1367), -1)  # t1368: "cuda:0 bf16[1, 32, 512, 128]"
    # t1368 = ltorch.cat((t1364, t1367), -1)  # t1368: "cuda:0 bf16[1, 32, 512, 128]"
      # t1368 = prims.cat((t1364, t1367), -1)  # t1368: "cuda:0 bf16[1, 32, 512, 128]"
  del t1364, t1367
  (t1369, t1370, t1371, t1372, _, _, t1373, t1374, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1366, t1368, t1334, 0.0, True, scale=0.08838834764831843)
  t1376 = torch.permute(t1369, (0, 2, 1, 3))  # t1376: "cuda:0 bf16[1, 512, 32, 128]"
    # t1376 = ltorch.permute(t1369, (0, 2, 1, 3))  # t1376: "cuda:0 bf16[1, 512, 32, 128]"
      # t1376 = prims.transpose(t1369, (0, 2, 1, 3))  # t1376: "cuda:0 bf16[1, 512, 32, 128]"
  t1377 = torch.reshape(t1376, (1, 512, 4096))  # t1377: "cuda:0 bf16[1, 512, 4096]"
    # t1377 = ltorch.reshape(t1376, (1, 512, 4096))  # t1377: "cuda:0 bf16[1, 512, 4096]"
      # t1377 = prims.reshape(t1376, (1, 512, 4096))  # t1377: "cuda:0 bf16[1, 512, 4096]"
  del t1376
  t1378 = torch.nn.functional.linear(t1377, t107, None)  # t1378: "cuda:0 bf16[1, 512, 4096]"
    # t1378 = ltorch.linear(t1377, t107, None)  # t1378: "cuda:0 bf16[1, 512, 4096]"
      # t1378 = prims.linear(t1377, t107, None)  # t1378: "cuda:0 bf16[1, 512, 4096]"
  [t1382, t1389, t1397] = nvFusion58(t1310, t1378, t1393)
    # t1380 = prims.convert_element_type(t1310, dtypes.float32)  # t1380: "cuda:0 f32[1, 512, 4096]"
    # t1379 = prims.convert_element_type(t1378, dtypes.float32)  # t1379: "cuda:0 f32[1, 512, 4096]"
    # t1381 = prims.add(t1379, t1380)  # t1381: "cuda:0 f32[1, 512, 4096]"
    # t1382 = prims.convert_element_type(t1381, dtypes.bfloat16)  # t1382: "cuda:0 bf16[1, 512, 4096]"
    # t1384 = prims.mul(t1381, t1381)  # t1384: "cuda:0 f32[1, 512, 4096]"
    # t1385 = prims.sum(t1384, (2,))  # t1385: "cuda:0 f32[1, 512]"
    # t1386 = prims.broadcast_in_dim(t1385, [1, 512, 1], [0, 1])  # t1386: "cuda:0 f32[1, 512, 1]"
    # t1387 = prims.div(t1386, 4096.0)  # t1387: "cuda:0 f32[1, 512, 1]"
    # t1388 = prims.add(t1387, 1e-05)  # t1388: "cuda:0 f32[1, 512, 1]"
    # t1389 = prims.rsqrt(t1388)  # t1389: "cuda:0 f32[1, 512, 1]"
    # t1390 = prims.broadcast_in_dim(t1389, (1, 512, 4096), (0, 1, 2))  # t1390: "cuda:0 f32[1, 512, 4096]"
    # t1391 = prims.mul(t1381, t1390)  # t1391: "cuda:0 f32[1, 512, 4096]"
    # t1395 = prims.convert_element_type(t1393, dtypes.float32)  # t1395: "cuda:0 f32[1, 512, 4096]"
    # t1396 = prims.mul(t1391, t1395)  # t1396: "cuda:0 f32[1, 512, 4096]"
    # t1397 = prims.convert_element_type(t1396, dtypes.bfloat16)  # t1397: "cuda:0 bf16[1, 512, 4096]"
  t1398 = torch.nn.functional.linear(t1397, t30, None)  # t1398: "cuda:0 bf16[1, 512, 11008]"
    # t1398 = ltorch.linear(t1397, t30, None)  # t1398: "cuda:0 bf16[1, 512, 11008]"
      # t1398 = prims.linear(t1397, t30, None)  # t1398: "cuda:0 bf16[1, 512, 11008]"
  t1399 = torch.nn.functional.linear(t1397, t46, None)  # t1399: "cuda:0 bf16[1, 512, 11008]"
    # t1399 = ltorch.linear(t1397, t46, None)  # t1399: "cuda:0 bf16[1, 512, 11008]"
      # t1399 = prims.linear(t1397, t46, None)  # t1399: "cuda:0 bf16[1, 512, 11008]"
  [t1413] = nvFusion59(t1398, t1399)
    # t1400 = prims.convert_element_type(t1398, dtypes.float32)  # t1400: "cuda:0 f32[1, 512, 11008]"
    # t1401 = prims.neg(t1400)  # t1401: "cuda:0 f32[1, 512, 11008]"
    # t1402 = prims.exp(t1401)  # t1402: "cuda:0 f32[1, 512, 11008]"
    # t1403 = prims.add(1.0, t1402)  # t1403: "cuda:0 f32[1, 512, 11008]"
    # t1404 = prims.reciprocal(t1403)  # t1404: "cuda:0 f32[1, 512, 11008]"
    # t1408 = prims.mul(t1400, t1404)  # t1408: "cuda:0 f32[1, 512, 11008]"
    # t1411 = prims.convert_element_type(t1399, dtypes.float32)  # t1411: "cuda:0 f32[1, 512, 11008]"
    # t1412 = prims.mul(t1408, t1411)  # t1412: "cuda:0 f32[1, 512, 11008]"
    # t1413 = prims.convert_element_type(t1412, dtypes.bfloat16)  # t1413: "cuda:0 bf16[1, 512, 11008]"
  t1414 = torch.nn.functional.linear(t1413, t108, None)  # t1414: "cuda:0 bf16[1, 512, 4096]"
    # t1414 = ltorch.linear(t1413, t108, None)  # t1414: "cuda:0 bf16[1, 512, 4096]"
      # t1414 = prims.linear(t1413, t108, None)  # t1414: "cuda:0 bf16[1, 512, 4096]"
  [t1418, t1425, t1433] = nvFusion60(t1382, t1414, t1429)
    # t1416 = prims.convert_element_type(t1382, dtypes.float32)  # t1416: "cuda:0 f32[1, 512, 4096]"
    # t1415 = prims.convert_element_type(t1414, dtypes.float32)  # t1415: "cuda:0 f32[1, 512, 4096]"
    # t1417 = prims.add(t1415, t1416)  # t1417: "cuda:0 f32[1, 512, 4096]"
    # t1418 = prims.convert_element_type(t1417, dtypes.bfloat16)  # t1418: "cuda:0 bf16[1, 512, 4096]"
    # t1420 = prims.mul(t1417, t1417)  # t1420: "cuda:0 f32[1, 512, 4096]"
    # t1421 = prims.sum(t1420, (2,))  # t1421: "cuda:0 f32[1, 512]"
    # t1422 = prims.broadcast_in_dim(t1421, [1, 512, 1], [0, 1])  # t1422: "cuda:0 f32[1, 512, 1]"
    # t1423 = prims.div(t1422, 4096.0)  # t1423: "cuda:0 f32[1, 512, 1]"
    # t1424 = prims.add(t1423, 1e-05)  # t1424: "cuda:0 f32[1, 512, 1]"
    # t1425 = prims.rsqrt(t1424)  # t1425: "cuda:0 f32[1, 512, 1]"
    # t1426 = prims.broadcast_in_dim(t1425, (1, 512, 4096), (0, 1, 2))  # t1426: "cuda:0 f32[1, 512, 4096]"
    # t1427 = prims.mul(t1417, t1426)  # t1427: "cuda:0 f32[1, 512, 4096]"
    # t1431 = prims.convert_element_type(t1429, dtypes.float32)  # t1431: "cuda:0 f32[1, 512, 4096]"
    # t1432 = prims.mul(t1427, t1431)  # t1432: "cuda:0 f32[1, 512, 4096]"
    # t1433 = prims.convert_element_type(t1432, dtypes.bfloat16)  # t1433: "cuda:0 bf16[1, 512, 4096]"
  t1434 = torch.nn.functional.linear(t1433, t15, None)  # t1434: "cuda:0 bf16[1, 512, 12288]"
    # t1434 = ltorch.linear(t1433, t15, None)  # t1434: "cuda:0 bf16[1, 512, 12288]"
      # t1434 = prims.linear(t1433, t15, None)  # t1434: "cuda:0 bf16[1, 512, 12288]"
  t1435 = torch.reshape(t1434, (1, 512, 32, 3, 128))  # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1435 = ltorch.reshape(t1434, (1, 512, 32, 3, 128))  # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1435 = prims.reshape(t1434, (1, 512, 32, 3, 128))  # t1435: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1434
  t1436 = torch.permute(t1435, (0, 2, 3, 1, 4))  # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1436 = ltorch.permute(t1435, (0, 2, 3, 1, 4))  # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1436 = prims.transpose(t1435, (0, 2, 3, 1, 4))  # t1436: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1435
  (t1437, t1438, t1439) = torch.split(t1436, (1, 1, 1), 2)
    # (t1437, t1438, t1439) = ltorch.split(t1436, (1, 1, 1), 2)
      # t1437 = prims.slice_prim(t1436, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1437: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1438 = prims.slice_prim(t1436, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1438: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1439 = prims.slice_prim(t1436, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1439: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1436
  t1440 = torch.reshape(t1437, (1, 32, 512, 128))  # t1440: "cuda:0 bf16[1, 32, 512, 128]"
    # t1440 = ltorch.reshape(t1437, (1, 32, 512, 128))  # t1440: "cuda:0 bf16[1, 32, 512, 128]"
      # t1440 = prims.reshape(t1437, (1, 32, 512, 128))  # t1440: "cuda:0 bf16[1, 32, 512, 128]"
  del t1437
  t1441 = torch.reshape(t1438, (1, 32, 512, 128))  # t1441: "cuda:0 bf16[1, 32, 512, 128]"
    # t1441 = ltorch.reshape(t1438, (1, 32, 512, 128))  # t1441: "cuda:0 bf16[1, 32, 512, 128]"
      # t1441 = prims.reshape(t1438, (1, 32, 512, 128))  # t1441: "cuda:0 bf16[1, 32, 512, 128]"
  del t1438
  t1442 = torch.reshape(t1439, (1, 32, 512, 128))  # t1442: "cuda:0 bf16[1, 32, 512, 128]"
    # t1442 = ltorch.reshape(t1439, (1, 32, 512, 128))  # t1442: "cuda:0 bf16[1, 32, 512, 128]"
      # t1442 = prims.reshape(t1439, (1, 32, 512, 128))  # t1442: "cuda:0 bf16[1, 32, 512, 128]"
  del t1439
  t1443 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1443: "cuda:0 bf16[1, 32, 512, 128]"
  t1458 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1458: "cuda:0 bf16[1, 32, 512, 128]"
  t1473 = torch_slice_prim_impl(t1440, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1473: "cuda:0 bf16[1, 32, 512, 0]"
  del t1440
  t1475 = torch_slice_prim_impl(t1441, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1475: "cuda:0 bf16[1, 32, 512, 0]"
  del t1441
  t1444 = torch_slice_prim_impl(t1443, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1444: "cuda:0 bf16[1, 32, 512, 64]"
  t1445 = torch_slice_prim_impl(t1443, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1445: "cuda:0 bf16[1, 32, 512, 64]"
  t1459 = torch_slice_prim_impl(t1458, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1459: "cuda:0 bf16[1, 32, 512, 64]"
  t1460 = torch_slice_prim_impl(t1458, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1460: "cuda:0 bf16[1, 32, 512, 64]"
  [t1448, t1463] = nvFusion61(t1443, t1445, t1458, t1460)
    # t1446 = prims.convert_element_type(t1445, dtypes.float32)  # t1446: "cuda:0 f32[1, 32, 512, 64]"
    # t1447 = prims.neg(t1446)  # t1447: "cuda:0 f32[1, 32, 512, 64]"
    # t1448 = prims.convert_element_type(t1447, dtypes.bfloat16)  # t1448: "cuda:0 bf16[1, 32, 512, 64]"
    # t1461 = prims.convert_element_type(t1460, dtypes.float32)  # t1461: "cuda:0 f32[1, 32, 512, 64]"
    # t1462 = prims.neg(t1461)  # t1462: "cuda:0 f32[1, 32, 512, 64]"
    # t1463 = prims.convert_element_type(t1462, dtypes.bfloat16)  # t1463: "cuda:0 bf16[1, 32, 512, 64]"
  del t1445, t1460
  t1464 = torch.cat((t1463, t1459), -1)  # t1464: "cuda:0 bf16[1, 32, 512, 128]"
    # t1464 = ltorch.cat((t1463, t1459), -1)  # t1464: "cuda:0 bf16[1, 32, 512, 128]"
      # t1464 = prims.cat((t1463, t1459), -1)  # t1464: "cuda:0 bf16[1, 32, 512, 128]"
  del t1463, t1459
  t1449 = torch.cat((t1448, t1444), -1)  # t1449: "cuda:0 bf16[1, 32, 512, 128]"
    # t1449 = ltorch.cat((t1448, t1444), -1)  # t1449: "cuda:0 bf16[1, 32, 512, 128]"
      # t1449 = prims.cat((t1448, t1444), -1)  # t1449: "cuda:0 bf16[1, 32, 512, 128]"
  del t1448, t1444
  [t1457, t1472] = nvFusion62(t1443, t1449, t1458, t1464, t154, t157)
    # t1451 = prims.convert_element_type(t1443, dtypes.float32)  # t1451: "cuda:0 f32[1, 32, 512, 128]"
    # t1466 = prims.convert_element_type(t1458, dtypes.float32)  # t1466: "cuda:0 f32[1, 32, 512, 128]"
    # t1467 = prims.mul(t1466, t154)  # t1467: "cuda:0 f32[1, 32, 512, 128]"
    # t1469 = prims.convert_element_type(t1464, dtypes.float32)  # t1469: "cuda:0 f32[1, 32, 512, 128]"
    # t1470 = prims.mul(t1469, t157)  # t1470: "cuda:0 f32[1, 32, 512, 128]"
    # t1471 = prims.add(t1467, t1470)  # t1471: "cuda:0 f32[1, 32, 512, 128]"
    # t1472 = prims.convert_element_type(t1471, dtypes.bfloat16)  # t1472: "cuda:0 bf16[1, 32, 512, 128]"
    # t1452 = prims.mul(t1451, t154)  # t1452: "cuda:0 f32[1, 32, 512, 128]"
    # t1454 = prims.convert_element_type(t1449, dtypes.float32)  # t1454: "cuda:0 f32[1, 32, 512, 128]"
    # t1455 = prims.mul(t1454, t157)  # t1455: "cuda:0 f32[1, 32, 512, 128]"
    # t1456 = prims.add(t1452, t1455)  # t1456: "cuda:0 f32[1, 32, 512, 128]"
    # t1457 = prims.convert_element_type(t1456, dtypes.bfloat16)  # t1457: "cuda:0 bf16[1, 32, 512, 128]"
  del t1443, t1449, t1458, t1464
  t1476 = torch.cat((t1472, t1475), -1)  # t1476: "cuda:0 bf16[1, 32, 512, 128]"
    # t1476 = ltorch.cat((t1472, t1475), -1)  # t1476: "cuda:0 bf16[1, 32, 512, 128]"
      # t1476 = prims.cat((t1472, t1475), -1)  # t1476: "cuda:0 bf16[1, 32, 512, 128]"
  del t1472, t1475
  t1474 = torch.cat((t1457, t1473), -1)  # t1474: "cuda:0 bf16[1, 32, 512, 128]"
    # t1474 = ltorch.cat((t1457, t1473), -1)  # t1474: "cuda:0 bf16[1, 32, 512, 128]"
      # t1474 = prims.cat((t1457, t1473), -1)  # t1474: "cuda:0 bf16[1, 32, 512, 128]"
  del t1457, t1473
  (t1477, t1478, t1479, t1480, _, _, t1481, t1482, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1474, t1476, t1442, 0.0, True, scale=0.08838834764831843)
  t1484 = torch.permute(t1477, (0, 2, 1, 3))  # t1484: "cuda:0 bf16[1, 512, 32, 128]"
    # t1484 = ltorch.permute(t1477, (0, 2, 1, 3))  # t1484: "cuda:0 bf16[1, 512, 32, 128]"
      # t1484 = prims.transpose(t1477, (0, 2, 1, 3))  # t1484: "cuda:0 bf16[1, 512, 32, 128]"
  t1485 = torch.reshape(t1484, (1, 512, 4096))  # t1485: "cuda:0 bf16[1, 512, 4096]"
    # t1485 = ltorch.reshape(t1484, (1, 512, 4096))  # t1485: "cuda:0 bf16[1, 512, 4096]"
      # t1485 = prims.reshape(t1484, (1, 512, 4096))  # t1485: "cuda:0 bf16[1, 512, 4096]"
  del t1484
  t1486 = torch.nn.functional.linear(t1485, t109, None)  # t1486: "cuda:0 bf16[1, 512, 4096]"
    # t1486 = ltorch.linear(t1485, t109, None)  # t1486: "cuda:0 bf16[1, 512, 4096]"
      # t1486 = prims.linear(t1485, t109, None)  # t1486: "cuda:0 bf16[1, 512, 4096]"
  [t1490, t1497, t1505] = nvFusion63(t1418, t1486, t1501)
    # t1488 = prims.convert_element_type(t1418, dtypes.float32)  # t1488: "cuda:0 f32[1, 512, 4096]"
    # t1487 = prims.convert_element_type(t1486, dtypes.float32)  # t1487: "cuda:0 f32[1, 512, 4096]"
    # t1489 = prims.add(t1487, t1488)  # t1489: "cuda:0 f32[1, 512, 4096]"
    # t1490 = prims.convert_element_type(t1489, dtypes.bfloat16)  # t1490: "cuda:0 bf16[1, 512, 4096]"
    # t1492 = prims.mul(t1489, t1489)  # t1492: "cuda:0 f32[1, 512, 4096]"
    # t1493 = prims.sum(t1492, (2,))  # t1493: "cuda:0 f32[1, 512]"
    # t1494 = prims.broadcast_in_dim(t1493, [1, 512, 1], [0, 1])  # t1494: "cuda:0 f32[1, 512, 1]"
    # t1495 = prims.div(t1494, 4096.0)  # t1495: "cuda:0 f32[1, 512, 1]"
    # t1496 = prims.add(t1495, 1e-05)  # t1496: "cuda:0 f32[1, 512, 1]"
    # t1497 = prims.rsqrt(t1496)  # t1497: "cuda:0 f32[1, 512, 1]"
    # t1498 = prims.broadcast_in_dim(t1497, (1, 512, 4096), (0, 1, 2))  # t1498: "cuda:0 f32[1, 512, 4096]"
    # t1499 = prims.mul(t1489, t1498)  # t1499: "cuda:0 f32[1, 512, 4096]"
    # t1503 = prims.convert_element_type(t1501, dtypes.float32)  # t1503: "cuda:0 f32[1, 512, 4096]"
    # t1504 = prims.mul(t1499, t1503)  # t1504: "cuda:0 f32[1, 512, 4096]"
    # t1505 = prims.convert_element_type(t1504, dtypes.bfloat16)  # t1505: "cuda:0 bf16[1, 512, 4096]"
  t1506 = torch.nn.functional.linear(t1505, t31, None)  # t1506: "cuda:0 bf16[1, 512, 11008]"
    # t1506 = ltorch.linear(t1505, t31, None)  # t1506: "cuda:0 bf16[1, 512, 11008]"
      # t1506 = prims.linear(t1505, t31, None)  # t1506: "cuda:0 bf16[1, 512, 11008]"
  t1507 = torch.nn.functional.linear(t1505, t47, None)  # t1507: "cuda:0 bf16[1, 512, 11008]"
    # t1507 = ltorch.linear(t1505, t47, None)  # t1507: "cuda:0 bf16[1, 512, 11008]"
      # t1507 = prims.linear(t1505, t47, None)  # t1507: "cuda:0 bf16[1, 512, 11008]"
  [t1521] = nvFusion64(t1506, t1507)
    # t1508 = prims.convert_element_type(t1506, dtypes.float32)  # t1508: "cuda:0 f32[1, 512, 11008]"
    # t1509 = prims.neg(t1508)  # t1509: "cuda:0 f32[1, 512, 11008]"
    # t1510 = prims.exp(t1509)  # t1510: "cuda:0 f32[1, 512, 11008]"
    # t1511 = prims.add(1.0, t1510)  # t1511: "cuda:0 f32[1, 512, 11008]"
    # t1512 = prims.reciprocal(t1511)  # t1512: "cuda:0 f32[1, 512, 11008]"
    # t1516 = prims.mul(t1508, t1512)  # t1516: "cuda:0 f32[1, 512, 11008]"
    # t1519 = prims.convert_element_type(t1507, dtypes.float32)  # t1519: "cuda:0 f32[1, 512, 11008]"
    # t1520 = prims.mul(t1516, t1519)  # t1520: "cuda:0 f32[1, 512, 11008]"
    # t1521 = prims.convert_element_type(t1520, dtypes.bfloat16)  # t1521: "cuda:0 bf16[1, 512, 11008]"
  t1522 = torch.nn.functional.linear(t1521, t110, None)  # t1522: "cuda:0 bf16[1, 512, 4096]"
    # t1522 = ltorch.linear(t1521, t110, None)  # t1522: "cuda:0 bf16[1, 512, 4096]"
      # t1522 = prims.linear(t1521, t110, None)  # t1522: "cuda:0 bf16[1, 512, 4096]"
  [t1526, t1533, t1541] = nvFusion65(t1490, t1522, t1537)
    # t1524 = prims.convert_element_type(t1490, dtypes.float32)  # t1524: "cuda:0 f32[1, 512, 4096]"
    # t1523 = prims.convert_element_type(t1522, dtypes.float32)  # t1523: "cuda:0 f32[1, 512, 4096]"
    # t1525 = prims.add(t1523, t1524)  # t1525: "cuda:0 f32[1, 512, 4096]"
    # t1526 = prims.convert_element_type(t1525, dtypes.bfloat16)  # t1526: "cuda:0 bf16[1, 512, 4096]"
    # t1528 = prims.mul(t1525, t1525)  # t1528: "cuda:0 f32[1, 512, 4096]"
    # t1529 = prims.sum(t1528, (2,))  # t1529: "cuda:0 f32[1, 512]"
    # t1530 = prims.broadcast_in_dim(t1529, [1, 512, 1], [0, 1])  # t1530: "cuda:0 f32[1, 512, 1]"
    # t1531 = prims.div(t1530, 4096.0)  # t1531: "cuda:0 f32[1, 512, 1]"
    # t1532 = prims.add(t1531, 1e-05)  # t1532: "cuda:0 f32[1, 512, 1]"
    # t1533 = prims.rsqrt(t1532)  # t1533: "cuda:0 f32[1, 512, 1]"
    # t1534 = prims.broadcast_in_dim(t1533, (1, 512, 4096), (0, 1, 2))  # t1534: "cuda:0 f32[1, 512, 4096]"
    # t1535 = prims.mul(t1525, t1534)  # t1535: "cuda:0 f32[1, 512, 4096]"
    # t1539 = prims.convert_element_type(t1537, dtypes.float32)  # t1539: "cuda:0 f32[1, 512, 4096]"
    # t1540 = prims.mul(t1535, t1539)  # t1540: "cuda:0 f32[1, 512, 4096]"
    # t1541 = prims.convert_element_type(t1540, dtypes.bfloat16)  # t1541: "cuda:0 bf16[1, 512, 4096]"
  t1542 = torch.nn.functional.linear(t1541, t16, None)  # t1542: "cuda:0 bf16[1, 512, 12288]"
    # t1542 = ltorch.linear(t1541, t16, None)  # t1542: "cuda:0 bf16[1, 512, 12288]"
      # t1542 = prims.linear(t1541, t16, None)  # t1542: "cuda:0 bf16[1, 512, 12288]"
  t1543 = torch.reshape(t1542, (1, 512, 32, 3, 128))  # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1543 = ltorch.reshape(t1542, (1, 512, 32, 3, 128))  # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1543 = prims.reshape(t1542, (1, 512, 32, 3, 128))  # t1543: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1542
  t1544 = torch.permute(t1543, (0, 2, 3, 1, 4))  # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1544 = ltorch.permute(t1543, (0, 2, 3, 1, 4))  # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1544 = prims.transpose(t1543, (0, 2, 3, 1, 4))  # t1544: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1543
  (t1545, t1546, t1547) = torch.split(t1544, (1, 1, 1), 2)
    # (t1545, t1546, t1547) = ltorch.split(t1544, (1, 1, 1), 2)
      # t1545 = prims.slice_prim(t1544, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1545: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1546 = prims.slice_prim(t1544, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1546: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1547 = prims.slice_prim(t1544, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1547: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1544
  t1548 = torch.reshape(t1545, (1, 32, 512, 128))  # t1548: "cuda:0 bf16[1, 32, 512, 128]"
    # t1548 = ltorch.reshape(t1545, (1, 32, 512, 128))  # t1548: "cuda:0 bf16[1, 32, 512, 128]"
      # t1548 = prims.reshape(t1545, (1, 32, 512, 128))  # t1548: "cuda:0 bf16[1, 32, 512, 128]"
  del t1545
  t1549 = torch.reshape(t1546, (1, 32, 512, 128))  # t1549: "cuda:0 bf16[1, 32, 512, 128]"
    # t1549 = ltorch.reshape(t1546, (1, 32, 512, 128))  # t1549: "cuda:0 bf16[1, 32, 512, 128]"
      # t1549 = prims.reshape(t1546, (1, 32, 512, 128))  # t1549: "cuda:0 bf16[1, 32, 512, 128]"
  del t1546
  t1550 = torch.reshape(t1547, (1, 32, 512, 128))  # t1550: "cuda:0 bf16[1, 32, 512, 128]"
    # t1550 = ltorch.reshape(t1547, (1, 32, 512, 128))  # t1550: "cuda:0 bf16[1, 32, 512, 128]"
      # t1550 = prims.reshape(t1547, (1, 32, 512, 128))  # t1550: "cuda:0 bf16[1, 32, 512, 128]"
  del t1547
  t1551 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1551: "cuda:0 bf16[1, 32, 512, 128]"
  t1566 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1566: "cuda:0 bf16[1, 32, 512, 128]"
  t1581 = torch_slice_prim_impl(t1548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1581: "cuda:0 bf16[1, 32, 512, 0]"
  del t1548
  t1583 = torch_slice_prim_impl(t1549, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1583: "cuda:0 bf16[1, 32, 512, 0]"
  del t1549
  t1552 = torch_slice_prim_impl(t1551, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1552: "cuda:0 bf16[1, 32, 512, 64]"
  t1553 = torch_slice_prim_impl(t1551, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1553: "cuda:0 bf16[1, 32, 512, 64]"
  t1567 = torch_slice_prim_impl(t1566, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1567: "cuda:0 bf16[1, 32, 512, 64]"
  t1568 = torch_slice_prim_impl(t1566, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1568: "cuda:0 bf16[1, 32, 512, 64]"
  [t1556, t1571] = nvFusion66(t1551, t1553, t1566, t1568)
    # t1554 = prims.convert_element_type(t1553, dtypes.float32)  # t1554: "cuda:0 f32[1, 32, 512, 64]"
    # t1555 = prims.neg(t1554)  # t1555: "cuda:0 f32[1, 32, 512, 64]"
    # t1556 = prims.convert_element_type(t1555, dtypes.bfloat16)  # t1556: "cuda:0 bf16[1, 32, 512, 64]"
    # t1569 = prims.convert_element_type(t1568, dtypes.float32)  # t1569: "cuda:0 f32[1, 32, 512, 64]"
    # t1570 = prims.neg(t1569)  # t1570: "cuda:0 f32[1, 32, 512, 64]"
    # t1571 = prims.convert_element_type(t1570, dtypes.bfloat16)  # t1571: "cuda:0 bf16[1, 32, 512, 64]"
  del t1553, t1568
  t1572 = torch.cat((t1571, t1567), -1)  # t1572: "cuda:0 bf16[1, 32, 512, 128]"
    # t1572 = ltorch.cat((t1571, t1567), -1)  # t1572: "cuda:0 bf16[1, 32, 512, 128]"
      # t1572 = prims.cat((t1571, t1567), -1)  # t1572: "cuda:0 bf16[1, 32, 512, 128]"
  del t1571, t1567
  t1557 = torch.cat((t1556, t1552), -1)  # t1557: "cuda:0 bf16[1, 32, 512, 128]"
    # t1557 = ltorch.cat((t1556, t1552), -1)  # t1557: "cuda:0 bf16[1, 32, 512, 128]"
      # t1557 = prims.cat((t1556, t1552), -1)  # t1557: "cuda:0 bf16[1, 32, 512, 128]"
  del t1556, t1552
  [t1565, t1580] = nvFusion67(t154, t1551, t1557, t1566, t157, t1572)
    # t1559 = prims.convert_element_type(t1551, dtypes.float32)  # t1559: "cuda:0 f32[1, 32, 512, 128]"
    # t1574 = prims.convert_element_type(t1566, dtypes.float32)  # t1574: "cuda:0 f32[1, 32, 512, 128]"
    # t1575 = prims.mul(t1574, t154)  # t1575: "cuda:0 f32[1, 32, 512, 128]"
    # t1577 = prims.convert_element_type(t1572, dtypes.float32)  # t1577: "cuda:0 f32[1, 32, 512, 128]"
    # t1578 = prims.mul(t1577, t157)  # t1578: "cuda:0 f32[1, 32, 512, 128]"
    # t1579 = prims.add(t1575, t1578)  # t1579: "cuda:0 f32[1, 32, 512, 128]"
    # t1580 = prims.convert_element_type(t1579, dtypes.bfloat16)  # t1580: "cuda:0 bf16[1, 32, 512, 128]"
    # t1560 = prims.mul(t1559, t154)  # t1560: "cuda:0 f32[1, 32, 512, 128]"
    # t1562 = prims.convert_element_type(t1557, dtypes.float32)  # t1562: "cuda:0 f32[1, 32, 512, 128]"
    # t1563 = prims.mul(t1562, t157)  # t1563: "cuda:0 f32[1, 32, 512, 128]"
    # t1564 = prims.add(t1560, t1563)  # t1564: "cuda:0 f32[1, 32, 512, 128]"
    # t1565 = prims.convert_element_type(t1564, dtypes.bfloat16)  # t1565: "cuda:0 bf16[1, 32, 512, 128]"
  del t1551, t1557, t1566, t1572
  t1584 = torch.cat((t1580, t1583), -1)  # t1584: "cuda:0 bf16[1, 32, 512, 128]"
    # t1584 = ltorch.cat((t1580, t1583), -1)  # t1584: "cuda:0 bf16[1, 32, 512, 128]"
      # t1584 = prims.cat((t1580, t1583), -1)  # t1584: "cuda:0 bf16[1, 32, 512, 128]"
  del t1580, t1583
  t1582 = torch.cat((t1565, t1581), -1)  # t1582: "cuda:0 bf16[1, 32, 512, 128]"
    # t1582 = ltorch.cat((t1565, t1581), -1)  # t1582: "cuda:0 bf16[1, 32, 512, 128]"
      # t1582 = prims.cat((t1565, t1581), -1)  # t1582: "cuda:0 bf16[1, 32, 512, 128]"
  del t1565, t1581
  (t1585, t1586, t1587, t1588, _, _, t1589, t1590, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1582, t1584, t1550, 0.0, True, scale=0.08838834764831843)
  t1592 = torch.permute(t1585, (0, 2, 1, 3))  # t1592: "cuda:0 bf16[1, 512, 32, 128]"
    # t1592 = ltorch.permute(t1585, (0, 2, 1, 3))  # t1592: "cuda:0 bf16[1, 512, 32, 128]"
      # t1592 = prims.transpose(t1585, (0, 2, 1, 3))  # t1592: "cuda:0 bf16[1, 512, 32, 128]"
  t1593 = torch.reshape(t1592, (1, 512, 4096))  # t1593: "cuda:0 bf16[1, 512, 4096]"
    # t1593 = ltorch.reshape(t1592, (1, 512, 4096))  # t1593: "cuda:0 bf16[1, 512, 4096]"
      # t1593 = prims.reshape(t1592, (1, 512, 4096))  # t1593: "cuda:0 bf16[1, 512, 4096]"
  del t1592
  t1594 = torch.nn.functional.linear(t1593, t111, None)  # t1594: "cuda:0 bf16[1, 512, 4096]"
    # t1594 = ltorch.linear(t1593, t111, None)  # t1594: "cuda:0 bf16[1, 512, 4096]"
      # t1594 = prims.linear(t1593, t111, None)  # t1594: "cuda:0 bf16[1, 512, 4096]"
  [t1598, t1605, t1613] = nvFusion68(t1526, t1594, t1609)
    # t1596 = prims.convert_element_type(t1526, dtypes.float32)  # t1596: "cuda:0 f32[1, 512, 4096]"
    # t1595 = prims.convert_element_type(t1594, dtypes.float32)  # t1595: "cuda:0 f32[1, 512, 4096]"
    # t1597 = prims.add(t1595, t1596)  # t1597: "cuda:0 f32[1, 512, 4096]"
    # t1598 = prims.convert_element_type(t1597, dtypes.bfloat16)  # t1598: "cuda:0 bf16[1, 512, 4096]"
    # t1600 = prims.mul(t1597, t1597)  # t1600: "cuda:0 f32[1, 512, 4096]"
    # t1601 = prims.sum(t1600, (2,))  # t1601: "cuda:0 f32[1, 512]"
    # t1602 = prims.broadcast_in_dim(t1601, [1, 512, 1], [0, 1])  # t1602: "cuda:0 f32[1, 512, 1]"
    # t1603 = prims.div(t1602, 4096.0)  # t1603: "cuda:0 f32[1, 512, 1]"
    # t1604 = prims.add(t1603, 1e-05)  # t1604: "cuda:0 f32[1, 512, 1]"
    # t1605 = prims.rsqrt(t1604)  # t1605: "cuda:0 f32[1, 512, 1]"
    # t1606 = prims.broadcast_in_dim(t1605, (1, 512, 4096), (0, 1, 2))  # t1606: "cuda:0 f32[1, 512, 4096]"
    # t1607 = prims.mul(t1597, t1606)  # t1607: "cuda:0 f32[1, 512, 4096]"
    # t1611 = prims.convert_element_type(t1609, dtypes.float32)  # t1611: "cuda:0 f32[1, 512, 4096]"
    # t1612 = prims.mul(t1607, t1611)  # t1612: "cuda:0 f32[1, 512, 4096]"
    # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16)  # t1613: "cuda:0 bf16[1, 512, 4096]"
  t1614 = torch.nn.functional.linear(t1613, t32, None)  # t1614: "cuda:0 bf16[1, 512, 11008]"
    # t1614 = ltorch.linear(t1613, t32, None)  # t1614: "cuda:0 bf16[1, 512, 11008]"
      # t1614 = prims.linear(t1613, t32, None)  # t1614: "cuda:0 bf16[1, 512, 11008]"
  t1615 = torch.nn.functional.linear(t1613, t48, None)  # t1615: "cuda:0 bf16[1, 512, 11008]"
    # t1615 = ltorch.linear(t1613, t48, None)  # t1615: "cuda:0 bf16[1, 512, 11008]"
      # t1615 = prims.linear(t1613, t48, None)  # t1615: "cuda:0 bf16[1, 512, 11008]"
  [t1629] = nvFusion69(t1614, t1615)
    # t1616 = prims.convert_element_type(t1614, dtypes.float32)  # t1616: "cuda:0 f32[1, 512, 11008]"
    # t1617 = prims.neg(t1616)  # t1617: "cuda:0 f32[1, 512, 11008]"
    # t1618 = prims.exp(t1617)  # t1618: "cuda:0 f32[1, 512, 11008]"
    # t1619 = prims.add(1.0, t1618)  # t1619: "cuda:0 f32[1, 512, 11008]"
    # t1620 = prims.reciprocal(t1619)  # t1620: "cuda:0 f32[1, 512, 11008]"
    # t1624 = prims.mul(t1616, t1620)  # t1624: "cuda:0 f32[1, 512, 11008]"
    # t1627 = prims.convert_element_type(t1615, dtypes.float32)  # t1627: "cuda:0 f32[1, 512, 11008]"
    # t1628 = prims.mul(t1624, t1627)  # t1628: "cuda:0 f32[1, 512, 11008]"
    # t1629 = prims.convert_element_type(t1628, dtypes.bfloat16)  # t1629: "cuda:0 bf16[1, 512, 11008]"
  t1630 = torch.nn.functional.linear(t1629, t112, None)  # t1630: "cuda:0 bf16[1, 512, 4096]"
    # t1630 = ltorch.linear(t1629, t112, None)  # t1630: "cuda:0 bf16[1, 512, 4096]"
      # t1630 = prims.linear(t1629, t112, None)  # t1630: "cuda:0 bf16[1, 512, 4096]"
  [t1634, t1641, t1649] = nvFusion70(t1598, t1630, t1645)
    # t1632 = prims.convert_element_type(t1598, dtypes.float32)  # t1632: "cuda:0 f32[1, 512, 4096]"
    # t1631 = prims.convert_element_type(t1630, dtypes.float32)  # t1631: "cuda:0 f32[1, 512, 4096]"
    # t1633 = prims.add(t1631, t1632)  # t1633: "cuda:0 f32[1, 512, 4096]"
    # t1634 = prims.convert_element_type(t1633, dtypes.bfloat16)  # t1634: "cuda:0 bf16[1, 512, 4096]"
    # t1636 = prims.mul(t1633, t1633)  # t1636: "cuda:0 f32[1, 512, 4096]"
    # t1637 = prims.sum(t1636, (2,))  # t1637: "cuda:0 f32[1, 512]"
    # t1638 = prims.broadcast_in_dim(t1637, [1, 512, 1], [0, 1])  # t1638: "cuda:0 f32[1, 512, 1]"
    # t1639 = prims.div(t1638, 4096.0)  # t1639: "cuda:0 f32[1, 512, 1]"
    # t1640 = prims.add(t1639, 1e-05)  # t1640: "cuda:0 f32[1, 512, 1]"
    # t1641 = prims.rsqrt(t1640)  # t1641: "cuda:0 f32[1, 512, 1]"
    # t1642 = prims.broadcast_in_dim(t1641, (1, 512, 4096), (0, 1, 2))  # t1642: "cuda:0 f32[1, 512, 4096]"
    # t1643 = prims.mul(t1633, t1642)  # t1643: "cuda:0 f32[1, 512, 4096]"
    # t1647 = prims.convert_element_type(t1645, dtypes.float32)  # t1647: "cuda:0 f32[1, 512, 4096]"
    # t1648 = prims.mul(t1643, t1647)  # t1648: "cuda:0 f32[1, 512, 4096]"
    # t1649 = prims.convert_element_type(t1648, dtypes.bfloat16)  # t1649: "cuda:0 bf16[1, 512, 4096]"
  t1650 = torch.nn.functional.linear(t1649, t17, None)  # t1650: "cuda:0 bf16[1, 512, 12288]"
    # t1650 = ltorch.linear(t1649, t17, None)  # t1650: "cuda:0 bf16[1, 512, 12288]"
      # t1650 = prims.linear(t1649, t17, None)  # t1650: "cuda:0 bf16[1, 512, 12288]"
  t1651 = torch.reshape(t1650, (1, 512, 32, 3, 128))  # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1651 = ltorch.reshape(t1650, (1, 512, 32, 3, 128))  # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1651 = prims.reshape(t1650, (1, 512, 32, 3, 128))  # t1651: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1650
  t1652 = torch.permute(t1651, (0, 2, 3, 1, 4))  # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1652 = ltorch.permute(t1651, (0, 2, 3, 1, 4))  # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1652 = prims.transpose(t1651, (0, 2, 3, 1, 4))  # t1652: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1651
  (t1653, t1654, t1655) = torch.split(t1652, (1, 1, 1), 2)
    # (t1653, t1654, t1655) = ltorch.split(t1652, (1, 1, 1), 2)
      # t1653 = prims.slice_prim(t1652, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1653: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1654 = prims.slice_prim(t1652, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1654: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1655 = prims.slice_prim(t1652, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1655: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1652
  t1656 = torch.reshape(t1653, (1, 32, 512, 128))  # t1656: "cuda:0 bf16[1, 32, 512, 128]"
    # t1656 = ltorch.reshape(t1653, (1, 32, 512, 128))  # t1656: "cuda:0 bf16[1, 32, 512, 128]"
      # t1656 = prims.reshape(t1653, (1, 32, 512, 128))  # t1656: "cuda:0 bf16[1, 32, 512, 128]"
  del t1653
  t1657 = torch.reshape(t1654, (1, 32, 512, 128))  # t1657: "cuda:0 bf16[1, 32, 512, 128]"
    # t1657 = ltorch.reshape(t1654, (1, 32, 512, 128))  # t1657: "cuda:0 bf16[1, 32, 512, 128]"
      # t1657 = prims.reshape(t1654, (1, 32, 512, 128))  # t1657: "cuda:0 bf16[1, 32, 512, 128]"
  del t1654
  t1658 = torch.reshape(t1655, (1, 32, 512, 128))  # t1658: "cuda:0 bf16[1, 32, 512, 128]"
    # t1658 = ltorch.reshape(t1655, (1, 32, 512, 128))  # t1658: "cuda:0 bf16[1, 32, 512, 128]"
      # t1658 = prims.reshape(t1655, (1, 32, 512, 128))  # t1658: "cuda:0 bf16[1, 32, 512, 128]"
  del t1655
  t1689 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1689: "cuda:0 bf16[1, 32, 512, 0]"
  t1691 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1691: "cuda:0 bf16[1, 32, 512, 0]"
  t1659 = torch_slice_prim_impl(t1656, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1659: "cuda:0 bf16[1, 32, 512, 128]"
  del t1656
  t1674 = torch_slice_prim_impl(t1657, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1674: "cuda:0 bf16[1, 32, 512, 128]"
  del t1657
  t1660 = torch_slice_prim_impl(t1659, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1660: "cuda:0 bf16[1, 32, 512, 64]"
  t1661 = torch_slice_prim_impl(t1659, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1661: "cuda:0 bf16[1, 32, 512, 64]"
  t1675 = torch_slice_prim_impl(t1674, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1675: "cuda:0 bf16[1, 32, 512, 64]"
  t1676 = torch_slice_prim_impl(t1674, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1676: "cuda:0 bf16[1, 32, 512, 64]"
  [t1664, t1679] = nvFusion71(t1659, t1661, t1674, t1676)
    # t1662 = prims.convert_element_type(t1661, dtypes.float32)  # t1662: "cuda:0 f32[1, 32, 512, 64]"
    # t1663 = prims.neg(t1662)  # t1663: "cuda:0 f32[1, 32, 512, 64]"
    # t1664 = prims.convert_element_type(t1663, dtypes.bfloat16)  # t1664: "cuda:0 bf16[1, 32, 512, 64]"
    # t1677 = prims.convert_element_type(t1676, dtypes.float32)  # t1677: "cuda:0 f32[1, 32, 512, 64]"
    # t1678 = prims.neg(t1677)  # t1678: "cuda:0 f32[1, 32, 512, 64]"
    # t1679 = prims.convert_element_type(t1678, dtypes.bfloat16)  # t1679: "cuda:0 bf16[1, 32, 512, 64]"
  del t1661, t1676
  t1680 = torch.cat((t1679, t1675), -1)  # t1680: "cuda:0 bf16[1, 32, 512, 128]"
    # t1680 = ltorch.cat((t1679, t1675), -1)  # t1680: "cuda:0 bf16[1, 32, 512, 128]"
      # t1680 = prims.cat((t1679, t1675), -1)  # t1680: "cuda:0 bf16[1, 32, 512, 128]"
  del t1679, t1675
  t1665 = torch.cat((t1664, t1660), -1)  # t1665: "cuda:0 bf16[1, 32, 512, 128]"
    # t1665 = ltorch.cat((t1664, t1660), -1)  # t1665: "cuda:0 bf16[1, 32, 512, 128]"
      # t1665 = prims.cat((t1664, t1660), -1)  # t1665: "cuda:0 bf16[1, 32, 512, 128]"
  del t1664, t1660
  [t1673, t1688] = nvFusion72(t154, t157, t1659, t1665, t1674, t1680)
    # t1667 = prims.convert_element_type(t1659, dtypes.float32)  # t1667: "cuda:0 f32[1, 32, 512, 128]"
    # t1682 = prims.convert_element_type(t1674, dtypes.float32)  # t1682: "cuda:0 f32[1, 32, 512, 128]"
    # t1683 = prims.mul(t1682, t154)  # t1683: "cuda:0 f32[1, 32, 512, 128]"
    # t1685 = prims.convert_element_type(t1680, dtypes.float32)  # t1685: "cuda:0 f32[1, 32, 512, 128]"
    # t1686 = prims.mul(t1685, t157)  # t1686: "cuda:0 f32[1, 32, 512, 128]"
    # t1687 = prims.add(t1683, t1686)  # t1687: "cuda:0 f32[1, 32, 512, 128]"
    # t1688 = prims.convert_element_type(t1687, dtypes.bfloat16)  # t1688: "cuda:0 bf16[1, 32, 512, 128]"
    # t1668 = prims.mul(t1667, t154)  # t1668: "cuda:0 f32[1, 32, 512, 128]"
    # t1670 = prims.convert_element_type(t1665, dtypes.float32)  # t1670: "cuda:0 f32[1, 32, 512, 128]"
    # t1671 = prims.mul(t1670, t157)  # t1671: "cuda:0 f32[1, 32, 512, 128]"
    # t1672 = prims.add(t1668, t1671)  # t1672: "cuda:0 f32[1, 32, 512, 128]"
    # t1673 = prims.convert_element_type(t1672, dtypes.bfloat16)  # t1673: "cuda:0 bf16[1, 32, 512, 128]"
  del t1659, t1665, t1674, t1680
  t1692 = torch.cat((t1688, t1691), -1)  # t1692: "cuda:0 bf16[1, 32, 512, 128]"
    # t1692 = ltorch.cat((t1688, t1691), -1)  # t1692: "cuda:0 bf16[1, 32, 512, 128]"
      # t1692 = prims.cat((t1688, t1691), -1)  # t1692: "cuda:0 bf16[1, 32, 512, 128]"
  del t1688, t1691
  t1690 = torch.cat((t1673, t1689), -1)  # t1690: "cuda:0 bf16[1, 32, 512, 128]"
    # t1690 = ltorch.cat((t1673, t1689), -1)  # t1690: "cuda:0 bf16[1, 32, 512, 128]"
      # t1690 = prims.cat((t1673, t1689), -1)  # t1690: "cuda:0 bf16[1, 32, 512, 128]"
  del t1673, t1689
  (t1693, t1694, t1695, t1696, _, _, t1697, t1698, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1690, t1692, t1658, 0.0, True, scale=0.08838834764831843)
  t1700 = torch.permute(t1693, (0, 2, 1, 3))  # t1700: "cuda:0 bf16[1, 512, 32, 128]"
    # t1700 = ltorch.permute(t1693, (0, 2, 1, 3))  # t1700: "cuda:0 bf16[1, 512, 32, 128]"
      # t1700 = prims.transpose(t1693, (0, 2, 1, 3))  # t1700: "cuda:0 bf16[1, 512, 32, 128]"
  t1701 = torch.reshape(t1700, (1, 512, 4096))  # t1701: "cuda:0 bf16[1, 512, 4096]"
    # t1701 = ltorch.reshape(t1700, (1, 512, 4096))  # t1701: "cuda:0 bf16[1, 512, 4096]"
      # t1701 = prims.reshape(t1700, (1, 512, 4096))  # t1701: "cuda:0 bf16[1, 512, 4096]"
  del t1700
  t1702 = torch.nn.functional.linear(t1701, t113, None)  # t1702: "cuda:0 bf16[1, 512, 4096]"
    # t1702 = ltorch.linear(t1701, t113, None)  # t1702: "cuda:0 bf16[1, 512, 4096]"
      # t1702 = prims.linear(t1701, t113, None)  # t1702: "cuda:0 bf16[1, 512, 4096]"
  [t1706, t1713, t1721] = nvFusion73(t1634, t1702, t1717)
    # t1704 = prims.convert_element_type(t1634, dtypes.float32)  # t1704: "cuda:0 f32[1, 512, 4096]"
    # t1703 = prims.convert_element_type(t1702, dtypes.float32)  # t1703: "cuda:0 f32[1, 512, 4096]"
    # t1705 = prims.add(t1703, t1704)  # t1705: "cuda:0 f32[1, 512, 4096]"
    # t1706 = prims.convert_element_type(t1705, dtypes.bfloat16)  # t1706: "cuda:0 bf16[1, 512, 4096]"
    # t1708 = prims.mul(t1705, t1705)  # t1708: "cuda:0 f32[1, 512, 4096]"
    # t1709 = prims.sum(t1708, (2,))  # t1709: "cuda:0 f32[1, 512]"
    # t1710 = prims.broadcast_in_dim(t1709, [1, 512, 1], [0, 1])  # t1710: "cuda:0 f32[1, 512, 1]"
    # t1711 = prims.div(t1710, 4096.0)  # t1711: "cuda:0 f32[1, 512, 1]"
    # t1712 = prims.add(t1711, 1e-05)  # t1712: "cuda:0 f32[1, 512, 1]"
    # t1713 = prims.rsqrt(t1712)  # t1713: "cuda:0 f32[1, 512, 1]"
    # t1714 = prims.broadcast_in_dim(t1713, (1, 512, 4096), (0, 1, 2))  # t1714: "cuda:0 f32[1, 512, 4096]"
    # t1715 = prims.mul(t1705, t1714)  # t1715: "cuda:0 f32[1, 512, 4096]"
    # t1719 = prims.convert_element_type(t1717, dtypes.float32)  # t1719: "cuda:0 f32[1, 512, 4096]"
    # t1720 = prims.mul(t1715, t1719)  # t1720: "cuda:0 f32[1, 512, 4096]"
    # t1721 = prims.convert_element_type(t1720, dtypes.bfloat16)  # t1721: "cuda:0 bf16[1, 512, 4096]"
  t1722 = torch.nn.functional.linear(t1721, t33, None)  # t1722: "cuda:0 bf16[1, 512, 11008]"
    # t1722 = ltorch.linear(t1721, t33, None)  # t1722: "cuda:0 bf16[1, 512, 11008]"
      # t1722 = prims.linear(t1721, t33, None)  # t1722: "cuda:0 bf16[1, 512, 11008]"
  t1723 = torch.nn.functional.linear(t1721, t49, None)  # t1723: "cuda:0 bf16[1, 512, 11008]"
    # t1723 = ltorch.linear(t1721, t49, None)  # t1723: "cuda:0 bf16[1, 512, 11008]"
      # t1723 = prims.linear(t1721, t49, None)  # t1723: "cuda:0 bf16[1, 512, 11008]"
  [t1737] = nvFusion74(t1722, t1723)
    # t1724 = prims.convert_element_type(t1722, dtypes.float32)  # t1724: "cuda:0 f32[1, 512, 11008]"
    # t1725 = prims.neg(t1724)  # t1725: "cuda:0 f32[1, 512, 11008]"
    # t1726 = prims.exp(t1725)  # t1726: "cuda:0 f32[1, 512, 11008]"
    # t1727 = prims.add(1.0, t1726)  # t1727: "cuda:0 f32[1, 512, 11008]"
    # t1728 = prims.reciprocal(t1727)  # t1728: "cuda:0 f32[1, 512, 11008]"
    # t1732 = prims.mul(t1724, t1728)  # t1732: "cuda:0 f32[1, 512, 11008]"
    # t1735 = prims.convert_element_type(t1723, dtypes.float32)  # t1735: "cuda:0 f32[1, 512, 11008]"
    # t1736 = prims.mul(t1732, t1735)  # t1736: "cuda:0 f32[1, 512, 11008]"
    # t1737 = prims.convert_element_type(t1736, dtypes.bfloat16)  # t1737: "cuda:0 bf16[1, 512, 11008]"
  t1738 = torch.nn.functional.linear(t1737, t114, None)  # t1738: "cuda:0 bf16[1, 512, 4096]"
    # t1738 = ltorch.linear(t1737, t114, None)  # t1738: "cuda:0 bf16[1, 512, 4096]"
      # t1738 = prims.linear(t1737, t114, None)  # t1738: "cuda:0 bf16[1, 512, 4096]"
  [t1742, t1749, t1757] = nvFusion75(t1706, t1738, t1753)
    # t1740 = prims.convert_element_type(t1706, dtypes.float32)  # t1740: "cuda:0 f32[1, 512, 4096]"
    # t1739 = prims.convert_element_type(t1738, dtypes.float32)  # t1739: "cuda:0 f32[1, 512, 4096]"
    # t1741 = prims.add(t1739, t1740)  # t1741: "cuda:0 f32[1, 512, 4096]"
    # t1742 = prims.convert_element_type(t1741, dtypes.bfloat16)  # t1742: "cuda:0 bf16[1, 512, 4096]"
    # t1744 = prims.mul(t1741, t1741)  # t1744: "cuda:0 f32[1, 512, 4096]"
    # t1745 = prims.sum(t1744, (2,))  # t1745: "cuda:0 f32[1, 512]"
    # t1746 = prims.broadcast_in_dim(t1745, [1, 512, 1], [0, 1])  # t1746: "cuda:0 f32[1, 512, 1]"
    # t1747 = prims.div(t1746, 4096.0)  # t1747: "cuda:0 f32[1, 512, 1]"
    # t1748 = prims.add(t1747, 1e-05)  # t1748: "cuda:0 f32[1, 512, 1]"
    # t1749 = prims.rsqrt(t1748)  # t1749: "cuda:0 f32[1, 512, 1]"
    # t1750 = prims.broadcast_in_dim(t1749, (1, 512, 4096), (0, 1, 2))  # t1750: "cuda:0 f32[1, 512, 4096]"
    # t1751 = prims.mul(t1741, t1750)  # t1751: "cuda:0 f32[1, 512, 4096]"
    # t1755 = prims.convert_element_type(t1753, dtypes.float32)  # t1755: "cuda:0 f32[1, 512, 4096]"
    # t1756 = prims.mul(t1751, t1755)  # t1756: "cuda:0 f32[1, 512, 4096]"
    # t1757 = prims.convert_element_type(t1756, dtypes.bfloat16)  # t1757: "cuda:0 bf16[1, 512, 4096]"
  t1758 = torch.nn.functional.linear(t1757, t18, None)  # t1758: "cuda:0 bf16[1, 512, 12288]"
    # t1758 = ltorch.linear(t1757, t18, None)  # t1758: "cuda:0 bf16[1, 512, 12288]"
      # t1758 = prims.linear(t1757, t18, None)  # t1758: "cuda:0 bf16[1, 512, 12288]"
  t1759 = torch.reshape(t1758, (1, 512, 32, 3, 128))  # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
    # t1759 = ltorch.reshape(t1758, (1, 512, 32, 3, 128))  # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
      # t1759 = prims.reshape(t1758, (1, 512, 32, 3, 128))  # t1759: "cuda:0 bf16[1, 512, 32, 3, 128]"
  del t1758
  t1760 = torch.permute(t1759, (0, 2, 3, 1, 4))  # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
    # t1760 = ltorch.permute(t1759, (0, 2, 3, 1, 4))  # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
      # t1760 = prims.transpose(t1759, (0, 2, 3, 1, 4))  # t1760: "cuda:0 bf16[1, 32, 3, 512, 128]"
  del t1759
  (t1761, t1762, t1763) = torch.split(t1760, (1, 1, 1), 2)
    # (t1761, t1762, t1763) = ltorch.split(t1760, (1, 1, 1), 2)
      # t1761 = prims.slice_prim(t1760, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1])  # t1761: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1762 = prims.slice_prim(t1760, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1])  # t1762: "cuda:0 bf16[1, 32, 1, 512, 128]"
      # t1763 = prims.slice_prim(t1760, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1])  # t1763: "cuda:0 bf16[1, 32, 1, 512, 128]"
  del t1760
  t1764 = torch.reshape(t1761, (1, 32, 512, 128))  # t1764: "cuda:0 bf16[1, 32, 512, 128]"
    # t1764 = ltorch.reshape(t1761, (1, 32, 512, 128))  # t1764: "cuda:0 bf16[1, 32, 512, 128]"
      # t1764 = prims.reshape(t1761, (1, 32, 512, 128))  # t1764: "cuda:0 bf16[1, 32, 512, 128]"
  del t1761
  t1765 = torch.reshape(t1762, (1, 32, 512, 128))  # t1765: "cuda:0 bf16[1, 32, 512, 128]"
    # t1765 = ltorch.reshape(t1762, (1, 32, 512, 128))  # t1765: "cuda:0 bf16[1, 32, 512, 128]"
      # t1765 = prims.reshape(t1762, (1, 32, 512, 128))  # t1765: "cuda:0 bf16[1, 32, 512, 128]"
  del t1762
  t1766 = torch.reshape(t1763, (1, 32, 512, 128))  # t1766: "cuda:0 bf16[1, 32, 512, 128]"
    # t1766 = ltorch.reshape(t1763, (1, 32, 512, 128))  # t1766: "cuda:0 bf16[1, 32, 512, 128]"
      # t1766 = prims.reshape(t1763, (1, 32, 512, 128))  # t1766: "cuda:0 bf16[1, 32, 512, 128]"
  del t1763
  t1767 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1767: "cuda:0 bf16[1, 32, 512, 128]"
  t1782 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1])  # t1782: "cuda:0 bf16[1, 32, 512, 128]"
  t1797 = torch_slice_prim_impl(t1764, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1797: "cuda:0 bf16[1, 32, 512, 0]"
  del t1764
  t1799 = torch_slice_prim_impl(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1])  # t1799: "cuda:0 bf16[1, 32, 512, 0]"
  del t1765
  t1768 = torch_slice_prim_impl(t1767, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1768: "cuda:0 bf16[1, 32, 512, 64]"
  t1769 = torch_slice_prim_impl(t1767, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1769: "cuda:0 bf16[1, 32, 512, 64]"
  t1783 = torch_slice_prim_impl(t1782, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1])  # t1783: "cuda:0 bf16[1, 32, 512, 64]"
  t1784 = torch_slice_prim_impl(t1782, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1])  # t1784: "cuda:0 bf16[1, 32, 512, 64]"
  [t1772, t1787] = nvFusion76(t1767, t1769, t1782, t1784)
    # t1770 = prims.convert_element_type(t1769, dtypes.float32)  # t1770: "cuda:0 f32[1, 32, 512, 64]"
    # t1771 = prims.neg(t1770)  # t1771: "cuda:0 f32[1, 32, 512, 64]"
    # t1772 = prims.convert_element_type(t1771, dtypes.bfloat16)  # t1772: "cuda:0 bf16[1, 32, 512, 64]"
    # t1785 = prims.convert_element_type(t1784, dtypes.float32)  # t1785: "cuda:0 f32[1, 32, 512, 64]"
    # t1786 = prims.neg(t1785)  # t1786: "cuda:0 f32[1, 32, 512, 64]"
    # t1787 = prims.convert_element_type(t1786, dtypes.bfloat16)  # t1787: "cuda:0 bf16[1, 32, 512, 64]"
  del t1769, t1784
  t1788 = torch.cat((t1787, t1783), -1)  # t1788: "cuda:0 bf16[1, 32, 512, 128]"
    # t1788 = ltorch.cat((t1787, t1783), -1)  # t1788: "cuda:0 bf16[1, 32, 512, 128]"
      # t1788 = prims.cat((t1787, t1783), -1)  # t1788: "cuda:0 bf16[1, 32, 512, 128]"
  del t1787, t1783
  t1773 = torch.cat((t1772, t1768), -1)  # t1773: "cuda:0 bf16[1, 32, 512, 128]"
    # t1773 = ltorch.cat((t1772, t1768), -1)  # t1773: "cuda:0 bf16[1, 32, 512, 128]"
      # t1773 = prims.cat((t1772, t1768), -1)  # t1773: "cuda:0 bf16[1, 32, 512, 128]"
  del t1772, t1768
  [t1781, t1796] = nvFusion77(t154, t157, t1767, t1773, t1782, t1788)
    # t1775 = prims.convert_element_type(t1767, dtypes.float32)  # t1775: "cuda:0 f32[1, 32, 512, 128]"
    # t1790 = prims.convert_element_type(t1782, dtypes.float32)  # t1790: "cuda:0 f32[1, 32, 512, 128]"
    # t1791 = prims.mul(t1790, t154)  # t1791: "cuda:0 f32[1, 32, 512, 128]"
    # t1793 = prims.convert_element_type(t1788, dtypes.float32)  # t1793: "cuda:0 f32[1, 32, 512, 128]"
    # t1794 = prims.mul(t1793, t157)  # t1794: "cuda:0 f32[1, 32, 512, 128]"
    # t1795 = prims.add(t1791, t1794)  # t1795: "cuda:0 f32[1, 32, 512, 128]"
    # t1796 = prims.convert_element_type(t1795, dtypes.bfloat16)  # t1796: "cuda:0 bf16[1, 32, 512, 128]"
    # t1776 = prims.mul(t1775, t154)  # t1776: "cuda:0 f32[1, 32, 512, 128]"
    # t1778 = prims.convert_element_type(t1773, dtypes.float32)  # t1778: "cuda:0 f32[1, 32, 512, 128]"
    # t1779 = prims.mul(t1778, t157)  # t1779: "cuda:0 f32[1, 32, 512, 128]"
    # t1780 = prims.add(t1776, t1779)  # t1780: "cuda:0 f32[1, 32, 512, 128]"
    # t1781 = prims.convert_element_type(t1780, dtypes.bfloat16)  # t1781: "cuda:0 bf16[1, 32, 512, 128]"
  del t1767, t1773, t1782, t1788
  t1800 = torch.cat((t1796, t1799), -1)  # t1800: "cuda:0 bf16[1, 32, 512, 128]"
    # t1800 = ltorch.cat((t1796, t1799), -1)  # t1800: "cuda:0 bf16[1, 32, 512, 128]"
      # t1800 = prims.cat((t1796, t1799), -1)  # t1800: "cuda:0 bf16[1, 32, 512, 128]"
  del t1796, t1799
  t1798 = torch.cat((t1781, t1797), -1)  # t1798: "cuda:0 bf16[1, 32, 512, 128]"
    # t1798 = ltorch.cat((t1781, t1797), -1)  # t1798: "cuda:0 bf16[1, 32, 512, 128]"
      # t1798 = prims.cat((t1781, t1797), -1)  # t1798: "cuda:0 bf16[1, 32, 512, 128]"
  del t1781, t1797
  (t1801, t1802, t1803, t1804, _, _, t1805, t1806, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(t1798, t1800, t1766, 0.0, True, scale=0.08838834764831843)
  t1808 = torch.permute(t1801, (0, 2, 1, 3))  # t1808: "cuda:0 bf16[1, 512, 32, 128]"
    # t1808 = ltorch.permute(t1801, (0, 2, 1, 3))  # t1808: "cuda:0 bf16[1, 512, 32, 128]"
      # t1808 = prims.transpose(t1801, (0, 2, 1, 3))  # t1808: "cuda:0 bf16[1, 512, 32, 128]"
  t1809 = torch.reshape(t1808, (1, 512, 4096))  # t1809: "cuda:0 bf16[1, 512, 4096]"
    # t1809 = ltorch.reshape(t1808, (1, 512, 4096))  # t1809: "cuda:0 bf16[1, 512, 4096]"
      # t1809 = prims.reshape(t1808, (1, 512, 4096))  # t1809: "cuda:0 bf16[1, 512, 4096]"
  del t1808
  t1810 = torch.nn.functional.linear(t1809, t115, None)  # t1810: "cuda:0 bf16[1, 512, 4096]"
    # t1810 = ltorch.linear(t1809, t115, None)  # t1810: "cuda:0 bf16[1, 512, 4096]"
      # t1810 = prims.linear(t1809, t115, None)  # t1810: "cuda:0 bf16[1, 512, 4096]"
  [t1814, t1821, t1829] = nvFusion78(t1742, t1810, t1825)
    # t1812 = prims.convert_element_type(t1742, dtypes.float32)  # t1812: "cuda:0 f32[1, 512, 4096]"
    # t1811 = prims.convert_element_type(t1810, dtypes.float32)  # t1811: "cuda:0 f32[1, 512, 4096]"
    # t1813 = prims.add(t1811, t1812)  # t1813: "cuda:0 f32[1, 512, 4096]"
    # t1814 = prims.convert_element_type(t1813, dtypes.bfloat16)  # t1814: "cuda:0 bf16[1, 512, 4096]"
    # t1816 = prims.mul(t1813, t1813)  # t1816: "cuda:0 f32[1, 512, 4096]"
    # t1817 = prims.sum(t1816, (2,))  # t1817: "cuda:0 f32[1, 512]"
    # t1818 = prims.broadcast_in_dim(t1817, [1, 512, 1], [0, 1])  # t1818: "cuda:0 f32[1, 512, 1]"
    # t1819 = prims.div(t1818, 4096.0)  # t1819: "cuda:0 f32[1, 512, 1]"
    # t1820 = prims.add(t1819, 1e-05)  # t1820: "cuda:0 f32[1, 512, 1]"
    # t1821 = prims.rsqrt(t1820)  # t1821: "cuda:0 f32[1, 512, 1]"
    # t1822 = prims.broadcast_in_dim(t1821, (1, 512, 4096), (0, 1, 2))  # t1822: "cuda:0 f32[1, 512, 4096]"
    # t1823 = prims.mul(t1813, t1822)  # t1823: "cuda:0 f32[1, 512, 4096]"
    # t1827 = prims.convert_element_type(t1825, dtypes.float32)  # t1827: "cuda:0 f32[1, 512, 4096]"
    # t1828 = prims.mul(t1823, t1827)  # t1828: "cuda:0 f32[1, 512, 4096]"
    # t1829 = prims.convert_element_type(t1828, dtypes.bfloat16)  # t1829: "cuda:0 bf16[1, 512, 4096]"
  t1831 = torch.nn.functional.linear(t1829, t50, None)  # t1831: "cuda:0 bf16[1, 512, 11008]"
    # t1831 = ltorch.linear(t1829, t50, None)  # t1831: "cuda:0 bf16[1, 512, 11008]"
      # t1831 = prims.linear(t1829, t50, None)  # t1831: "cuda:0 bf16[1, 512, 11008]"
  t1830 = torch.nn.functional.linear(t1829, t34, None)  # t1830: "cuda:0 bf16[1, 512, 11008]"
    # t1830 = ltorch.linear(t1829, t34, None)  # t1830: "cuda:0 bf16[1, 512, 11008]"
      # t1830 = prims.linear(t1829, t34, None)  # t1830: "cuda:0 bf16[1, 512, 11008]"
  [t1845] = nvFusion79(t1830, t1831)
    # t1832 = prims.convert_element_type(t1830, dtypes.float32)  # t1832: "cuda:0 f32[1, 512, 11008]"
    # t1833 = prims.neg(t1832)  # t1833: "cuda:0 f32[1, 512, 11008]"
    # t1834 = prims.exp(t1833)  # t1834: "cuda:0 f32[1, 512, 11008]"
    # t1835 = prims.add(1.0, t1834)  # t1835: "cuda:0 f32[1, 512, 11008]"
    # t1836 = prims.reciprocal(t1835)  # t1836: "cuda:0 f32[1, 512, 11008]"
    # t1840 = prims.mul(t1832, t1836)  # t1840: "cuda:0 f32[1, 512, 11008]"
    # t1843 = prims.convert_element_type(t1831, dtypes.float32)  # t1843: "cuda:0 f32[1, 512, 11008]"
    # t1844 = prims.mul(t1840, t1843)  # t1844: "cuda:0 f32[1, 512, 11008]"
    # t1845 = prims.convert_element_type(t1844, dtypes.bfloat16)  # t1845: "cuda:0 bf16[1, 512, 11008]"
  t1846 = torch.nn.functional.linear(t1845, t116, None)  # t1846: "cuda:0 bf16[1, 512, 4096]"
    # t1846 = ltorch.linear(t1845, t116, None)  # t1846: "cuda:0 bf16[1, 512, 4096]"
      # t1846 = prims.linear(t1845, t116, None)  # t1846: "cuda:0 bf16[1, 512, 4096]"
  [t1857, t1865] = nvFusion80(t1814, t1846, t1861)
    # t1848 = prims.convert_element_type(t1814, dtypes.float32)  # t1848: "cuda:0 f32[1, 512, 4096]"
    # t1847 = prims.convert_element_type(t1846, dtypes.float32)  # t1847: "cuda:0 f32[1, 512, 4096]"
    # t1849 = prims.add(t1847, t1848)  # t1849: "cuda:0 f32[1, 512, 4096]"
    # t1852 = prims.mul(t1849, t1849)  # t1852: "cuda:0 f32[1, 512, 4096]"
    # t1853 = prims.sum(t1852, (2,))  # t1853: "cuda:0 f32[1, 512]"
    # t1854 = prims.broadcast_in_dim(t1853, [1, 512, 1], [0, 1])  # t1854: "cuda:0 f32[1, 512, 1]"
    # t1855 = prims.div(t1854, 4096.0)  # t1855: "cuda:0 f32[1, 512, 1]"
    # t1856 = prims.add(t1855, 1e-05)  # t1856: "cuda:0 f32[1, 512, 1]"
    # t1857 = prims.rsqrt(t1856)  # t1857: "cuda:0 f32[1, 512, 1]"
    # t1858 = prims.broadcast_in_dim(t1857, (1, 512, 4096), (0, 1, 2))  # t1858: "cuda:0 f32[1, 512, 4096]"
    # t1859 = prims.mul(t1849, t1858)  # t1859: "cuda:0 f32[1, 512, 4096]"
    # t1863 = prims.convert_element_type(t1861, dtypes.float32)  # t1863: "cuda:0 f32[1, 512, 4096]"
    # t1864 = prims.mul(t1859, t1863)  # t1864: "cuda:0 f32[1, 512, 4096]"
    # t1865 = prims.convert_element_type(t1864, dtypes.bfloat16)  # t1865: "cuda:0 bf16[1, 512, 4096]"
  t1866 = torch.nn.functional.linear(t1865, t51, None)  # t1866: "cuda:0 bf16[1, 512, 32000]"
    # t1866 = ltorch.linear(t1865, t51, None)  # t1866: "cuda:0 bf16[1, 512, 32000]"
      # t1866 = prims.linear(t1865, t51, None)  # t1866: "cuda:0 bf16[1, 512, 32000]"
  return {'output': t1866, 'flat_args': [t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19, t20, t21, t22, t23, t24, t25, t26, t27, t28, t29, t30, t31, t32, t33, t34, t35, t36, t37, t38, t39, t40, t41, t42, t43, t44, t45, t46, t47, t48, t49, t50, t51, t52, t53, t54, t55, t56, t57, t58, t59, t60, t61, t62, t63, t64, t65, t66, t67, t68, t69, t70, t71, t72, t73, t74, t75, t76, t77, t78, t79, t80, t81, t82, t83, t84, t85, t86, t87, t88, t89, t90, t91, t92, t93, t94, t95, t96, t97, t98, t99, t100, t101, t102, t103, t104, t105, t106, t107, t108, t109, t110, t111, t112, t113, t114, t115, t116, t117], 'flat_output': (t1866,)}, ((t0, t10, t100, t1001, t101, t1010, t102, t103, t104, t1042, t1044, t1045, t1046, t1047, t1048, t1049, t105, t1050, t1053, t1054, t1058, t106, t1065, t1069, t107, t1073, t1074, t1075, t108, t1089, t109, t1090, t1094, t11, t110, t1101, t1105, t1109, t111, t1118, t112, t113, t114, t115, t1150, t1152, t1153, t1154, t1155, t1156, t1157, t1158, t116, t1161, t1162, t1166, t1173, t1177, t1181, t1182, t1183, t1197, t1198, t12, t1202, t1209, t1213, t1217, t122, t1226, t1258, t1260, t1261, t1262, t1263, t1264, t1265, t1266, t1269, t1270, t1274, t1281, t1285, t1289, t129, t1290, t1291, t13, t1305, t1306, t1310, t1317, t1321, t1325, t133, t1334, t1366, t1368, t1369, t137, t1370, t1371, t1372, t1373, t1374, t1377, t1378, t1382, t1389, t1393, t1397, t1398, t1399, t14, t1413, t1414, t1418, t1425, t1429, t1433, t1442, t146, t1474, t1476, t1477, t1478, t1479, t1480, t1481, t1482, t1485, t1486, t1490, t1497, t15, t1501, t1505, t1506, t1507, t1521, t1522, t1526, t1533, t1537, t154, t1541, t1550, t157, t1582, t1584, t1585, t1586, t1587, t1588, t1589, t1590, t1593, t1594, t1598, t16, t1605, t1609, t1613, t1614, t1615, t1629, t1630, t1634, t1641, t1645, t1649, t1658, t1690, t1692, t1693, t1694, t1695, t1696, t1697, t1698, t17, t1701, t1702, t1706, t1713, t1717, t1721, t1722, t1723, t1737, t1738, t1742, t1749, t1753, t1757, t1766, t178, t1798, t18, t180, t1800, t1801, t1802, t1803, t1804, t1805, t1806, t1809, t181, t1810, t1814, t182, t1821, t1825, t1829, t183, t1830, t1831, t184, t1845, t1846, t185, t1857, t186, t1861, t1865, t189, t19, t190, t194, t20, t201, t205, t209, t21, t210, t211, t22, t225, t226, t23, t230, t237, t24, t241, t245, t25, t254, t26, t27, t28, t286, t288, t289, t29, t290, t291, t292, t293, t294, t297, t298, t3, t30, t302, t309, t31, t313, t317, t318, t319, t32, t33, t333, t334, t338, t34, t345, t349, t35, t353, t36, t362, t37, t38, t39, t394, t396, t397, t398, t399, t4, t40, t400, t401, t402, t405, t406, t41, t410, t417, t42, t421, t425, t426, t427, t43, t44, t441, t442, t446, t45, t453, t457, t46, t461, t47, t470, t48, t49, t5, t50, t502, t504, t505, t506, t507, t508, t509, t51, t510, t513, t514, t518, t525, t529, t533, t534, t535, t549, t550, t554, t561, t565, t569, t578, t6, t610, t612, t613, t614, t615, t616, t617, t618, t621, t622, t626, t633, t637, t641, t642, t643, t657, t658, t662, t669, t673, t677, t686, t7, t718, t720, t721, t722, t723, t724, t725, t726, t729, t730, t734, t741, t745, t749, t750, t751, t765, t766, t770, t777, t781, t785, t794, t8, t826, t828, t829, t830, t831, t832, t833, t834, t837, t838, t842, t849, t85, t853, t857, t858, t859, t86, t87, t873, t874, t878, t88, t885, t889, t89, t893, t9, t90, t902, t91, t92, t93, t934, t936, t937, t938, t939, t94, t940, t941, t942, t945, t946, t95, t950, t957, t96, t961, t965, t966, t967, t97, t98, t981, t982, t986, t99, t993, t997), (False, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 0.0, 4096.0, 4096.0, 0.08838834764831843, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 4096.0, 4096.0, 0.0, 0.08838834764831843, 32000, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2))

Well, that is quite a bit to look through. But here is a key thing: The function now returns a bunch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a ThunderFunctionBackward on as its grad_fn. (You can see the backward trace with thunder.last_backward_traces(thunder_model)[-1]).

[10]:
actual
[10]:
tensor([[[ 0.4160, -0.4668,  1.1016,  ...,  0.5430,  1.2656,  0.2891],
         [ 0.3320, -0.0557,  1.7891,  ...,  1.0703,  1.0078,  1.2266],
         [ 0.6836, -0.2871,  0.9531,  ...,  0.0806,  0.7070,  0.8477],
         ...,
         [ 0.7695, -0.1260,  0.7266,  ...,  0.1118, -0.0238, -1.2656],
         [-0.7773, -0.5547, -0.3047,  ..., -0.1807,  0.1895,  0.6875],
         [ 0.8867,  0.4766,  0.3984,  ...,  0.0815, -0.0879,  0.3477]]],
       device='cuda:0', grad_fn=<ThunderFunctionBackward>)

Let us clean up a bit.

[11]:
del actual, expected
import gc
gc.collect();

But is it faster? Yes!

[12]:
%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
240 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
208 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!

[13]:
del m, thunder_model
import gc
gc.collect()
torch.cuda.empty_cache()

Distributed with Thunder

Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.

Happily Thunder sports an FSDP interface to use multiple cards in our box.

You still need to setup the process group, but as far as the model is concerned,

model = thunder.jit(thunder.distributed.fsdp(model))

is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though torch-run.

Check out our LitGPT Thunder examples for complete distributed training and finetuning!

[14]:
%%writefile zero_to_thunder_fsdp_simple_example.py
from thunder.tests.litgpt_model import GPT, Config
import os
import torch, torch.distributed
import thunder, thunder.distributed

# Create Model
# NOTE: We create the model on CPU.
device='cpu'
torch.set_default_dtype(torch.bfloat16)
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 8 # fewer layers
model = GPT(cfg)

# Setup for distributed
torch.distributed.init_process_group(backend='nccl')
rank = int(os.environ["LOCAL_RANK"])

device = f"cuda:{rank}"
x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)

# thunder.distributed.fsdp takes care of moving the parameter
# shard to the correct GPU for the current process.
model = thunder.jit(thunder.distributed.fsdp(model)) #  <---------------------------------------
print(f"rank {rank} computing")
# Run the forward pass.
for i in range(10):
    res = model(x)
    res.sum().backward()

Overwriting zero_to_thunder_fsdp_simple_example.py

Now we can launch it. Note that you need two GPUs for this to run correctly.

[15]:
!torchrun --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757]
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0320 15:06:06.538000 140013994370240 torch/distributed/run.py:757] *****************************************
rank 1 computing
rank 0 computing

So there. FSDP with just wrapping the model in fsdp.

Extending Thunder

But we promised that thunder is extensible. Let’s find out what’s up with that.

Specifically, we will incorporate the fast rope embedding kernel from the great Unsloth project into our model (note that NVFuser also creates a fused kernel for this).

In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with executors handling operations. Let us define one.

[16]:
my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')
thunder.extend.register_executor(my_ex)
[16]:
my_ex

For our base implementation, we take the code from LitGPT’s implementation

In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the register_operator function. Because we will demonstrate Thunder’s ability to divert functions in the model, we make a version here that will not be diverted.

[17]:
import litgpt
def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    head_size = x.size(-1)
    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x * cos) + (rotated * sin)
    return roped.to(dtype=x.dtype)

Registering operators

Say we have a function apply_rope applying the RoPE transformation in PyTorch.

In thunder, we define a meta function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the register_operator function and tell it to use the new symbol instead of the original function litgpt.model.apply_rope.

[18]:
import torch, thunder
from thunder.tests.litgpt_model import GPT
from thunder import TensorProxy

def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return litgpt.model.apply_rope(x, cos, sin)

def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
    return TensorProxy(like=x)

apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,
                                    replaces=litgpt.model.apply_rope)

Testing our new operator

[19]:
with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)

def test_apply_rope(x, m):
    return litgpt.model.apply_rope(x, m.cos, m.sin)

thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())

expected = test_apply_rope(Q, m); actual = thunder_apply_rope(Q, m); print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_apply_rope)[-1]
deviation: 0.0
[19]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
  # x: "cuda:0 bf16[2, 128, 4096, 16]"
  # t_1_cos: "cuda:0 f32[4096, 16]"
  # t_1_sin: "cuda:0 f32[4096, 16]"
  t2 = apply_rope(x, t_1_cos, t_1_sin)  # t2: "cuda:0 bf16[2, 128, 4096, 16]"
  del x, t_1_cos, t_1_sin
  return t2

Optimized kernels

But why did we do this? Well, we can now layer a faster implementation on top. For this we take the unsloth fast rope embedding kernels. We take the bits that were in the forward and backward of the autograd.Function into our implementation functions. Note that we include the transpositions in our setup in order to have compatibility to the LitGPT implementation. This change in memory layout of the operands can have a large effect on the runtime though, so our timings are likely not representative of the ones the Unsloth project gets in their use of the same triton kernels.

[20]:
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl
import torch

MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2

def calculate_settings(n):
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
                           f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
    num_warps = 4
    if   BLOCK_SIZE >= 32768: num_warps = 32
    elif BLOCK_SIZE >=  8192: num_warps = 16
    elif BLOCK_SIZE >=  2048: num_warps = 8
    return BLOCK_SIZE, num_warps

@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
@triton.jit
def _rope_embedding(
    Q,     Q_row_stride,
    cos, cos_row_stride,
    sin, sin_row_stride,
    seqlen, head_dim, group_size, n_heads,
    BACKWARD_PASS: tl.constexpr,
    BLOCK_SIZE : tl.constexpr,
):
    """
        Calculates the RoPE Embedding quickly
        RoPE is Q * cos + rotate_half(Q) * sin
        See our blog post for more info
    """
    row_position  = tl.program_id(0)
    group_head_position = tl.program_id(1)
    col_offsets  = tl.arange(0, BLOCK_SIZE)
    half_head_dim = head_dim // 2
    mask = col_offsets < half_head_dim

    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)

    if BACKWARD_PASS:
        # See our blog post for more info.
        sin1 = -sin1
    pass

    head_start = group_head_position * group_size
    head_end = min((head_start + group_size), n_heads)

    for i in range(head_start, head_end):
        offs_q1 = row_position * Q_row_stride + i * head_dim + col_offsets
        offs_q2 = row_position * Q_row_stride + i * head_dim + col_offsets + half_head_dim

        # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
        Q1   = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
        Q2   = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)

        tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
        tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
    pass
pass


def fast_rope_embedding_forward(Q, cos, sin):
    Q = Q.transpose(1, 2).clone()
    cos, sin = cos.squeeze(), sin.squeeze()
    batch, seq_len, n_heads, head_dim = Q.shape
    Q = Q.reshape(batch*seq_len, n_heads*head_dim)
    n_rows, n_cols = Q.shape
    assert(seq_len <= cos.shape[0])

    # [TODO] Changing blocksize to head_dim//2 seems to have
    # some concurrency / un-deterministic issues.
    BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
    group_size = 4 # 4 or 8, too large group_size can hurt performance.
    n_groups = triton.cdiv(n_heads, group_size)

    grid = (n_rows, n_groups, )
    _rope_embedding[grid](
          Q,   Q.stride(0),
        cos, cos.stride(0),
        sin, sin.stride(0),
        seq_len, head_dim, group_size, n_heads,
        BACKWARD_PASS = False,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    Q = Q.view(batch, seq_len, n_heads, head_dim).transpose(1, 2)
    return Q, (BLOCK_SIZE, num_warps)

def fast_rope_embedding_backward(BLOCK_SIZE, num_warps, cos, sin, dY):
    dY = dY.transpose(1, 2)
    batch, seq_len, n_heads, head_dim = dY.shape
    dY = dY.reshape(batch*seq_len, n_heads*head_dim)
    # Must be reshape not view
    n_rows, n_cols = dY.shape

    group_size = 4 # 4 or 8, too large group_size can hurt performance.
    n_groups = triton.cdiv(n_heads, group_size)

    grid = (n_rows, n_groups, )
    _rope_embedding[grid](
        dY,  dY .stride(0),
        cos, cos.stride(0),
        sin, sin.stride(0),
        seq_len, head_dim, group_size, n_heads,
        BACKWARD_PASS = True,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    dY = dY.view(batch, seq_len, n_heads, head_dim)
    dY = dY.transpose(1, 2)
    return dY

We also define the corresponding meta functions.

[21]:
def fast_rope_embedding_forward_meta(Q, cos, sin):
    batch, n_heads, seq_len, head_dim = Q.shape
    n_rows, n_cols = batch*seq_len, n_heads*head_dim
    assert(seq_len <= cos.shape[0])

    BLOCK_SIZE, num_warps = calculate_settings(head_dim//2)
    return TensorProxy(like=Q), (BLOCK_SIZE, num_warps)

def fast_rope_embedding_backward_meta(BLOCK_SIZE, num_warps, cos, sin, dY):
    return TensorProxy(like=dY)

Register optimized operators

Just like the apply_rope before, we can register operators for the optimized forward and backward.

[22]:
unsloth_apply_rope_forward = my_ex.register_operator('unsloth_apply_rope_forward',
    meta=fast_rope_embedding_forward_meta, fn=fast_rope_embedding_forward)
unsloth_apply_rope_backward = my_ex.register_operator('unsloth_apply_rope_backward',
    meta=fast_rope_embedding_backward_meta, fn=fast_rope_embedding_backward)

Implementations for operators

Do we need to divert apply_rope again? No! We can register the specialized kernel as an implementation of our base apply_rope operator. For this we need an execution transform - which is a fancy word for a function that implements the original operator (apply_ropw) in terms of our new operator - so it has the call signature of the apply_rope. Because - like many fast implementations - the unsloth rope embedding does not implement the operator in full generality (well, actually they mainly want a 4d tensor input), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs.

[23]:
def apply_rope_to_unsloth(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:
    assert len(x.shape) == 4
    res, *_ = unsloth_apply_rope_forward(x, cos, sin)
    return res

def apply_rope_to_unsloth_checker(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> bool:
    if len(x.shape) != 4:
        return False
    return (x.device.devicetype == thunder.devices.DeviceType.CUDA and
            cos.device.devicetype == thunder.devices.DeviceType.CUDA and
           cos.device.devicetype == thunder.devices.DeviceType.CUDA)

my_ex.register_implementation(apply_rope,
                              checker=apply_rope_to_unsloth_checker,
                              execution_transform=apply_rope_to_unsloth)

So let us give it a try! Works great…

[24]:
thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())

expected = test_apply_rope(Q, m)
actual = thunder_apply_rope(Q, m)
print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_apply_rope)[-1]
deviation: 0.015625
[24]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, t_1_cos, t_1_sin):
  # x: "cuda:0 bf16[2, 128, 4096, 16]"
  # t_1_cos: "cuda:0 f32[4096, 16]"
  # t_1_sin: "cuda:0 f32[4096, 16]"
  (t2, (_, _)) = unsloth_apply_rope_forward(x, t_1_cos, t_1_sin)
  del x, t_1_cos, t_1_sin
  return t2

And this is also automatic when we instantiate a larger llama2-like model:

[25]:
torch.set_default_dtype(torch.float32)
with torch.device('cuda'):
    m = GPT(Config.from_name('llama2-like'))

for p in m.parameters():
    p.requires_grad_(False)

thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())

inp = torch.randint(1, m.config.vocab_size, (1, 128), device="cuda")
actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())
deviation: 5.960464477539062e-07

By peeking into the trace, we can see that it actually used the unsloth apply rope:

[26]:
[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\n') if 'apply_rope' in s]
[26]:
['  (q_roped, (_, _)) = unsloth_apply_rope_forward(t55, cos, sin)',
 '  (k_roped, (_, _)) = unsloth_apply_rope_forward(t57, cos, sin)',
 '  (t165, (_, _)) = unsloth_apply_rope_forward(t164, cos, sin)',
 '  (t167, (_, _)) = unsloth_apply_rope_forward(t166, cos, sin)']

But what about the backward?

Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call get_grad for the output, compute the backward, and put it on the input with put_grads.

[27]:
from thunder.core.transforms import get_grad, put_grads

def unsloth_apply_rope_grad(x: TensorProxy, cos: TensorProxy, sin: TensorProxy):
    res, (BLOCK_SIZE, num_warps) = unsloth_apply_rope_forward(x, cos, sin)
    grad_res = get_grad(res)
    grad_x = unsloth_apply_rope_backward(BLOCK_SIZE, num_warps, cos, sin, grad_res)
    put_grads((x,), (grad_x,))
    return res

my_ex.register_implementation(apply_rope, checker=apply_rope_to_unsloth_checker,
                              execution_transform=apply_rope_to_unsloth,
                              grad_transform=unsloth_apply_rope_grad
                              )


Note that the parts are not actually executed at the same time in the actual computation, but just during tracing.

And let us try our function using the optimized backward

[28]:
Q.requires_grad_()

thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())

expected = test_apply_rope(Q, m)
go = torch.ones_like(expected)
gr_expected, = torch.autograd.grad(expected, Q, go)
actual = thunder_apply_rope(Q, m)
gr_actual, = torch.autograd.grad(actual, Q, go)

print("res deviation:", (expected - actual).abs().max().item())
print("grad deviation:", (gr_expected - gr_actual).abs().max().item())
res deviation: 0.015625
grad deviation: 0.0078125

And with last_backward_traces we can check that our module is using the unsloth backward:

[29]:
thunder.last_backward_traces(thunder_apply_rope)[-1]
[29]:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, \
  _, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t4, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t1, \
  t2, \
  = C0
  clear_collection(C0)
  del C0
  t3 = unsloth_apply_rope_backward(8, 4, t1, t2, t4)  # t3: "cuda:0 bf16[2, 128, 4096, 16]"
  del t1, t2, t4
  return (t3, None, None)

Comparing and exploring optimizations

It is also straightforward to compare potential optimizations.

Note again, that our use of the unsloth kernel might not result in the same performance as the unsloth project sees due to differences in the hardware used, software environment, or memory layout of the operands.

[30]:
def test_apply_rope_copy(x, m):
    return apply_rope_copy(x, m.cos, m.sin)

test_apply_rope_myex = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors())
test_apply_rope_nvfuser = thunder.jit(test_apply_rope_copy)
y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go)
y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go)

print("eager")
%timeit y = test_apply_rope(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder + unsloth")
%timeit y = test_apply_rope_myex(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()
print("thunder default (nvfuser)")
%timeit y = test_apply_rope_nvfuser(Q, m); gr, = torch.autograd.grad(y, Q, go); torch.cuda.synchronize()

eager
3.84 ms ± 3.46 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder + unsloth
6.69 ms ± 3.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
thunder default (nvfuser)
1.4 ms ± 4.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

That’s it!

Conclusion

To wrap up, we hope you got a taste of

  • Getting things going with Thunder:

    • Applying Thunder through thunder.jit and

    • using FSDP by just wrapping the model in thunder.distributed.fsdp before compilation.

  • See what’s going on inspecting traces:

    • thunder.last_traces for the forward traces,

    • thunder.last_backward_traces for the backward,

  • Extending Thunder:

    • registering operators with the OperatorExecutor,

    • defining implementations with custom forward and backward to include optimized kernels.

Keep in mind that Thunder is still experimental and only expected to work with the limited set of models we have tested it with. You will find bugs and missing pieces. Naturally, we would love for you to help us fix these! You can find us on the Thunder section of the Lightning forums or in the #thunder channel on the PyTorch-Lightning slack.

Do check out our LitGPT studios and the other tutorial notebooks.