Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test failures on Ubuntu 24.04 (RuntimeError: allocating memory on queue: ABORTED) #317

Open
ScottTodd opened this issue Dec 5, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@ScottTodd
Copy link
Member

Observed on #316 when trying to update test workflows from Ubuntu 22.04 to Ubuntu 24.04:

=========================== short test summary info ============================
FAILED tests/dynamo/tensor_test.py::TensorTest::test_nn_MLP - RuntimeError: Error invoking function: ABORTED; while invoking native function hal.device.queue.alloca; while calling import; 
[ 0] bytecode module.main$async:242 /home/runner/work/iree-turbine/iree-turbine/wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:535:0
FAILED tests/dynamo/tensor_test.py::TensorTest::test_nn_linear - RuntimeError: allocating memory on queue: ABORTED
FAILED tests/dynamo/tensor_test.py::TensorTest::test_unary_op - RuntimeError: allocating memory on queue: ABORTED
====== 3 failed, 290 passed, 507 skipped, 9 xfailed, 8 warnings in 34.70s ======

Environment information

Branched off at 5cde4de

Logs list the same Python version (3.11.10) and packages installed (pip freeze output) on both OS versions:

+ pip freeze
execnet==2.1.1
filelock==3.16.1
fsspec==2024.10.0
iniconfig==2.0.0
iree-base-compiler==3.1.0rc20241127
iree-base-runtime==3.1.0rc20241127
iree-turbine==2.5.0.dev0
Jinja2==3.1.4
MarkupSafe==3.0.2
ml_dtypes==0.5.0
mpmath==1.3.0
networkx==3.4.2
numpy==2.1.3
packaging==24.2
parameterized==0.9.0
pillow==11.0.0
pluggy==1.5.0
pytest==8.0.0
pytest-xdist==3.5.0
sympy==1.13.1
torch==2.5.1+cpu
torchvision==0.20.1+cpu
typing_extensions==4.12.2

Software included in 22.04: https://github.com/actions/runner-images/blob/main/images/ubuntu/Ubuntu2204-Readme.md
Software included in 24.04: https://github.com/actions/runner-images/blob/main/images/ubuntu/Ubuntu2404-Readme.md

Logs

Sample workflow runs:

Logs:

=================================== FAILURES ===================================
____________________________ TensorTest.test_nn_MLP ____________________________
[gw2] linux -- Python 3.11.10 /home/runner/work/iree-turbine/iree-turbine/wheelhouse/test.venv/bin/python

self = <tensor_test.TensorTest testMethod=test_nn_MLP>

    def test_nn_MLP(self):
        class MLP(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.layer0 = torch.nn.Linear(64, 32, bias=True)
                self.layer1 = torch.nn.Linear(32, 16, bias=True)
                self.layer2 = torch.nn.Linear(16, 7, bias=True)
                self.layer3 = torch.nn.Linear(7, 7, bias=True)
    
            def forward(self, x: torch.Tensor):
                x = self.layer0(x)
                x = torch.sigmoid(x)
                x = self.layer1(x)
                x = torch.sigmoid(x)
                x = self.layer2(x)
                x = torch.sigmoid(x)
                x = self.layer3(x)
                return x
    
        m = MLP()
        input = torch.randn(16, 64)
        ref_output = m(input)
        m.to("turbine")
        input = input.to("turbine")
>       turbine_output = m(input)

tests/dynamo/tensor_test.py:[131](https://github.com/iree-org/iree-turbine/actions/runs/12185371055/job/33991838631?pr=316#step:5:132): 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
tests/dynamo/tensor_test.py:117: in forward
    x = self.layer0(x)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py:125: in forward
    return F.linear(input, self.weight, self.bias)
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:87: in __torch_function__
    return super_fn(*args, **kwargs or {})
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:79: in super_fn
    return func(*args, **kwargs)
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:230: in __torch_dispatch__
    return compute_method(func, *args, **kwargs)
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:576: in compute_method
    exec_results = exec(*py_args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <iree.turbine.dynamo.executor.EagerSpecializedExecutable object at 0x7f9f204a02c0>
inputs = (<TurbineTensor(Device) of <HalBufferView (32, 64), element_type=0x20000021, 8192 bytes (at offset 0 into 8192), memor...|HOST_VISIBLE, allowed_access=ALL, allowed_usage=TRANSFER|DISPATCH_STORAGE|MAPPING|MAPPING_PERSISTENT> on local-task>,)
arg_list = <VmVariantList(3): [HalBufferView(32x64:0x20000021), fence(1), fence(1)]>
ret_list = <VmVariantList(0): []>, device = <Turbine Device: local-task>

    def __call__(self, *inputs):
        arg_list = VmVariantList(len(inputs))
        ret_list = VmVariantList(
            1
        )  # TODO: Get the number of results from the descriptor.
    
        # Initialize wait and signal fence if not async mode.
        device = inputs[0]._storage.device
        wait_fence, signal_fence = self._initialize_fences(device, inputs, arg_list)
    
        # Move inputs to the device and add to arguments.
        self._inputs_to_device(inputs, arg_list, wait_fence, signal_fence)
    
        # Invoke.
>       self.vm_context.invoke(self.entry_function, arg_list, ret_list)
E       RuntimeError: Error invoking function: ABORTED; while invoking native function hal.device.queue.alloca; while calling import; 
E       [ 0] bytecode module.main$async:242 /home/runner/work/iree-turbine/iree-turbine/wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:535:0

wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/executor.py:183: RuntimeError
__________________________ TensorTest.test_nn_linear ___________________________
[gw2] linux -- Python 3.11.10 /home/runner/work/iree-turbine/iree-turbine/wheelhouse/test.venv/bin/python

self = <tensor_test.TensorTest testMethod=test_nn_linear>

    def test_nn_linear(self):
        m = torch.nn.Linear(20, 30)
        input = torch.randn(128, 20)
        ref_output = m(input)
>       m.to("turbine")

tests/dynamo/tensor_test.py:100: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1340: in to
    return self._apply(convert)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:927: in _apply
    param_applied = fn(param)
wheelhouse/test.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:[132](https://github.com/iree-org/iree-turbine/actions/runs/12185371055/job/33991838631?pr=316#step:5:133)6: in convert
    return t.to(
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:84: in __torch_function__
    return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {})
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:382: in to
    new_t = DeviceTensor._async_create_empty(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

size = torch.Size([30, 20]), device = <Turbine Device: local-task>
dtype = torch.float32

    @staticmethod
    def _async_create_empty(
        size: Sequence[int], device: Device, dtype: torch.dtype
    ) -> "DeviceTensor":
        """Creates an uninitialized tensor with a given size and dtype."""
        alloc_size = _calculate_c_contig_size(size, dtype)
        hal_device = device.hal_device
        # Async allocate a buffer, waiting for the device (tx_timeline, tx_timepoint)
        # and signalling tx_timepoint + 1. Because we are just creating an empty
        # (uninitialized) tensor, it is ready when allocation completes.
        tx_semaphore = device._tx_timeline
        current_tx_timepoint = device._tx_timepoint
        wait_semaphores = [(tx_semaphore, current_tx_timepoint)]
        alloca_complete_semaphore = (tx_semaphore, current_tx_timepoint + 1)
        signal_semaphores = [alloca_complete_semaphore]
        device._tx_timepoint += 1
>       buffer = hal_device.queue_alloca(alloc_size, wait_semaphores, signal_semaphores)
E       RuntimeError: allocating memory on queue: ABORTED

wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:307: RuntimeError
___________________________ TensorTest.test_unary_op ___________________________
[gw2] linux -- Python 3.11.10 /home/runner/work/iree-turbine/iree-turbine/wheelhouse/test.venv/bin/python

self = <tensor_test.TensorTest testMethod=test_unary_op>

    def test_unary_op(self):
>       t1 = -5.3 * torch.ones(2, 3).to(device="turbine")

tests/dynamo/tensor_test.py:92: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:84: in __torch_function__
    return self.IMPLEMENTATIONS[func](super_fn, *args, **kwargs or {})
wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:382: in to
    new_t = DeviceTensor._async_create_empty(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

size = torch.Size([2, 3]), device = <Turbine Device: local-task>
dtype = torch.float32

    @staticmethod
    def _async_create_empty(
        size: Sequence[int], device: Device, dtype: torch.dtype
    ) -> "DeviceTensor":
        """Creates an uninitialized tensor with a given size and dtype."""
        alloc_size = _calculate_c_contig_size(size, dtype)
        hal_device = device.hal_device
        # Async allocate a buffer, waiting for the device (tx_timeline, tx_timepoint)
        # and signalling tx_timepoint + 1. Because we are just creating an empty
        # (uninitialized) tensor, it is ready when allocation completes.
        tx_semaphore = device._tx_timeline
        current_tx_timepoint = device._tx_timepoint
        wait_semaphores = [(tx_semaphore, current_tx_timepoint)]
        alloca_complete_semaphore = (tx_semaphore, current_tx_timepoint + 1)
        signal_semaphores = [alloca_complete_semaphore]
        device._tx_timepoint += 1
>       buffer = hal_device.queue_alloca(alloc_size, wait_semaphores, signal_semaphores)
E       RuntimeError: allocating memory on queue: ABORTED

wheelhouse/test.venv/lib/python3.11/site-packages/iree/turbine/dynamo/tensor.py:307: RuntimeError
@ScottTodd ScottTodd added the bug Something isn't working label Dec 5, 2024
@ScottTodd
Copy link
Member Author

Oh, also seeing these on Ubuntu 22.04 sometimes. The tests look to be flaky. Maybe they failed more consistently on 24.04?

https://github.com/iree-org/iree-turbine/actions/runs/12438563139/job/34730710430?pr=354#step:6:4112

@ScottTodd
Copy link
Member Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant