Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiled model raises error "attn_bias is not correctly aligned" in pytorch 2.2 #121943

Closed
flishwang opened this issue Mar 15, 2024 · 8 comments
Closed
Assignees
Labels
high priority module: unknown We do not know who is responsible for this feature, bug, or test case. oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@flishwang
Copy link

flishwang commented Mar 15, 2024

🐛 Describe the bug

When running the following code, errors may occur in pytorch 2.2.0 or 2.2.1, but not in 2.1.0.


from einops import rearrange
import torch,torch.nn as nn
def rotate_half(x):
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')


def random_masking_v4(mask_kernel, percent,loss_kernel, B, H, W, device='cpu', loss_weight_factor = 1.0):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    k1, k2 = mask_kernel
    pad = (loss_kernel -1) // 2
    with torch.no_grad():
        noise1 = torch.rand(B, 1, H + k1 - 1, W + k2 - 1, device=device) * 800
        noise1 = torch.nn.functional.max_pool2d(noise1, kernel_size=(k1, k2), stride=1, padding=0, )
        noise2 = torch.rand(B, 1, H + k2 - 1, W + k1 - 1, device=device) * 800
        noise2 = torch.nn.functional.max_pool2d(noise2, kernel_size=(k2, k1), stride=1, padding=0, )

        noise = (torch.maximum(noise1, noise2)).view(B, 1, H, W)
        noise = (torch.rand(B, 1, H, W, device=device) - noise).view(B, -1)

        ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove, shape:B,N
        ids_restore = torch.argsort(ids_shuffle, dim=1) # represents the order of each id
        ids_mask = ids_restore < int(H*W*percent)

        rand_center = torch.cat([ids_shuffle[:, 0:1] // W, ids_shuffle[:, 0:1] % W], 1).unsqueeze(-1)

        cy, cx = torch.meshgrid(torch.arange(H, device=device),
                                torch.arange(W, device=device), indexing='ij')
        coords = torch.stack([cy, cx]).view(1, 2, H * W)
        distance = (((coords - rand_center + torch.rand(B, 2, 1, device=device)) ** 2).sum(1)) ** 0.5  + 1
        ids_order = (distance * 3).int() * ~ids_mask + -100 * ids_mask
        can_see_p1 = ids_order[:,:,None] >= ids_order[:,None,:]
        attn_mask = can_see_p1.unsqueeze(1)

        patch_order = ids_order.view(B,1,H,W).float()
        loss_order = torch.nn.functional.unfold(patch_order,loss_kernel,dilation=1, padding=pad)

        if loss_kernel == 3:
            loss_weight = torch.as_tensor((2,1,2,1,1,1,2,1,2),dtype=torch.float32,device=device)
        elif loss_kernel == 5:
            loss_weight = torch.as_tensor(((8,5,2,5,8),(5,2,1,2,5),(2,1,1,1,2),
                                                (5,2,1,2,5),(8,5,2,5,8)), dtype=torch.float32, device=device)
        else:
            raise NotImplementedError

        loss_weight = 1.0 / loss_weight.view(1,-1,1) ** loss_weight_factor
        loss_mask = ((loss_order-1e-5) > patch_order.view(B,1,H*W)).float()

    return torch.where(attn_mask,0,-9999.0), loss_mask * loss_weight


class Attention(nn.Module):

    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=True,
            qk_norm=False,
            attn_drop=0.,
            proj_drop=0.,
            norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = True

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    #@torch.compile
    def forward(self, x,mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,attn_mask=mask,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn + mask
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class CustomModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = Attention(768,12,True)

    def forward(self,x):
        mask,_ = random_masking_v4((2,5),0.15,3,x.shape[0],
                                   int(x.shape[1]**0.5),int(x.shape[1]**0.5),device=x.device,
                                   )
        return self.attn(x,mask)

model = CustomModel().cuda()
model_without_ddp = model
x =torch.zeros(256,196,768).cuda()

optimizer = torch.optim.AdamW(model_without_ddp.parameters())
model = torch.compile(model)

with torch.cuda.amp.autocast():
    out = model(x)
loss = out.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad(True)

with torch.cuda.amp.autocast():
    out = model(x)
loss = out.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad(True)

error messages:


  File "//test.py", line 130, in <module>
    out = model(x)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "//test.py", line 116, in forward
    def forward(self,x):
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 83, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 408, in forward
    fw_outs = call_func_at_runtime_with_args(
  File "/usr/local/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
    return model(new_inputs)
  File "/usr/local/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_root/hz/chzdqrbr5gisewqim47noe7zsijncibkeouw2ryw2no4ecybmvnj.py", line 641, in call
    buf21 = aten._scaled_dot_product_efficient_attention(reinterpret_tensor(buf18, (256, 12, 196, 64), (451584, 64, 2304, 1), 0), reinterpret_tensor(buf18, (256, 12, 196, 64), (451584, 64, 2304, 1), 768), reinterpret_tensor(buf18, (256, 12, 196, 64), (451584, 64, 2304, 1), 1536), buf20, True)
  File "/usr/local/lib/python3.10/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: attn_bias is not correctly aligned (strideM). attn_bias.stride(2) = 196, and should be a multiple of 8.

Versions

/usr/local/lib/python3.10/runpy.py:126: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
warn(RuntimeWarning(msg))
Collecting environment information...
PyTorch version: 2.2.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.10.13 (main, Dec 19 2023, 08:15:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-4.15.0-189-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Quadro RTX 6000
GPU 1: Quadro RTX 6000
GPU 2: Quadro RTX 6000
GPU 3: Quadro RTX 6000
GPU 4: Quadro RTX 6000
GPU 5: Quadro RTX 6000
GPU 6: Quadro RTX 6000
GPU 7: Quadro RTX 6000

Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Gold 5218 CPU @ 2.30GHz
Stepping: 7
CPU MHz: 2651.207
CPU max MHz: 3900.0000
CPU min MHz: 1000.0000
BogoMIPS: 4600.00
Virtualization: VT-x
L1d cache: 1 MiB
L1i cache: 1 MiB
L2 cache: 32 MiB
L3 cache: 44 MiB
NUMA node0 CPU(s): 0-15,32-47
NUMA node1 CPU(s): 16-31,48-63
Vulnerability Itlb multihit: KVM: Mitigation: Split huge pages
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] torch==2.2.0+cu118
[pip3] torchlaunch==1.0
[pip3] torchvision==0.17.0+cu118
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire

@flishwang flishwang changed the title compiled model raises error "attn_bias is not correctly aligned " for pytorch 2.2 Compiled model raises error "attn_bias is not correctly aligned" in pytorch 2.2 Mar 15, 2024
@anijain2305
Copy link
Contributor

@flishwang I can't repro the error. Can you please try again on nightly?

@anijain2305
Copy link
Contributor

cc @drisspg if this reminds you of anything.

@anijain2305 anijain2305 added module: unknown We do not know who is responsible for this feature, bug, or test case. triage review high priority and removed high priority labels Mar 18, 2024
@flishwang
Copy link
Author

flishwang commented Mar 19, 2024

@flishwang I can't repro the error. Can you please try again on nightly?

Yes, I tried the code on nightly and no error occurs.

@drisspg
Copy link
Contributor

drisspg commented Mar 19, 2024

I fixed a similar issue here: #114837 but I think that should have landed by 2.2

@anijain2305
Copy link
Contributor

Comments from triage meeting

  • Quadro 6000 is sm_75
  • We should add a test
  • We need to reproduce.

@anijain2305 anijain2305 self-assigned this Mar 19, 2024
@jansel jansel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 26, 2024
@poedator
Copy link

poedator commented Apr 4, 2024

I observed same error with this repo that uses torch graphs:

git clone git@github.com:Infini-AI-Lab/Sequoia.git
cd Sequoia/test
bash run_A100.sh
...
File "/home/optimus/sequoia/tests/../Engine/Llama_modules.py", line 122, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: p.attn_bias_ptr is not correctly aligned

The error appears in torch==2.2.2; torch==2.1.2 works fine
cc: @dreaming-panda

@fxmarty
Copy link

fxmarty commented Apr 17, 2024

same issue on transformers main (huggingface/transformers@8e5f76f) and torch 2.2.2. This script used to work:

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
from transformers.cache_utils import StaticCache

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)

with torch.device("cuda"):
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-chat-hf",
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
    )

inputs = tokenizer(
    ["I would", "Today I am in Paris and", "I am"], padding=True, return_tensors="pt"
).to(model.device)

new_tokens = 10
gen_config = GenerationConfig(
    max_new_tokens=new_tokens,
    min_new_tokens=new_tokens,
    use_cache=True,
    pad_token_id=tokenizer.pad_token_id,
    num_beams=1,
    do_sample=False,
    eos_token_id=None,  # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None  # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.

print("----- GENERATE WITHOUT COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")

decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

print("decoded", decoded)

print("compiling...")

model.forward = torch.compile(model.forward, mode="reduce-overhead")
print("Finished compile call")

print("----- GENERATE WITH COMPILE")
gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)

gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
print("decoded", decoded)

now getting

  File "/home/felix/test_static_cache_bis.py", line 45, in <module>
    gen_out = model.generate(**inputs, generation_config=gen_config, cache_implementation="static")
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/felix/transformers/src/transformers/generation/utils.py", line 1572, in generate
    result = self._greedy_search(
  File "/home/felix/transformers/src/transformers/generation/utils.py", line 2477, in _greedy_search
    outputs = self(
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/felix/transformers/src/transformers/models/llama/modeling_llama.py", line 1143, in forward
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 665, in run
    return compiled_fn(new_inputs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 380, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 408, in cudagraphify
    return manager.add_function(
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1941, in add_function
    return fn, fn(inputs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1755, in run
    out = self._run(new_inputs, function_id)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1796, in _run
    return self.run_eager(new_inputs, function_id)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 1911, in run_eager
    return node.run(new_inputs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/cudagraph_trees.py", line 611, in run
    out = self.wrapped_function.model(new_inputs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_felix/3m/c3mvkaman7clzvf6o4c4aa673bnphnkaqri2ksupf2k7soqk3jct.py", line 1636, in call
    buf14 = aten._scaled_dot_product_efficient_attention(buf12, arg292_1, arg293_1, buf13, False)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/utils/_device.py", line 77, in __torch_function__
    return func(*args, **kwargs)
  File "/home/felix/miniconda3/envs/fx/lib/python3.9/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: attn_bias is not correctly aligned (strideH). attn_bias.stride(1) = 119, and should be a multiple of 8.

This looks to be fixed on PyTorch 2.3 RC.

@flishwang
Copy link
Author

close the issue as bug fixed in 2.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: unknown We do not know who is responsible for this feature, bug, or test case. oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants